备份代码
This commit is contained in:
parent
bf525906e4
commit
92ee61df03
|
@ -1,149 +0,0 @@
|
|||
package infrastructure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/Superdanda/hade/app/provider/database_connect"
|
||||
userModule "github.com/Superdanda/hade/app/provider/user"
|
||||
"github.com/Superdanda/hade/framework"
|
||||
"github.com/Superdanda/hade/framework/contract"
|
||||
"github.com/Superdanda/hade/framework/provider/repository"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserRepository struct {
|
||||
container framework.Container
|
||||
db *gorm.DB
|
||||
contract.OrmRepository[userModule.User, int64]
|
||||
userModule.Repository
|
||||
}
|
||||
|
||||
func NewOrmUserRepositoryAndRegister(container framework.Container) {
|
||||
//获取必要服务对象
|
||||
connectService := container.MustMake(database_connect.DatabaseConnectKey).(database_connect.Service)
|
||||
infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService)
|
||||
repositoryService := container.MustMake(contract.RepositoryKey).(contract.RepositoryService)
|
||||
|
||||
connect := connectService.DefaultDatabaseConnect()
|
||||
userOrmService := &UserRepository{container: container, db: connect}
|
||||
infrastructureService.RegisterOrmRepository(userModule.UserKey, userOrmService)
|
||||
|
||||
//注册通用仓储对象
|
||||
repository.RegisterRepository[userModule.User, int64](repositoryService, userModule.UserKey, userOrmService)
|
||||
}
|
||||
|
||||
func (u *UserRepository) SaveToDB(entity *userModule.User) error {
|
||||
u.db.Save(entity)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserRepository) FindByIDFromDB(id int64) (*userModule.User, error) {
|
||||
user := &userModule.User{}
|
||||
u.db.Find(user, id)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (u *UserRepository) FindByIDsFromDB(ids []int64) ([]*userModule.User, error) {
|
||||
var users []*userModule.User
|
||||
// 使用 GORM 的 Where 方法查询用户 ID 在给定 ID 列表中的记录
|
||||
if err := u.db.Where("id IN ?", ids).Find(&users).Error; err != nil {
|
||||
return nil, err // 如果查询出错,返回错误
|
||||
}
|
||||
return users, nil // 返回查询结果和 nil 错误
|
||||
}
|
||||
|
||||
func (u *UserRepository) GetPrimaryKey(entity *userModule.User) int64 {
|
||||
return entity.ID
|
||||
}
|
||||
|
||||
func (u *UserRepository) GetBaseField() string {
|
||||
return userModule.UserKey
|
||||
}
|
||||
|
||||
func (u *UserRepository) GetFieldQueryFunc(fieldName string) (func(value string) ([]*userModule.User, error), bool) {
|
||||
switch fieldName {
|
||||
case "Email":
|
||||
return func(value string) ([]*userModule.User, error) {
|
||||
var users []*userModule.User
|
||||
// 执行查询,匹配 Email 字段
|
||||
if err := u.db.Where("email = ?", value).Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}, true
|
||||
case "UserName":
|
||||
return func(value string) ([]*userModule.User, error) {
|
||||
var users []*userModule.User
|
||||
// 执行查询,匹配 UserName 字段
|
||||
if err := u.db.Where("user_name = ?", value).Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}, true
|
||||
default:
|
||||
// 如果传入的字段名不支持,返回 nil 和 false
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UserRepository) GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*userModule.User, error), bool) {
|
||||
switch fieldName {
|
||||
case "Email":
|
||||
return func(values []string) ([]*userModule.User, error) {
|
||||
var users []*userModule.User
|
||||
// 批量查询 Email 字段匹配的用户
|
||||
if err := u.db.Where("email IN ?", values).Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}, true
|
||||
|
||||
case "UserName":
|
||||
return func(values []string) ([]*userModule.User, error) {
|
||||
var users []*userModule.User
|
||||
// 批量查询 UserName 字段匹配的用户
|
||||
if err := u.db.Where("user_name IN ?", values).Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}, true
|
||||
|
||||
default:
|
||||
return nil, false // 不支持的字段返回 false
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UserRepository) GetFieldValueFunc(fieldName string) (func(entity *userModule.User) string, bool) {
|
||||
switch fieldName {
|
||||
case "Email":
|
||||
return func(entity *userModule.User) string {
|
||||
return entity.Email
|
||||
}, true
|
||||
|
||||
case "UserName":
|
||||
return func(entity *userModule.User) string {
|
||||
return entity.UserName
|
||||
}, true
|
||||
|
||||
default:
|
||||
return nil, false // 不支持的字段返回 false
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UserRepository) Save(ctx context.Context, user *userModule.User) error {
|
||||
repositoryService := u.container.MustMake(contract.RepositoryKey).(contract.RepositoryService)
|
||||
genericRepository := repositoryService.GetGenericRepositoryByKey(userModule.UserKey).(contract.GenericRepository[userModule.User, int64])
|
||||
if err := genericRepository.Save(ctx, user); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserRepository) FindById(ctx context.Context, id int64) (*userModule.User, error) {
|
||||
repositoryService := u.container.MustMake(contract.RepositoryKey).(contract.RepositoryService)
|
||||
genericRepository := repositoryService.GetGenericRepositoryByKey(userModule.UserKey).(contract.GenericRepository[userModule.User, int64])
|
||||
byID, err := genericRepository.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return byID, nil
|
||||
}
|
|
@ -29,9 +29,11 @@ func (s *UserService) SaveUser(ctx context.Context, user *User) error {
|
|||
|
||||
func NewUserService(params ...interface{}) (interface{}, error) {
|
||||
container := params[0].(framework.Container)
|
||||
userService := &UserService{container: container}
|
||||
infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService)
|
||||
ormRepository := infrastructureService.GetModuleOrmRepository(UserKey).(Repository)
|
||||
return &UserService{container: container, repository: ormRepository}, nil
|
||||
userService.repository = ormRepository
|
||||
return userService, nil
|
||||
}
|
||||
|
||||
func (s *UserService) Foo() string {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
url: http://127.0.0.1:8066
|
||||
|
||||
name: hade
|
||||
name: github.com/Superdanda/hade
|
||||
|
||||
swagger_open: true
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
func initProviderCommand() *cobra.Command {
|
||||
providerCommand.AddCommand(providerCreateCommand)
|
||||
providerCommand.AddCommand(providerListCommand)
|
||||
providerCommand.AddCommand(providerRepositoryCommand)
|
||||
return providerCommand
|
||||
}
|
||||
|
||||
|
@ -281,6 +282,100 @@ var providerCreateCommand = &cobra.Command{
|
|||
},
|
||||
}
|
||||
|
||||
var providerRepositoryCommand = &cobra.Command{
|
||||
Use: "repository",
|
||||
Short: "创建仓储层实现",
|
||||
RunE: func(c *cobra.Command, args []string) error {
|
||||
container := c.GetContainer()
|
||||
fmt.Println("创建一个仓储层实现")
|
||||
var name string
|
||||
var idType string
|
||||
{
|
||||
prompt := &survey.Input{
|
||||
Message: "请输入仓储层实现名称(例如:user):",
|
||||
}
|
||||
err := survey.AskOne(prompt, &name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
fmt.Println("服务名称不能为空")
|
||||
return nil
|
||||
}
|
||||
|
||||
{
|
||||
prompt := &survey.Input{
|
||||
Message: "请输入仓储层存储模型的ID类型(默认为:int64):",
|
||||
}
|
||||
err := survey.AskOne(prompt, &idType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
idType = strings.TrimSpace(idType)
|
||||
if idType == "" {
|
||||
idType = "int64"
|
||||
}
|
||||
|
||||
// 检查服务是否存在
|
||||
// 这里可以添加检查逻辑,防止重复创建
|
||||
|
||||
app := container.MustMake(contract.AppKey).(contract.App)
|
||||
config := container.MustMake(contract.ConfigKey).(contract.Config)
|
||||
appName := config.GetAppName()
|
||||
infrastructureDir := app.InfrastructureFolder()
|
||||
|
||||
// 准备模板数据
|
||||
data := map[string]interface{}{
|
||||
"ModuleAlias": fmt.Sprintf("%sModule", name),
|
||||
"ModulePath": fmt.Sprintf("%s/app/provider/%v", appName, name),
|
||||
"StructName": strings.Title(name),
|
||||
"EntityName": strings.Title(name),
|
||||
"EntityKey": fmt.Sprintf("%sKey", strings.Title(name)),
|
||||
"VariableName": name,
|
||||
"AppName": appName,
|
||||
"IDType": idType,
|
||||
}
|
||||
|
||||
// 定义模板函数
|
||||
funcs := template.FuncMap{
|
||||
"title": strings.Title,
|
||||
"lower": strings.ToLower,
|
||||
}
|
||||
|
||||
// 解析模板文件
|
||||
//tmplPath := filepath.Join(app.TemplateFolder(), "repository_template.go.tmpl")
|
||||
tmpl, err := template.New("repository").Funcs(funcs).Parse(repositoryTmp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 确定生成文件的路径
|
||||
infrastructurePath := filepath.Join(infrastructureDir, fmt.Sprintf("%s.go", name))
|
||||
if err := os.MkdirAll(infrastructureDir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建并写入文件
|
||||
file, err := os.Create(infrastructurePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
err = tmpl.Execute(file, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("成功创建服务:%s,文件位于:%s\n", name, infrastructurePath)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func generateControllers(node *RouteNode, pathParts []string, tmpl *template.Template, data map[string]interface{}, moduleFolder string) error {
|
||||
// 更新路径部分
|
||||
newPathParts := append(pathParts, node.Path)
|
||||
|
@ -508,3 +603,83 @@ func Convert{{.packageName | title}}ToDTO({{.packageName}} *{{.packageName}}.{{.
|
|||
return &{{.packageName | title}}DTO{}
|
||||
}
|
||||
`
|
||||
|
||||
var repositoryTmp = `package infrastructure
|
||||
import (
|
||||
"{{.AppName}}/app/provider/database_connect"
|
||||
{{.ModuleAlias}} "{{.ModulePath}}"
|
||||
"github.com/Superdanda/hade/framework"
|
||||
"github.com/Superdanda/hade/framework/contract"
|
||||
"github.com/Superdanda/hade/framework/provider/repository"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type {{.StructName}}Repository struct {
|
||||
container framework.Container
|
||||
db *gorm.DB
|
||||
contract.OrmRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}]
|
||||
{{.ModuleAlias}}.Repository
|
||||
}
|
||||
|
||||
func NewOrm{{.StructName}}RepositoryAndRegister(container framework.Container) {
|
||||
// 获取必要的服务对象
|
||||
connectService := container.MustMake(database_connect.DatabaseConnectKey).(database_connect.Service)
|
||||
infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService)
|
||||
repositoryService := container.MustMake(contract.RepositoryKey).(contract.RepositoryService)
|
||||
|
||||
connect := connectService.DefaultDatabaseConnect()
|
||||
{{.VariableName}}OrmService := &{{.StructName}}Repository{container: container, db: connect}
|
||||
infrastructureService.RegisterOrmRepository({{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService)
|
||||
|
||||
// 注册通用仓储对象
|
||||
repository.RegisterRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}](repositoryService, {{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService)
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) SaveToDB(entity *{{.ModuleAlias}}.{{.EntityName}}) error {
|
||||
return u.db.Save(entity).Error
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) FindByIDFromDB(id {{.IDType}}) (*{{.ModuleAlias}}.{{.EntityName}}, error) {
|
||||
entity := &{{.ModuleAlias}}.{{.EntityName}}{}
|
||||
err := u.db.First(entity, id).Error
|
||||
return entity, err
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) FindByIDsFromDB(ids []{{.IDType}}) ([]*{{.ModuleAlias}}.{{.EntityName}}, error) {
|
||||
var entities []*{{.ModuleAlias}}.{{.EntityName}}
|
||||
err := u.db.Where("id IN ?", ids).Find(&entities).Error
|
||||
return entities, err
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) GetPrimaryKey(entity *{{.ModuleAlias}}.{{.EntityName}}) {{.IDType}} {
|
||||
return entity.ID
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) GetBaseField() string {
|
||||
return {{.ModuleAlias}}.{{.EntityKey}}
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) GetFieldQueryFunc(fieldName string) (func(value string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) {
|
||||
switch fieldName {
|
||||
// 根据您的实际情况添加字段查询函数
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) {
|
||||
switch fieldName {
|
||||
// 根据您的实际情况添加字段批量查询函数
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) GetFieldValueFunc(fieldName string) (func(entity *{{.ModuleAlias}}.{{.EntityName}}) string, bool) {
|
||||
switch fieldName {
|
||||
// 根据您的实际情况添加获取字段值的函数
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
`
|
||||
|
|
|
@ -40,4 +40,10 @@ type App interface {
|
|||
AppFolder() string
|
||||
// DeployFolder 部署文件夹
|
||||
DeployFolder() string
|
||||
|
||||
// InfrastructureFolder 业务层 基础服务设施目录
|
||||
InfrastructureFolder() string
|
||||
|
||||
// TemplateFolder 模板文件夹
|
||||
TemplateFolder() string
|
||||
}
|
||||
|
|
|
@ -105,6 +105,14 @@ func (h HadeApp) HttpModuleFolder() string {
|
|||
return deployFolder
|
||||
}
|
||||
|
||||
func (h HadeApp) InfrastructureFolder() string {
|
||||
return filepath.Join(h.AppFolder(), "infrastructure")
|
||||
}
|
||||
|
||||
func (h HadeApp) TemplateFolder() string {
|
||||
return filepath.Join(h.BaseFolder(), "framework", "template")
|
||||
}
|
||||
|
||||
func (h HadeApp) AppId() string {
|
||||
return h.appId
|
||||
}
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
package infrastructure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"{{.appName}}/app/provider/database_connect"
|
||||
{{.ModuleAlias}} "{{.ModulePath}}"
|
||||
"github.com/Superdanda/hade/framework"
|
||||
"github.com/Superdanda/hade/framework/contract"
|
||||
"github.com/Superdanda/hade/framework/provider/repository"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type {{.StructName}}Repository struct {
|
||||
container framework.Container
|
||||
db *gorm.DB
|
||||
contract.OrmRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}]
|
||||
{{.ModuleAlias}}.Repository
|
||||
}
|
||||
|
||||
func NewOrm{{.StructName}}RepositoryAndRegister(container framework.Container) {
|
||||
// 获取必要的服务对象
|
||||
connectService := container.MustMake(database_connect.DatabaseConnectKey).(database_connect.Service)
|
||||
infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService)
|
||||
repositoryService := container.MustMake(contract.RepositoryKey).(contract.RepositoryService)
|
||||
|
||||
connect := connectService.DefaultDatabaseConnect()
|
||||
{{.VariableName}}OrmService := &{{.StructName}}Repository{container: container, db: connect}
|
||||
infrastructureService.RegisterOrmRepository({{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService)
|
||||
|
||||
// 注册通用仓储对象
|
||||
repository.RegisterRepository[{{.ModuleAlias}}.{{.EntityName}}, {{.IDType}}](repositoryService, {{.ModuleAlias}}.{{.EntityKey}}, {{.VariableName}}OrmService)
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) SaveToDB(entity *{{.ModuleAlias}}.{{.EntityName}}) error {
|
||||
return u.db.Save(entity).Error
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) FindByIDFromDB(id {{.IDType}}) (*{{.ModuleAlias}}.{{.EntityName}}, error) {
|
||||
entity := &{{.ModuleAlias}}.{{.EntityName}}{}
|
||||
err := u.db.First(entity, id).Error
|
||||
return entity, err
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) FindByIDsFromDB(ids []{{.IDType}}) ([]*{{.ModuleAlias}}.{{.EntityName}}, error) {
|
||||
var entities []*{{.ModuleAlias}}.{{.EntityName}}
|
||||
err := u.db.Where("id IN ?", ids).Find(&entities).Error
|
||||
return entities, err
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) GetPrimaryKey(entity *{{.ModuleAlias}}.{{.EntityName}}) {{.IDType}} {
|
||||
return entity.ID
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) GetBaseField() string {
|
||||
return {{.ModuleAlias}}.{{.EntityKey}}
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) GetFieldQueryFunc(fieldName string) (func(value string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) {
|
||||
switch fieldName {
|
||||
// 根据您的实际情况添加字段查询函数
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*{{.ModuleAlias}}.{{.EntityName}}, error), bool) {
|
||||
switch fieldName {
|
||||
// 根据您的实际情况添加字段批量查询函数
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func (u *{{.StructName}}Repository) GetFieldValueFunc(fieldName string) (func(entity *{{.ModuleAlias}}.{{.EntityName}}) string, bool) {
|
||||
switch fieldName {
|
||||
// 根据您的实际情况添加获取字段值的函数
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue