From 92ee61df0336ed45266087802368c8f7ceadb66c Mon Sep 17 00:00:00 2001 From: dandan <1033719135@qq.com> Date: Thu, 31 Oct 2024 23:15:22 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=87=E4=BB=BD=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/infrastructure/user.go | 149 --------------- app/provider/user/service.go | 4 +- config/development/app.yaml | 2 +- framework/command/provider.go | 175 ++++++++++++++++++ framework/contract/app.go | 6 + framework/provider/app/service.go | 8 + .../template/repository_template.go.tmpl | 80 ++++++++ 7 files changed, 273 insertions(+), 151 deletions(-) delete mode 100644 app/infrastructure/user.go create mode 100644 framework/template/repository_template.go.tmpl diff --git a/app/infrastructure/user.go b/app/infrastructure/user.go deleted file mode 100644 index 271163c..0000000 --- a/app/infrastructure/user.go +++ /dev/null @@ -1,149 +0,0 @@ -package infrastructure - -import ( - "context" - "github.com/Superdanda/hade/app/provider/database_connect" - userModule "github.com/Superdanda/hade/app/provider/user" - "github.com/Superdanda/hade/framework" - "github.com/Superdanda/hade/framework/contract" - "github.com/Superdanda/hade/framework/provider/repository" - "gorm.io/gorm" -) - -type UserRepository struct { - container framework.Container - db *gorm.DB - contract.OrmRepository[userModule.User, int64] - userModule.Repository -} - -func NewOrmUserRepositoryAndRegister(container framework.Container) { - //获取必要服务对象 - connectService := container.MustMake(database_connect.DatabaseConnectKey).(database_connect.Service) - infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService) - repositoryService := container.MustMake(contract.RepositoryKey).(contract.RepositoryService) - - connect := connectService.DefaultDatabaseConnect() - userOrmService := &UserRepository{container: container, db: connect} - infrastructureService.RegisterOrmRepository(userModule.UserKey, userOrmService) - - //注册通用仓储对象 - repository.RegisterRepository[userModule.User, int64](repositoryService, userModule.UserKey, userOrmService) -} - -func (u *UserRepository) SaveToDB(entity *userModule.User) error { - u.db.Save(entity) - return nil -} - -func (u *UserRepository) FindByIDFromDB(id int64) (*userModule.User, error) { - user := &userModule.User{} - u.db.Find(user, id) - return user, nil -} - -func (u *UserRepository) FindByIDsFromDB(ids []int64) ([]*userModule.User, error) { - var users []*userModule.User - // 使用 GORM 的 Where 方法查询用户 ID 在给定 ID 列表中的记录 - if err := u.db.Where("id IN ?", ids).Find(&users).Error; err != nil { - return nil, err // 如果查询出错,返回错误 - } - return users, nil // 返回查询结果和 nil 错误 -} - -func (u *UserRepository) GetPrimaryKey(entity *userModule.User) int64 { - return entity.ID -} - -func (u *UserRepository) GetBaseField() string { - return userModule.UserKey -} - -func (u *UserRepository) GetFieldQueryFunc(fieldName string) (func(value string) ([]*userModule.User, error), bool) { - switch fieldName { - case "Email": - return func(value string) ([]*userModule.User, error) { - var users []*userModule.User - // 执行查询,匹配 Email 字段 - if err := u.db.Where("email = ?", value).Find(&users).Error; err != nil { - return nil, err - } - return users, nil - }, true - case "UserName": - return func(value string) ([]*userModule.User, error) { - var users []*userModule.User - // 执行查询,匹配 UserName 字段 - if err := u.db.Where("user_name = ?", value).Find(&users).Error; err != nil { - return nil, err - } - return users, nil - }, true - default: - // 如果传入的字段名不支持,返回 nil 和 false - return nil, false - } -} - -func (u *UserRepository) GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*userModule.User, error), bool) { - switch fieldName { - case "Email": - return func(values []string) ([]*userModule.User, error) { - var users []*userModule.User - // 批量查询 Email 字段匹配的用户 - if err := u.db.Where("email IN ?", values).Find(&users).Error; err != nil { - return nil, err - } - return users, nil - }, true - - case "UserName": - return func(values []string) ([]*userModule.User, error) { - var users []*userModule.User - // 批量查询 UserName 字段匹配的用户 - if err := u.db.Where("user_name IN ?", values).Find(&users).Error; err != nil { - return nil, err - } - return users, nil - }, true - - default: - return nil, false // 不支持的字段返回 false - } -} - -func (u *UserRepository) GetFieldValueFunc(fieldName string) (func(entity *userModule.User) string, bool) { - switch fieldName { - case "Email": - return func(entity *userModule.User) string { - return entity.Email - }, true - - case "UserName": - return func(entity *userModule.User) string { - return entity.UserName - }, true - - default: - return nil, false // 不支持的字段返回 false - } -} - -func (u *UserRepository) Save(ctx context.Context, user *userModule.User) error { - repositoryService := u.container.MustMake(contract.RepositoryKey).(contract.RepositoryService) - genericRepository := repositoryService.GetGenericRepositoryByKey(userModule.UserKey).(contract.GenericRepository[userModule.User, int64]) - if err := genericRepository.Save(ctx, user); err != nil { - return err - } - return nil -} - -func (u *UserRepository) FindById(ctx context.Context, id int64) (*userModule.User, error) { - repositoryService := u.container.MustMake(contract.RepositoryKey).(contract.RepositoryService) - genericRepository := repositoryService.GetGenericRepositoryByKey(userModule.UserKey).(contract.GenericRepository[userModule.User, int64]) - byID, err := genericRepository.FindByID(ctx, id) - if err != nil { - return nil, err - } - return byID, nil -} diff --git a/app/provider/user/service.go b/app/provider/user/service.go index 4aee3fc..90ff524 100644 --- a/app/provider/user/service.go +++ b/app/provider/user/service.go @@ -29,9 +29,11 @@ func (s *UserService) SaveUser(ctx context.Context, user *User) error { func NewUserService(params ...interface{}) (interface{}, error) { container := params[0].(framework.Container) + userService := &UserService{container: container} infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService) ormRepository := infrastructureService.GetModuleOrmRepository(UserKey).(Repository) - return &UserService{container: container, repository: ormRepository}, nil + userService.repository = ormRepository + return userService, nil } func (s *UserService) Foo() string { diff --git a/config/development/app.yaml b/config/development/app.yaml index eda46d7..ac74771 100644 --- a/config/development/app.yaml +++ b/config/development/app.yaml @@ -1,6 +1,6 @@ url: http://127.0.0.1:8066 -name: hade +name: github.com/Superdanda/hade swagger_open: true diff --git a/framework/command/provider.go b/framework/command/provider.go index 1d6ea76..cc7f104 100644 --- a/framework/command/provider.go +++ b/framework/command/provider.go @@ -20,6 +20,7 @@ import ( func initProviderCommand() *cobra.Command { providerCommand.AddCommand(providerCreateCommand) providerCommand.AddCommand(providerListCommand) + providerCommand.AddCommand(providerRepositoryCommand) return providerCommand } @@ -281,6 +282,100 @@ var providerCreateCommand = &cobra.Command{ }, } +var providerRepositoryCommand = &cobra.Command{ + Use: "repository", + Short: "创建仓储层实现", + RunE: func(c *cobra.Command, args []string) error { + container := c.GetContainer() + fmt.Println("创建一个仓储层实现") + var name string + var idType string + { + prompt := &survey.Input{ + Message: "请输入仓储层实现名称(例如:user):", + } + err := survey.AskOne(prompt, &name) + if err != nil { + return err + } + } + + name = strings.TrimSpace(name) + if name == "" { + fmt.Println("服务名称不能为空") + return nil + } + + { + prompt := &survey.Input{ + Message: "请输入仓储层存储模型的ID类型(默认为:int64):", + } + err := survey.AskOne(prompt, &idType) + if err != nil { + return err + } + } + idType = strings.TrimSpace(idType) + if idType == "" { + idType = "int64" + } + + // 检查服务是否存在 + // 这里可以添加检查逻辑,防止重复创建 + + app := container.MustMake(contract.AppKey).(contract.App) + config := container.MustMake(contract.ConfigKey).(contract.Config) + appName := config.GetAppName() + infrastructureDir := app.InfrastructureFolder() + + // 准备模板数据 + data := map[string]interface{}{ + "ModuleAlias": fmt.Sprintf("%sModule", name), + "ModulePath": fmt.Sprintf("%s/app/provider/%v", appName, name), + "StructName": strings.Title(name), + "EntityName": strings.Title(name), + "EntityKey": fmt.Sprintf("%sKey", strings.Title(name)), + "VariableName": name, + "AppName": appName, + "IDType": idType, + } + + // 定义模板函数 + funcs := template.FuncMap{ + "title": strings.Title, + "lower": strings.ToLower, + } + + // 解析模板文件 + //tmplPath := filepath.Join(app.TemplateFolder(), "repository_template.go.tmpl") + tmpl, err := template.New("repository").Funcs(funcs).Parse(repositoryTmp) + if err != nil { + return err + } + + // 确定生成文件的路径 + infrastructurePath := filepath.Join(infrastructureDir, fmt.Sprintf("%s.go", name)) + if err := os.MkdirAll(infrastructureDir, 0755); err != nil { + return err + } + + // 创建并写入文件 + file, err := os.Create(infrastructurePath) + if err != nil { + return err + } + defer file.Close() + + err = tmpl.Execute(file, data) + if err != nil { + return err + } + + fmt.Printf("成功创建服务:%s,文件位于:%s\n", name, infrastructurePath) + return nil + }, +} + func generateControllers(node *RouteNode, pathParts []string, tmpl *template.Template, data map[string]interface{}, moduleFolder string) error { // 更新路径部分 newPathParts := append(pathParts, node.Path) @@ -508,3 +603,83 @@ func Convert{{.packageName | title}}ToDTO({{.packageName}} *{{.packageName}}.{{. return &{{.packageName | title}}DTO{} } ` + +var repositoryTmp = `package infrastructure +import ( + "{{.AppName}}/app/provider/database_connect" + {{.ModuleAlias}} "{{.ModulePath}}" + "github.com/Superdanda/hade/framework" + "github.com/Superdanda/hade/framework/contract" + "github.com/Superdanda/hade/framework/provider/repository" + "gorm.io/gorm" +) + +type {{.StructName}}Repository struct { + container framework.Container + db *gorm.DB + contract.OrmRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}] + {{.ModuleAlias}}.Repository +} + +func NewOrm{{.StructName}}RepositoryAndRegister(container framework.Container) { + // 获取必要的服务对象 + connectService := container.MustMake(database_connect.DatabaseConnectKey).(database_connect.Service) + infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService) + repositoryService := container.MustMake(contract.RepositoryKey).(contract.RepositoryService) + + connect := connectService.DefaultDatabaseConnect() + {{.VariableName}}OrmService := &{{.StructName}}Repository{container: container, db: connect} + infrastructureService.RegisterOrmRepository({{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService) + + // 注册通用仓储对象 + repository.RegisterRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}](repositoryService, {{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService) +} + +func (u *{{.StructName}}Repository) SaveToDB(entity *{{.ModuleAlias}}.{{.EntityName}}) error { + return u.db.Save(entity).Error +} + +func (u *{{.StructName}}Repository) FindByIDFromDB(id {{.IDType}}) (*{{.ModuleAlias}}.{{.EntityName}}, error) { + entity := &{{.ModuleAlias}}.{{.EntityName}}{} + err := u.db.First(entity, id).Error + return entity, err +} + +func (u *{{.StructName}}Repository) FindByIDsFromDB(ids []{{.IDType}}) ([]*{{.ModuleAlias}}.{{.EntityName}}, error) { + var entities []*{{.ModuleAlias}}.{{.EntityName}} + err := u.db.Where("id IN ?", ids).Find(&entities).Error + return entities, err +} + +func (u *{{.StructName}}Repository) GetPrimaryKey(entity *{{.ModuleAlias}}.{{.EntityName}}) {{.IDType}} { + return entity.ID +} + +func (u *{{.StructName}}Repository) GetBaseField() string { + return {{.ModuleAlias}}.{{.EntityKey}} +} + +func (u *{{.StructName}}Repository) GetFieldQueryFunc(fieldName string) (func(value string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) { + switch fieldName { + // 根据您的实际情况添加字段查询函数 + default: + return nil, false + } +} + +func (u *{{.StructName}}Repository) GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) { + switch fieldName { + // 根据您的实际情况添加字段批量查询函数 + default: + return nil, false + } +} + +func (u *{{.StructName}}Repository) GetFieldValueFunc(fieldName string) (func(entity *{{.ModuleAlias}}.{{.EntityName}}) string, bool) { + switch fieldName { + // 根据您的实际情况添加获取字段值的函数 + default: + return nil, false + } +} +` diff --git a/framework/contract/app.go b/framework/contract/app.go index 45b9371..e075e0c 100644 --- a/framework/contract/app.go +++ b/framework/contract/app.go @@ -40,4 +40,10 @@ type App interface { AppFolder() string // DeployFolder 部署文件夹 DeployFolder() string + + // InfrastructureFolder 业务层 基础服务设施目录 + InfrastructureFolder() string + + // TemplateFolder 模板文件夹 + TemplateFolder() string } diff --git a/framework/provider/app/service.go b/framework/provider/app/service.go index 5308373..d6900f6 100644 --- a/framework/provider/app/service.go +++ b/framework/provider/app/service.go @@ -105,6 +105,14 @@ func (h HadeApp) HttpModuleFolder() string { return deployFolder } +func (h HadeApp) InfrastructureFolder() string { + return filepath.Join(h.AppFolder(), "infrastructure") +} + +func (h HadeApp) TemplateFolder() string { + return filepath.Join(h.BaseFolder(), "framework", "template") +} + func (h HadeApp) AppId() string { return h.appId } diff --git a/framework/template/repository_template.go.tmpl b/framework/template/repository_template.go.tmpl new file mode 100644 index 0000000..3911b18 --- /dev/null +++ b/framework/template/repository_template.go.tmpl @@ -0,0 +1,80 @@ +package infrastructure + +import ( + "context" + "{{.appName}}/app/provider/database_connect" + {{.ModuleAlias}} "{{.ModulePath}}" + "github.com/Superdanda/hade/framework" + "github.com/Superdanda/hade/framework/contract" + "github.com/Superdanda/hade/framework/provider/repository" + "gorm.io/gorm" +) + +type {{.StructName}}Repository struct { + container framework.Container + db *gorm.DB + contract.OrmRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}] + {{.ModuleAlias}}.Repository +} + +func NewOrm{{.StructName}}RepositoryAndRegister(container framework.Container) { + // 获取必要的服务对象 + connectService := container.MustMake(database_connect.DatabaseConnectKey).(database_connect.Service) + infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService) + repositoryService := container.MustMake(contract.RepositoryKey).(contract.RepositoryService) + + connect := connectService.DefaultDatabaseConnect() + {{.VariableName}}OrmService := &{{.StructName}}Repository{container: container, db: connect} + infrastructureService.RegisterOrmRepository({{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService) + + // 注册通用仓储对象 + repository.RegisterRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}](repositoryService, {{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService) +} + +func (u *{{.StructName}}Repository) SaveToDB(entity *{{.ModuleAlias}}.{{.EntityName}}) error { + return u.db.Save(entity).Error +} + +func (u *{{.StructName}}Repository) FindByIDFromDB(id {{.IDType}}) (*{{.ModuleAlias}}.{{.EntityName}}, error) { + entity := &{{.ModuleAlias}}.{{.EntityName}}{} + err := u.db.First(entity, id).Error + return entity, err +} + +func (u *{{.StructName}}Repository) FindByIDsFromDB(ids []{{.IDType}}) ([]*{{.ModuleAlias}}.{{.EntityName}}, error) { + var entities []*{{.ModuleAlias}}.{{.EntityName}} + err := u.db.Where("id IN ?", ids).Find(&entities).Error + return entities, err +} + +func (u *{{.StructName}}Repository) GetPrimaryKey(entity *{{.ModuleAlias}}.{{.EntityName}}) {{.IDType}} { + return entity.ID +} + +func (u *{{.StructName}}Repository) GetBaseField() string { + return {{.ModuleAlias}}.{{.EntityKey}} +} + +func (u *{{.StructName}}Repository) GetFieldQueryFunc(fieldName string) (func(value string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) { + switch fieldName { + // 根据您的实际情况添加字段查询函数 + default: + return nil, false + } +} + +func (u *{{.StructName}}Repository) GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) { + switch fieldName { + // 根据您的实际情况添加字段批量查询函数 + default: + return nil, false + } +} + +func (u *{{.StructName}}Repository) GetFieldValueFunc(fieldName string) (func(entity *{{.ModuleAlias}}.{{.EntityName}}) string, bool) { + switch fieldName { + // 根据您的实际情况添加获取字段值的函数 + default: + return nil, false + } +}