备份代码

This commit is contained in:
dandan 2024-10-31 23:15:22 +08:00
parent bf525906e4
commit 92ee61df03
7 changed files with 273 additions and 151 deletions

View File

@ -1,149 +0,0 @@
package infrastructure
import (
"context"
"github.com/Superdanda/hade/app/provider/database_connect"
userModule "github.com/Superdanda/hade/app/provider/user"
"github.com/Superdanda/hade/framework"
"github.com/Superdanda/hade/framework/contract"
"github.com/Superdanda/hade/framework/provider/repository"
"gorm.io/gorm"
)
type UserRepository struct {
container framework.Container
db *gorm.DB
contract.OrmRepository[userModule.User, int64]
userModule.Repository
}
func NewOrmUserRepositoryAndRegister(container framework.Container) {
//获取必要服务对象
connectService := container.MustMake(database_connect.DatabaseConnectKey).(database_connect.Service)
infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService)
repositoryService := container.MustMake(contract.RepositoryKey).(contract.RepositoryService)
connect := connectService.DefaultDatabaseConnect()
userOrmService := &UserRepository{container: container, db: connect}
infrastructureService.RegisterOrmRepository(userModule.UserKey, userOrmService)
//注册通用仓储对象
repository.RegisterRepository[userModule.User, int64](repositoryService, userModule.UserKey, userOrmService)
}
func (u *UserRepository) SaveToDB(entity *userModule.User) error {
u.db.Save(entity)
return nil
}
func (u *UserRepository) FindByIDFromDB(id int64) (*userModule.User, error) {
user := &userModule.User{}
u.db.Find(user, id)
return user, nil
}
func (u *UserRepository) FindByIDsFromDB(ids []int64) ([]*userModule.User, error) {
var users []*userModule.User
// 使用 GORM 的 Where 方法查询用户 ID 在给定 ID 列表中的记录
if err := u.db.Where("id IN ?", ids).Find(&users).Error; err != nil {
return nil, err // 如果查询出错,返回错误
}
return users, nil // 返回查询结果和 nil 错误
}
func (u *UserRepository) GetPrimaryKey(entity *userModule.User) int64 {
return entity.ID
}
func (u *UserRepository) GetBaseField() string {
return userModule.UserKey
}
func (u *UserRepository) GetFieldQueryFunc(fieldName string) (func(value string) ([]*userModule.User, error), bool) {
switch fieldName {
case "Email":
return func(value string) ([]*userModule.User, error) {
var users []*userModule.User
// 执行查询,匹配 Email 字段
if err := u.db.Where("email = ?", value).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}, true
case "UserName":
return func(value string) ([]*userModule.User, error) {
var users []*userModule.User
// 执行查询,匹配 UserName 字段
if err := u.db.Where("user_name = ?", value).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}, true
default:
// 如果传入的字段名不支持,返回 nil 和 false
return nil, false
}
}
func (u *UserRepository) GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*userModule.User, error), bool) {
switch fieldName {
case "Email":
return func(values []string) ([]*userModule.User, error) {
var users []*userModule.User
// 批量查询 Email 字段匹配的用户
if err := u.db.Where("email IN ?", values).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}, true
case "UserName":
return func(values []string) ([]*userModule.User, error) {
var users []*userModule.User
// 批量查询 UserName 字段匹配的用户
if err := u.db.Where("user_name IN ?", values).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}, true
default:
return nil, false // 不支持的字段返回 false
}
}
func (u *UserRepository) GetFieldValueFunc(fieldName string) (func(entity *userModule.User) string, bool) {
switch fieldName {
case "Email":
return func(entity *userModule.User) string {
return entity.Email
}, true
case "UserName":
return func(entity *userModule.User) string {
return entity.UserName
}, true
default:
return nil, false // 不支持的字段返回 false
}
}
func (u *UserRepository) Save(ctx context.Context, user *userModule.User) error {
repositoryService := u.container.MustMake(contract.RepositoryKey).(contract.RepositoryService)
genericRepository := repositoryService.GetGenericRepositoryByKey(userModule.UserKey).(contract.GenericRepository[userModule.User, int64])
if err := genericRepository.Save(ctx, user); err != nil {
return err
}
return nil
}
func (u *UserRepository) FindById(ctx context.Context, id int64) (*userModule.User, error) {
repositoryService := u.container.MustMake(contract.RepositoryKey).(contract.RepositoryService)
genericRepository := repositoryService.GetGenericRepositoryByKey(userModule.UserKey).(contract.GenericRepository[userModule.User, int64])
byID, err := genericRepository.FindByID(ctx, id)
if err != nil {
return nil, err
}
return byID, nil
}

View File

@ -29,9 +29,11 @@ func (s *UserService) SaveUser(ctx context.Context, user *User) error {
func NewUserService(params ...interface{}) (interface{}, error) { func NewUserService(params ...interface{}) (interface{}, error) {
container := params[0].(framework.Container) container := params[0].(framework.Container)
userService := &UserService{container: container}
infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService) infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService)
ormRepository := infrastructureService.GetModuleOrmRepository(UserKey).(Repository) ormRepository := infrastructureService.GetModuleOrmRepository(UserKey).(Repository)
return &UserService{container: container, repository: ormRepository}, nil userService.repository = ormRepository
return userService, nil
} }
func (s *UserService) Foo() string { func (s *UserService) Foo() string {

View File

@ -1,6 +1,6 @@
url: http://127.0.0.1:8066 url: http://127.0.0.1:8066
name: hade name: github.com/Superdanda/hade
swagger_open: true swagger_open: true

View File

@ -20,6 +20,7 @@ import (
func initProviderCommand() *cobra.Command { func initProviderCommand() *cobra.Command {
providerCommand.AddCommand(providerCreateCommand) providerCommand.AddCommand(providerCreateCommand)
providerCommand.AddCommand(providerListCommand) providerCommand.AddCommand(providerListCommand)
providerCommand.AddCommand(providerRepositoryCommand)
return providerCommand return providerCommand
} }
@ -281,6 +282,100 @@ var providerCreateCommand = &cobra.Command{
}, },
} }
var providerRepositoryCommand = &cobra.Command{
Use: "repository",
Short: "创建仓储层实现",
RunE: func(c *cobra.Command, args []string) error {
container := c.GetContainer()
fmt.Println("创建一个仓储层实现")
var name string
var idType string
{
prompt := &survey.Input{
Message: "请输入仓储层实现名称例如user",
}
err := survey.AskOne(prompt, &name)
if err != nil {
return err
}
}
name = strings.TrimSpace(name)
if name == "" {
fmt.Println("服务名称不能为空")
return nil
}
{
prompt := &survey.Input{
Message: "请输入仓储层存储模型的ID类型默认为int64",
}
err := survey.AskOne(prompt, &idType)
if err != nil {
return err
}
}
idType = strings.TrimSpace(idType)
if idType == "" {
idType = "int64"
}
// 检查服务是否存在
// 这里可以添加检查逻辑,防止重复创建
app := container.MustMake(contract.AppKey).(contract.App)
config := container.MustMake(contract.ConfigKey).(contract.Config)
appName := config.GetAppName()
infrastructureDir := app.InfrastructureFolder()
// 准备模板数据
data := map[string]interface{}{
"ModuleAlias": fmt.Sprintf("%sModule", name),
"ModulePath": fmt.Sprintf("%s/app/provider/%v", appName, name),
"StructName": strings.Title(name),
"EntityName": strings.Title(name),
"EntityKey": fmt.Sprintf("%sKey", strings.Title(name)),
"VariableName": name,
"AppName": appName,
"IDType": idType,
}
// 定义模板函数
funcs := template.FuncMap{
"title": strings.Title,
"lower": strings.ToLower,
}
// 解析模板文件
//tmplPath := filepath.Join(app.TemplateFolder(), "repository_template.go.tmpl")
tmpl, err := template.New("repository").Funcs(funcs).Parse(repositoryTmp)
if err != nil {
return err
}
// 确定生成文件的路径
infrastructurePath := filepath.Join(infrastructureDir, fmt.Sprintf("%s.go", name))
if err := os.MkdirAll(infrastructureDir, 0755); err != nil {
return err
}
// 创建并写入文件
file, err := os.Create(infrastructurePath)
if err != nil {
return err
}
defer file.Close()
err = tmpl.Execute(file, data)
if err != nil {
return err
}
fmt.Printf("成功创建服务:%s文件位于%s\n", name, infrastructurePath)
return nil
},
}
func generateControllers(node *RouteNode, pathParts []string, tmpl *template.Template, data map[string]interface{}, moduleFolder string) error { func generateControllers(node *RouteNode, pathParts []string, tmpl *template.Template, data map[string]interface{}, moduleFolder string) error {
// 更新路径部分 // 更新路径部分
newPathParts := append(pathParts, node.Path) newPathParts := append(pathParts, node.Path)
@ -508,3 +603,83 @@ func Convert{{.packageName | title}}ToDTO({{.packageName}} *{{.packageName}}.{{.
return &{{.packageName | title}}DTO{} return &{{.packageName | title}}DTO{}
} }
` `
var repositoryTmp = `package infrastructure
import (
"{{.AppName}}/app/provider/database_connect"
{{.ModuleAlias}} "{{.ModulePath}}"
"github.com/Superdanda/hade/framework"
"github.com/Superdanda/hade/framework/contract"
"github.com/Superdanda/hade/framework/provider/repository"
"gorm.io/gorm"
)
type {{.StructName}}Repository struct {
container framework.Container
db *gorm.DB
contract.OrmRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}]
{{.ModuleAlias}}.Repository
}
func NewOrm{{.StructName}}RepositoryAndRegister(container framework.Container) {
// 获取必要的服务对象
connectService := container.MustMake(database_connect.DatabaseConnectKey).(database_connect.Service)
infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService)
repositoryService := container.MustMake(contract.RepositoryKey).(contract.RepositoryService)
connect := connectService.DefaultDatabaseConnect()
{{.VariableName}}OrmService := &{{.StructName}}Repository{container: container, db: connect}
infrastructureService.RegisterOrmRepository({{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService)
// 注册通用仓储对象
repository.RegisterRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}](repositoryService, {{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService)
}
func (u *{{.StructName}}Repository) SaveToDB(entity *{{.ModuleAlias}}.{{.EntityName}}) error {
return u.db.Save(entity).Error
}
func (u *{{.StructName}}Repository) FindByIDFromDB(id {{.IDType}}) (*{{.ModuleAlias}}.{{.EntityName}}, error) {
entity := &{{.ModuleAlias}}.{{.EntityName}}{}
err := u.db.First(entity, id).Error
return entity, err
}
func (u *{{.StructName}}Repository) FindByIDsFromDB(ids []{{.IDType}}) ([]*{{.ModuleAlias}}.{{.EntityName}}, error) {
var entities []*{{.ModuleAlias}}.{{.EntityName}}
err := u.db.Where("id IN ?", ids).Find(&entities).Error
return entities, err
}
func (u *{{.StructName}}Repository) GetPrimaryKey(entity *{{.ModuleAlias}}.{{.EntityName}}) {{.IDType}} {
return entity.ID
}
func (u *{{.StructName}}Repository) GetBaseField() string {
return {{.ModuleAlias}}.{{.EntityKey}}
}
func (u *{{.StructName}}Repository) GetFieldQueryFunc(fieldName string) (func(value string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) {
switch fieldName {
// 根据您的实际情况添加字段查询函数
default:
return nil, false
}
}
func (u *{{.StructName}}Repository) GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) {
switch fieldName {
// 根据您的实际情况添加字段批量查询函数
default:
return nil, false
}
}
func (u *{{.StructName}}Repository) GetFieldValueFunc(fieldName string) (func(entity *{{.ModuleAlias}}.{{.EntityName}}) string, bool) {
switch fieldName {
// 根据您的实际情况添加获取字段值的函数
default:
return nil, false
}
}
`

View File

@ -40,4 +40,10 @@ type App interface {
AppFolder() string AppFolder() string
// DeployFolder 部署文件夹 // DeployFolder 部署文件夹
DeployFolder() string DeployFolder() string
// InfrastructureFolder 业务层 基础服务设施目录
InfrastructureFolder() string
// TemplateFolder 模板文件夹
TemplateFolder() string
} }

View File

@ -105,6 +105,14 @@ func (h HadeApp) HttpModuleFolder() string {
return deployFolder return deployFolder
} }
func (h HadeApp) InfrastructureFolder() string {
return filepath.Join(h.AppFolder(), "infrastructure")
}
func (h HadeApp) TemplateFolder() string {
return filepath.Join(h.BaseFolder(), "framework", "template")
}
func (h HadeApp) AppId() string { func (h HadeApp) AppId() string {
return h.appId return h.appId
} }

View File

@ -0,0 +1,80 @@
package infrastructure
import (
"context"
"{{.appName}}/app/provider/database_connect"
{{.ModuleAlias}} "{{.ModulePath}}"
"github.com/Superdanda/hade/framework"
"github.com/Superdanda/hade/framework/contract"
"github.com/Superdanda/hade/framework/provider/repository"
"gorm.io/gorm"
)
type {{.StructName}}Repository struct {
container framework.Container
db *gorm.DB
contract.OrmRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}]
{{.ModuleAlias}}.Repository
}
func NewOrm{{.StructName}}RepositoryAndRegister(container framework.Container) {
// 获取必要的服务对象
connectService := container.MustMake(database_connect.DatabaseConnectKey).(database_connect.Service)
infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService)
repositoryService := container.MustMake(contract.RepositoryKey).(contract.RepositoryService)
connect := connectService.DefaultDatabaseConnect()
{{.VariableName}}OrmService := &{{.StructName}}Repository{container: container, db: connect}
infrastructureService.RegisterOrmRepository({{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService)
// 注册通用仓储对象
repository.RegisterRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}](repositoryService, {{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService)
}
func (u *{{.StructName}}Repository) SaveToDB(entity *{{.ModuleAlias}}.{{.EntityName}}) error {
return u.db.Save(entity).Error
}
func (u *{{.StructName}}Repository) FindByIDFromDB(id {{.IDType}}) (*{{.ModuleAlias}}.{{.EntityName}}, error) {
entity := &{{.ModuleAlias}}.{{.EntityName}}{}
err := u.db.First(entity, id).Error
return entity, err
}
func (u *{{.StructName}}Repository) FindByIDsFromDB(ids []{{.IDType}}) ([]*{{.ModuleAlias}}.{{.EntityName}}, error) {
var entities []*{{.ModuleAlias}}.{{.EntityName}}
err := u.db.Where("id IN ?", ids).Find(&entities).Error
return entities, err
}
func (u *{{.StructName}}Repository) GetPrimaryKey(entity *{{.ModuleAlias}}.{{.EntityName}}) {{.IDType}} {
return entity.ID
}
func (u *{{.StructName}}Repository) GetBaseField() string {
return {{.ModuleAlias}}.{{.EntityKey}}
}
func (u *{{.StructName}}Repository) GetFieldQueryFunc(fieldName string) (func(value string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) {
switch fieldName {
// 根据您的实际情况添加字段查询函数
default:
return nil, false
}
}
func (u *{{.StructName}}Repository) GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) {
switch fieldName {
// 根据您的实际情况添加字段批量查询函数
default:
return nil, false
}
}
func (u *{{.StructName}}Repository) GetFieldValueFunc(fieldName string) (func(entity *{{.ModuleAlias}}.{{.EntityName}}) string, bool) {
switch fieldName {
// 根据您的实际情况添加获取字段值的函数
default:
return nil, false
}
}