diff --git a/app/http/module/user/api.go b/app/http/module/user/api.go new file mode 100644 index 0000000..9a77f29 --- /dev/null +++ b/app/http/module/user/api.go @@ -0,0 +1,31 @@ +package user + +import ( + "github.com/Superdanda/hade/app/provider/user" + "github.com/Superdanda/hade/framework/gin" +) + +type UserApi struct{} + +// 注册路由 +func RegisterRoutes(r *gin.Engine) error { + + api := UserApi{} + + if !r.IsBind(user.UserKey) { + r.Bind(&user.UserProvider{}) + } + + Group := r.Group("/") + { + + userGroup := Group.Group("/user") + { + + userGroup.POST("/login", api.UserLogin) + + } + } + + return nil +} diff --git a/app/http/module/user/api_user_login.go b/app/http/module/user/api_user_login.go new file mode 100644 index 0000000..4dea43c --- /dev/null +++ b/app/http/module/user/api_user_login.go @@ -0,0 +1,10 @@ +package user + +import ( + "github.com/Superdanda/hade/framework/gin" +) + +// UserLogin handler +func (api *UserApi) UserLogin(c *gin.Context) { + // TODO: Implement UserLogin +} diff --git a/app/http/module/user/dto.go b/app/http/module/user/dto.go new file mode 100644 index 0000000..dbe698c --- /dev/null +++ b/app/http/module/user/dto.go @@ -0,0 +1,3 @@ +package user + +type UserDTO struct{} diff --git a/app/http/module/user/mapper.go b/app/http/module/user/mapper.go new file mode 100644 index 0000000..516c5c1 --- /dev/null +++ b/app/http/module/user/mapper.go @@ -0,0 +1,10 @@ +package user + +import "github.com/Superdanda/hade/app/provider/user" + +func ConvertUserToDTO(user *user.User) *UserDTO { + if user == nil { + return nil + } + return &UserDTO{} +} diff --git a/app/http/route.go b/app/http/route.go index be80f18..fe4eb8f 100644 --- a/app/http/route.go +++ b/app/http/route.go @@ -2,6 +2,7 @@ package http import ( "github.com/Superdanda/hade/app/http/module/demo" + "github.com/Superdanda/hade/app/http/module/user" "github.com/Superdanda/hade/framework/contract" "github.com/Superdanda/hade/framework/gin" ginSwagger "github.com/Superdanda/hade/framework/middleware/gin-swagger" @@ -23,6 +24,8 @@ func Routes(core *gin.Engine) { err := demo.Register(core) + err = user.RegisterRoutes(core) + if err != nil { return } diff --git a/app/infrastructure/user.go b/app/infrastructure/user.go new file mode 100644 index 0000000..1a45eef --- /dev/null +++ b/app/infrastructure/user.go @@ -0,0 +1,143 @@ +package infrastructure + +import ( + "context" + "github.com/Superdanda/hade/app/provider/database_connect" + userModule "github.com/Superdanda/hade/app/provider/user" + "github.com/Superdanda/hade/framework" + "github.com/Superdanda/hade/framework/contract" + "gorm.io/gorm" +) + +type UserRepository struct { + container framework.Container + db *gorm.DB + contract.OrmRepository[userModule.User, int64] + userModule.Repository +} + +func NewUserRepository(container framework.Container) contract.OrmRepository[userModule.User, int64] { + connectService := container.MustMake(database_connect.DatabaseConnectKey).(database_connect.Service) + connect := connectService.DefaultDatabaseConnect() + userOrmService := &UserRepository{container: container, db: connect} + infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService) + infrastructureService.RegisterOrmRepository(userModule.UserKey, userOrmService) + + //repository.RegisterRepository[userModule.User, int64](userModule.UserKey,) + return userOrmService +} + +func (u *UserRepository) SaveToDB(entity *userModule.User) error { + u.db.Save(entity) + return nil +} + +func (u *UserRepository) FindByIDFromDB(id int64) (*userModule.User, error) { + user := &userModule.User{} + u.db.Find(user, id) + return user, nil +} + +func (u *UserRepository) FindByID64sFromDB(ids []int64) ([]*userModule.User, error) { + var users []*userModule.User + // 使用 GORM 的 Where 方法查询用户 ID 在给定 ID 列表中的记录 + if err := u.db.Where("id IN ?", ids).Find(&users).Error; err != nil { + return nil, err // 如果查询出错,返回错误 + } + return users, nil // 返回查询结果和 nil 错误 +} + +func (u *UserRepository) GetPrimaryKey(entity *userModule.User) int64 { + return entity.ID +} + +func (u *UserRepository) GetBaseField() string { + return userModule.UserKey +} + +func (u *UserRepository) GetFieldQueryFunc(fieldName string) (func(value string) ([]*userModule.User, error), bool) { + switch fieldName { + case "Email": + return func(value string) ([]*userModule.User, error) { + var users []*userModule.User + // 执行查询,匹配 Email 字段 + if err := u.db.Where("email = ?", value).Find(&users).Error; err != nil { + return nil, err + } + return users, nil + }, true + case "UserName": + return func(value string) ([]*userModule.User, error) { + var users []*userModule.User + // 执行查询,匹配 UserName 字段 + if err := u.db.Where("user_name = ?", value).Find(&users).Error; err != nil { + return nil, err + } + return users, nil + }, true + default: + // 如果传入的字段名不支持,返回 nil 和 false + return nil, false + } +} + +func (u *UserRepository) GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*userModule.User, error), bool) { + switch fieldName { + case "Email": + return func(values []string) ([]*userModule.User, error) { + var users []*userModule.User + // 批量查询 Email 字段匹配的用户 + if err := u.db.Where("email IN ?", values).Find(&users).Error; err != nil { + return nil, err + } + return users, nil + }, true + + case "UserName": + return func(values []string) ([]*userModule.User, error) { + var users []*userModule.User + // 批量查询 UserName 字段匹配的用户 + if err := u.db.Where("user_name IN ?", values).Find(&users).Error; err != nil { + return nil, err + } + return users, nil + }, true + + default: + return nil, false // 不支持的字段返回 false + } +} + +func (u *UserRepository) GetFieldValueFunc(fieldName string) (func(entity *userModule.User) string, bool) { + switch fieldName { + case "Email": + return func(entity *userModule.User) string { + return entity.Email + }, true + + case "UserName": + return func(entity *userModule.User) string { + return entity.UserName + }, true + + default: + return nil, false // 不支持的字段返回 false + } +} + +func (u *UserRepository) Save(ctx context.Context, user *userModule.User) error { + repository := u.container.MustMake(contract.RepositoryKey).(contract.Repository[userModule.User, int64]) + if err := repository.Save(ctx, userModule.UserKey, user); err != nil { + return err + } + return nil +} + +func (u *UserRepository) FindById(ctx context.Context, id int64) (*userModule.User, error) { + repository := u.container.MustMake(contract.RepositoryKey).(contract.Repository[userModule.User, int64]) + byID, err := repository.FindByID(ctx, userModule.UserKey, id) + if err != nil { + return nil, err + } + return byID, nil +} diff --git a/app/provider/database_connect/contract.go b/app/provider/database_connect/contract.go index 457a516..d1556a8 100644 --- a/app/provider/database_connect/contract.go +++ b/app/provider/database_connect/contract.go @@ -5,6 +5,7 @@ import "gorm.io/gorm" const DatabaseConnectKey = "hade:database_connect" type Service interface { + DefaultDatabaseConnect() *gorm.DB LocalDatabaseConnect() *gorm.DB AliDataBaseConnect() *gorm.DB } diff --git a/app/provider/database_connect/service.go b/app/provider/database_connect/service.go index 133d212..aaaba9b 100644 --- a/app/provider/database_connect/service.go +++ b/app/provider/database_connect/service.go @@ -20,6 +20,10 @@ func (d DatabaseConnectService) AliDataBaseConnect() *gorm.DB { return getDatabaseConnectByYaml("database.ali", d) } +func (d DatabaseConnectService) DefaultDatabaseConnect() *gorm.DB { + return getDatabaseConnectByYaml("database.ali", d) +} + func NewDatabaseConnectService(params ...interface{}) (interface{}, error) { container := params[0].(framework.Container) return &DatabaseConnectService{container: container}, nil diff --git a/app/provider/user/contract.go b/app/provider/user/contract.go new file mode 100644 index 0000000..c3f5285 --- /dev/null +++ b/app/provider/user/contract.go @@ -0,0 +1,26 @@ +package user + +import ( + "context" + "time" +) + +const UserKey = "user" + +type Service interface { + // GetUser 获取用户信息 + GetUser(ctx context.Context, userID int64) (*User, error) + + // SaveUser 保存用户信息 + SaveUser(ctx context.Context, user *User) error +} + +type User struct { + ID int64 `gorm:"column:id;primary_key;auto_increment" json:"id"` // 代表用户id, 只有注册成功之后才有这个id,唯一表示一个用户 + UserName string `gorm:"column:username;type:varchar(255);comment:用户名;not null" json:"username"` + NickName string `gorm:"column:username;type:varchar(255);comment:昵称;not null" json:"nickname"` + Avatar string `gorm:"column:username;type:varchar(255);comment:头像" json:"avatar"` + Password string `gorm:"column:password;type:varchar(255);comment:密码;not null" json:"password"` + Email string `gorm:"column:email;type:varchar(255);comment:邮箱;not null" json:"email"` + CreatedAt time.Time `gorm:"column:created_at;type:datetime;comment:创建时间;not null;<-:create" json:"createdAt"` +} diff --git a/app/provider/user/provider.go b/app/provider/user/provider.go new file mode 100644 index 0000000..5aca166 --- /dev/null +++ b/app/provider/user/provider.go @@ -0,0 +1,31 @@ +package user + +import ( + "github.com/Superdanda/hade/framework" +) + +type UserProvider struct { + framework.ServiceProvider + + c framework.Container +} + +func (sp *UserProvider) Name() string { + return UserKey +} + +func (sp *UserProvider) Register(c framework.Container) framework.NewInstance { + return NewUserService +} + +func (sp *UserProvider) IsDefer() bool { + return false +} + +func (sp *UserProvider) Params(c framework.Container) []interface{} { + return []interface{}{c} +} + +func (sp *UserProvider) Boot(c framework.Container) error { + return nil +} diff --git a/app/provider/user/repository.go b/app/provider/user/repository.go new file mode 100644 index 0000000..9b9acb9 --- /dev/null +++ b/app/provider/user/repository.go @@ -0,0 +1,8 @@ +package user + +import "context" + +type Repository interface { + Save(ctx context.Context, user *User) error + FindById(ctx context.Context, id int64) (*User, error) +} diff --git a/app/provider/user/service.go b/app/provider/user/service.go new file mode 100644 index 0000000..aac3dc1 --- /dev/null +++ b/app/provider/user/service.go @@ -0,0 +1,40 @@ +package user + +import ( + "context" + "github.com/Superdanda/hade/framework" + "github.com/Superdanda/hade/framework/contract" + "github.com/Superdanda/hade/framework/provider/infrastructure" +) + +type UserService struct { + container framework.Container + repository Repository +} + +func (s *UserService) GetUser(ctx context.Context, userID int64) (*User, error) { + user, err := s.repository.FindById(ctx, userID) + if err != nil { + return nil, err + } + return user, nil +} + +func (s *UserService) SaveUser(ctx context.Context, user *User) error { + err := s.repository.Save(ctx, user) + if err != nil { + return err + } + return nil +} + +func NewUserService(params ...interface{}) (interface{}, error) { + container := params[0].(framework.Container) + infrastructureService := container.MustMake(contract.InfrastructureKey).(infrastructure.Service) + ormRepository := infrastructureService.GetModuleOrmRepository(UserKey).(Repository) + return &UserService{container: container, repository: ormRepository}, nil +} + +func (s *UserService) Foo() string { + return "" +} diff --git a/config/development/cache.yaml b/config/development/cache.yaml index 6f6f71a..a312892 100644 --- a/config/development/cache.yaml +++ b/config/development/cache.yaml @@ -1 +1,4 @@ driver: memory + +repository: + expire: 6h \ No newline at end of file diff --git a/framework/contract/infrastructure.go b/framework/contract/infrastructure.go new file mode 100644 index 0000000..cd46802 --- /dev/null +++ b/framework/contract/infrastructure.go @@ -0,0 +1,9 @@ +package contract + +const InfrastructureKey = "hade:infrastructure" + +type InfrastructureService interface { + // GetModuleOrmRepository 通过模块名称来获取 对应的基础设施 -仓储层实现类 + GetModuleOrmRepository(moduleName string) interface{} + RegisterOrmRepository(moduleName string, repository interface{}) +} diff --git a/framework/contract/repository.go b/framework/contract/repository.go index 88ff470..d565657 100644 --- a/framework/contract/repository.go +++ b/framework/contract/repository.go @@ -1,13 +1,33 @@ package contract -import "context" +import ( + "context" + "github.com/Superdanda/hade/framework" +) const RepositoryKey = "hade:repository" -type Repository[T any, ID comparable] interface { +type RepositoryService interface { + GetGenericRepositoryByKey(key string) interface{} + GetGenericRepositoryMap() map[string]interface{} + GetContainer() framework.Container +} + +type GenericRepository[T any, ID comparable] interface { Save(ctx context.Context, entity *T) error FindByID(ctx context.Context, id ID) (*T, error) - FindByField(ctx context.Context, fieldName string, value any) ([]*T, error) + FindByField(ctx context.Context, fieldName string, value string) ([]*T, error) FindByIDs(ctx context.Context, ids []ID) ([]*T, error) - FindByFieldIn(ctx context.Context, fieldName string, values []any) ([]*T, error) + FindByFieldIn(ctx context.Context, fieldName string, values []string) ([]*T, error) +} + +type OrmRepository[T any, ID comparable] interface { + SaveToDB(entity *T) error + FindByIDFromDB(id ID) (*T, error) + FindByIDsFromDB(ids []ID) ([]*T, error) + GetPrimaryKey(entity *T) ID + GetBaseField() string + GetFieldQueryFunc(fieldName string) (func(value string) ([]*T, error), bool) + GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*T, error), bool) + GetFieldValueFunc(fieldName string) (func(entity *T) string, bool) } diff --git a/framework/provider/id/provider.go b/framework/provider/id/provider.go index ef14491..2ca34f9 100644 --- a/framework/provider/id/provider.go +++ b/framework/provider/id/provider.go @@ -1,12 +1,15 @@ package id -import "github.com/Superdanda/hade/framework" +import ( + "github.com/Superdanda/hade/framework" + "github.com/Superdanda/hade/framework/contract" +) type HadeIDProvider struct { } func (h HadeIDProvider) Register(container framework.Container) framework.NewInstance { - return NewHadeIDService + return nil } func (h HadeIDProvider) Boot(container framework.Container) error { diff --git a/framework/provider/infrastructure/provider.go b/framework/provider/infrastructure/provider.go new file mode 100644 index 0000000..78ca861 --- /dev/null +++ b/framework/provider/infrastructure/provider.go @@ -0,0 +1,28 @@ +package infrastructure + +import "github.com/Superdanda/hade/framework" + +type InfrastructureProvider struct { +} + +const InfrastructureKey = "hade:infrastructure" + +func (i *InfrastructureProvider) Register(container framework.Container) framework.NewInstance { + return NewInfrastructureService +} + +func (i *InfrastructureProvider) Boot(container framework.Container) error { + return nil +} + +func (i *InfrastructureProvider) IsDefer() bool { + return false +} + +func (i *InfrastructureProvider) Params(container framework.Container) []interface{} { + return []interface{}{container} +} + +func (i *InfrastructureProvider) Name() string { + return InfrastructureKey +} diff --git a/framework/provider/infrastructure/service.go b/framework/provider/infrastructure/service.go new file mode 100644 index 0000000..51923de --- /dev/null +++ b/framework/provider/infrastructure/service.go @@ -0,0 +1,25 @@ +package infrastructure + +import ( + "github.com/Superdanda/hade/framework" + "github.com/Superdanda/hade/framework/contract" + _ "github.com/Superdanda/hade/framework/provider/repository" +) + +type Service struct { + container framework.Container + contract.InfrastructureService + ormRepositoryMap map[string]interface{} +} + +func NewInfrastructureService(params ...interface{}) (interface{}, error) { + return &Service{container: params[0].(framework.Container)}, nil +} + +func (i *Service) GetModuleOrmRepository(moduleName string) interface{} { + return i.ormRepositoryMap[moduleName] +} + +func (i *Service) RegisterOrmRepository(moduleName string, repository interface{}) { + i.ormRepositoryMap[moduleName] = repository +} diff --git a/framework/provider/repository/provider.go b/framework/provider/repository/provider.go index 28e1a7a..ff664ef 100644 --- a/framework/provider/repository/provider.go +++ b/framework/provider/repository/provider.go @@ -8,8 +8,7 @@ import ( type RepositoryProvider struct{} func (r RepositoryProvider) Register(container framework.Container) framework.NewInstance { - //TODO implement me - panic("implement me") + return NewHadeRepositoryService } func (r RepositoryProvider) Boot(container framework.Container) error { diff --git a/framework/provider/repository/service.go b/framework/provider/repository/service.go index fe755ae..f10cce5 100644 --- a/framework/provider/repository/service.go +++ b/framework/provider/repository/service.go @@ -2,17 +2,45 @@ package repository import ( "context" + "encoding/json" + "fmt" "github.com/Superdanda/hade/framework" "github.com/Superdanda/hade/framework/contract" "github.com/pkg/errors" + "time" ) -type HadeRepositoryService[T any, ID comparable] struct { - container framework.Container - cacheService contract.CacheService +func RegisterRepository[T any, ID comparable](service contract.RepositoryService, key string, ormRepository interface{}) { + container := service.GetContainer() + cacheService := container.MustMake(contract.CacheKey).(contract.CacheService) + configService := container.MustMake(contract.ConfigKey).(contract.Config) + expireTime := configService.GetString("cache.repository.expire") + if expireTime == "" { + fmt.Println("从配置文件获取缓存失败,使用默认缓存时间6个小时") + } + expireTime = "6h" + duration, err := time.ParseDuration(expireTime) + if err != nil { + fmt.Println("从配置文件获取缓存失败,使用默认缓存时间6个小时") + } else { + duration = 6 * time.Hour + } + genericRepository := NewHadeGenericRepository[T, ID]( + key, + container, + NewHadeCacheRepository[T, ID](cacheService, duration), + ormRepository.(contract.OrmRepository[T, ID]), + ) + service.GetGenericRepositoryMap()[key] = genericRepository } -func NewHadeRepositoryService[T any, ID comparable](params ...interface{}) (interface{}, error) { +type HadeRepositoryService struct { + container framework.Container + genericRepositoryMap map[string]interface{} + contract.RepositoryService +} + +func NewHadeRepositoryService(params ...interface{}) (interface{}, error) { if len(params) < 2 { return nil, errors.New("insufficient parameters") } @@ -20,41 +48,385 @@ func NewHadeRepositoryService[T any, ID comparable](params ...interface{}) (inte if !ok { return nil, errors.New("invalid container parameter") } - cacheService, ok := params[1].(contract.CacheService) - if !ok { - return nil, errors.New("invalid cacheService parameter") - } - return &HadeRepositoryService[T, ID]{ - container: container, - cacheService: cacheService, + return &HadeRepositoryService{ + container: container, + genericRepositoryMap: make(map[string]interface{}), }, nil } -func (h *HadeRepositoryService[T, ID]) Save(ctx context.Context, entity *T) error { - //TODO implement me - panic("implement me") +func (h *HadeRepositoryService) GetGenericRepositoryByKey(key string) interface{} { + return h.genericRepositoryMap[key] } -func (h HadeRepositoryService[T, ID]) FindByID(ctx context.Context, id ID) (*T, error) { - //TODO implement me - panic("implement me") +func (h *HadeRepositoryService) GetGenericRepositoryMap() map[string]interface{} { + return h.genericRepositoryMap +} +func (h *HadeRepositoryService) GetContainer() framework.Container { + return h.container } -func (h HadeRepositoryService[T, ID]) FindByField(ctx context.Context, fieldName string, value any) ([]*T, error) { - //TODO implement me - panic("implement me") +type HadeGenericRepository[T any, ID comparable] struct { + repositoryKey string + container framework.Container + cacheRepository *HadeCacheRepository[T, ID] + ormRepository contract.OrmRepository[T, ID] + contract.GenericRepository[T, ID] } -func (h HadeRepositoryService[T, ID]) FindByIDs(ctx context.Context, ids []ID) ([]*T, error) { - //TODO implement me - panic("implement me") -} - -func (h HadeRepositoryService[T, ID]) FindByFieldIn(ctx context.Context, fieldName string, values []any) ([]*T, error) { - //TODO implement me - panic("implement me") +func NewHadeGenericRepository[T any, ID comparable](repositoryKey string, container framework.Container, cacheRepository *HadeCacheRepository[T, ID], + ormRepository contract.OrmRepository[T, ID]) *HadeGenericRepository[T, ID] { + return &HadeGenericRepository[T, ID]{repositoryKey: repositoryKey, container: container, cacheRepository: cacheRepository, ormRepository: ormRepository} } type HadeCacheRepository[T any, ID comparable] struct { - cacheService contract.CacheService + cacheService contract.CacheService + cacheExpiration time.Duration +} + +func NewHadeCacheRepository[T any, ID comparable](cacheService contract.CacheService, cacheExpiration time.Duration) *HadeCacheRepository[T, ID] { + return &HadeCacheRepository[T, ID]{ + cacheService: cacheService, + cacheExpiration: cacheExpiration, + } +} + +func (g *HadeGenericRepository[T, ID]) Save(ctx context.Context, entity *T) error { + ormRepository := g.ormRepository + err := ormRepository.SaveToDB(entity) + if err != nil { + return err + } + err = g.updateCache(ctx, entity) + if err != nil { + return err + } + return nil +} + +func (g *HadeGenericRepository[T, ID]) updateCache(ctx context.Context, entity *T) error { + ormRepository := g.ormRepository + id := ormRepository.GetPrimaryKey(entity) + return g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity) +} + +func (g *HadeGenericRepository[T, ID]) FindByID(ctx context.Context, id ID) (*T, error) { + ormRepository := g.ormRepository + cache, _ := g.cacheRepository.FindFromCache(ctx, ormRepository.GetBaseField(), id) + if cache != nil { + return cache, nil + } + entityFromDB, err := ormRepository.FindByIDFromDB(id) + if err != nil { + return nil, err + } + go g.updateCache(ctx, entityFromDB) + return entityFromDB, nil +} + +func (g *HadeGenericRepository[T, ID]) FindByField(ctx context.Context, fieldName string, value string) ([]*T, error) { + ormRepository := g.ormRepository + // 获取字段查询函数 + queryFunc, ok := ormRepository.GetFieldQueryFunc(fieldName) + if !ok { + return nil, fmt.Errorf("no query function found for field: %s", fieldName) + } + + // 从缓存中获取 IDs + ids, err := g.cacheRepository.FindIDsFromCache(ctx, ormRepository.GetBaseField(), fieldName, value) + if err != nil || ids == nil || len(ids) == 0 { + // 缓存未命中,从数据库查询 + entitiesFromDB, err := queryFunc(value) + if err != nil { + return nil, err + } + + if len(entitiesFromDB) == 0 { + return nil, nil // 数据库中没有数据 + } + + // 更新缓存 + var idsToCache []ID + for _, entity := range entitiesFromDB { + id := ormRepository.GetPrimaryKey(entity) + idsToCache = append(idsToCache, id) + // 异步更新实体缓存 + go g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity) + } + + // 缓存字段到 IDs 的映射 + fieldValuesToIDs := map[string][]ID{value: idsToCache} + go g.cacheRepository.CacheFieldToIDs(ctx, ormRepository.GetBaseField(), fieldName, fieldValuesToIDs) + + return entitiesFromDB, nil + } + + // 从缓存中获取实体 + entities, err := g.cacheRepository.FindBatchByIds(ctx, ormRepository.GetBaseField(), ids) + if err != nil { + return nil, err + } + + // 检查缓存中是否有缺失的实体 + var missingIDs []ID + for i, entity := range entities { + if entity == nil { + missingIDs = append(missingIDs, ids[i]) + } + } + + if len(missingIDs) > 0 { + // 从数据库获取缺失的实体 + missingEntities, err := ormRepository.FindByIDsFromDB(missingIDs) + if err != nil { + return nil, err + } + + // 更新缓存并合并结果 + for _, entity := range missingEntities { + entities = append(entities, entity) + id := ormRepository.GetPrimaryKey(entity) + go g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity) + } + } + + return entities, nil +} + +func (g *HadeGenericRepository[T, ID]) FindByIDs(ctx context.Context, ids []ID) ([]*T, error) { + ormRepository := g.ormRepository + // 从缓存中获取实体 + entities, err := g.cacheRepository.FindBatchByIds(ctx, ormRepository.GetBaseField(), ids) + if err != nil { + return nil, err + } + + // 记录缓存中缺失的 IDs + var missingIDs []ID + for i, entity := range entities { + if entity == nil { + missingIDs = append(missingIDs, ids[i]) + } + } + + // 从数据库获取缺失的实体 + if len(missingIDs) > 0 { + missingEntities, err := ormRepository.FindByIDsFromDB(missingIDs) + if err != nil { + return nil, err + } + + // 更新缓存并合并结果 + for _, entity := range missingEntities { + entities = append(entities, entity) + id := ormRepository.GetPrimaryKey(entity) + go g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity) + } + } + + // 过滤掉 nil 值的实体 + var result []*T + for _, entity := range entities { + if entity != nil { + result = append(result, entity) + } + } + + return result, nil +} + +func (g *HadeGenericRepository[T, ID]) FindByFieldIn(ctx context.Context, fieldName string, values []string) ([]*T, error) { + ormRepository := g.ormRepository + // 获取字段查询函数 + queryFunc, ok := ormRepository.GetFieldInQueryFunc(fieldName) + if !ok { + return nil, fmt.Errorf("no query function found for field: %s", fieldName) + } + + var allEntities []*T + var allIDs []ID + var missingFieldValues []string + + // 对每个字段值,尝试从缓存中获取 IDs + fieldValuesToIDs := make(map[string][]ID) + for _, value := range values { + ids, err := g.cacheRepository.FindIDsFromCache(ctx, ormRepository.GetBaseField(), fieldName, value) + if err != nil || ids == nil || len(ids) == 0 { + // 缓存未命中,记录缺失的字段值 + missingFieldValues = append(missingFieldValues, value) + } else { + fieldValuesToIDs[value] = ids + allIDs = append(allIDs, ids...) + } + } + + // 从缓存中获取实体 + entities, err := g.cacheRepository.FindBatchByIds(ctx, ormRepository.GetBaseField(), allIDs) + if err != nil { + return nil, err + } + + // 记录缓存中缺失的 IDs + var missingIDs []ID + idEntityMap := make(map[ID]*T) + for i, entity := range entities { + if entity != nil { + allEntities = append(allEntities, entity) + id := ormRepository.GetPrimaryKey(entity) + idEntityMap[id] = entity + } else { + missingIDs = append(missingIDs, allIDs[i]) + } + } + + // 从数据库获取缺失的实体 + if len(missingIDs) > 0 { + missingEntities, err := ormRepository.FindByIDsFromDB(missingIDs) + if err != nil { + return nil, err + } + + // 更新缓存并合并结果 + for _, entity := range missingEntities { + allEntities = append(allEntities, entity) + id := ormRepository.GetPrimaryKey(entity) + idEntityMap[id] = entity + go g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity) + } + } + + // 处理缓存中缺失的字段值 + if len(missingFieldValues) > 0 { + missingEntities, err := queryFunc(missingFieldValues) + if err != nil { + return nil, err + } + + // 更新缓存并合并结果 + fieldValuesToIDsToCache := make(map[string][]ID) + for _, entity := range missingEntities { + allEntities = append(allEntities, entity) + id := ormRepository.GetPrimaryKey(entity) + fieldGetter, ok := ormRepository.GetFieldValueFunc(fieldName) + if !ok { + continue + } + fieldValue := fieldGetter(entity) + fieldValuesToIDsToCache[fieldValue] = append(fieldValuesToIDsToCache[fieldValue], id) + go g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity) + } + + // 更新字段到 IDs 的缓存 + go g.cacheRepository.CacheFieldToIDs(ctx, ormRepository.GetBaseField(), fieldName, fieldValuesToIDsToCache) + } + + return allEntities, nil +} + +// getKey 生成缓存键 +func (r *HadeCacheRepository[T, ID]) getKey(prefix string, value any) string { + return fmt.Sprintf("%s::%v", prefix, value) +} + +func (r *HadeCacheRepository[T, ID]) getKeyWithField(prefix, fieldPrefix string, value any) string { + return fmt.Sprintf("%s::%s::%v", prefix, fieldPrefix, value) +} + +// Cache 将实体缓存到 Redis +func (r *HadeCacheRepository[T, ID]) Cache(ctx context.Context, prefix string, id ID, entity *T) error { + key := r.getKey(prefix, id) + return r.cacheService.SetObj(ctx, key, entity, r.cacheExpiration) +} + +// CacheEvict 从缓存中删除某个实体 +func (r *HadeCacheRepository[T, ID]) CacheEvict(ctx context.Context, prefix string, id ID) error { + key := r.getKey(prefix, id) + return r.cacheService.Del(ctx, key) +} + +// FindFromCache 根据 ID 从缓存中获取实体 +func (r *HadeCacheRepository[T, ID]) FindFromCache(ctx context.Context, prefix string, id ID) (*T, error) { + key := r.getKey(prefix, id) + var entity T + err := r.cacheService.GetObj(ctx, key, &entity) + if err != nil { + return nil, err + } + return &entity, nil +} + +// FindBatchByIds 从缓存中批量获取实体 +func (r *HadeCacheRepository[T, ID]) FindBatchByIds(ctx context.Context, prefix string, ids []ID) ([]*T, error) { + keys := make([]string, len(ids)) + for i, id := range ids { + keys[i] = r.getKey(prefix, id) + } + keyValueMap, err := r.cacheService.GetMany(ctx, keys) + if err != nil { + return nil, err + } + results := make([]*T, len(ids)) + for i, key := range keys { + data, ok := keyValueMap[key] + if ok && data != "" { + var entity T + if err := json.Unmarshal([]byte(data), &entity); err != nil { + return nil, err + } + results[i] = &entity + } else { + results[i] = nil + } + } + return results, nil +} + +// CacheFieldToID 缓存字段到 ID 的映射 +func (r *HadeCacheRepository[T, ID]) CacheFieldToID(ctx context.Context, prefix, fieldPrefix, fieldValue string, id ID) error { + key := r.getKeyWithField(prefix, fieldPrefix, fieldValue) + return r.cacheService.SetObj(ctx, key, id, r.cacheExpiration) +} + +// CacheEvictFieldToID 从缓存中删除字段到 ID 的映射 +func (r *HadeCacheRepository[T, ID]) CacheEvictFieldToID(ctx context.Context, prefix, fieldPrefix, fieldValue string) error { + key := r.getKeyWithField(prefix, fieldPrefix, fieldValue) + return r.cacheService.Del(ctx, key) +} + +// FindIDFromCache 根据字段获取对应的 ID +func (r *HadeCacheRepository[T, ID]) FindIDFromCache(ctx context.Context, prefix, fieldPrefix, fieldValue string) (ID, error) { + key := r.getKeyWithField(prefix, fieldPrefix, fieldValue) + var id ID + err := r.cacheService.GetObj(ctx, key, &id) + return id, err +} + +// CacheFieldToIDs 缓存字段到多个 ID 的映射 +func (r *HadeCacheRepository[T, ID]) CacheFieldToIDs(ctx context.Context, prefix, fieldPrefix string, fieldValuesToIDs map[string][]ID) error { + data := make(map[string]string) + for fieldValue, ids := range fieldValuesToIDs { + key := r.getKeyWithField(prefix, fieldPrefix, fieldValue) + idsBytes, err := json.Marshal(ids) + if err != nil { + return err + } + data[key] = string(idsBytes) + } + return r.cacheService.SetMany(ctx, data, r.cacheExpiration) +} + +// CacheEvictFieldsToIDsBatch 从缓存中批量删除字段到 ID 的映射 +func (r *HadeCacheRepository[T, ID]) CacheEvictFieldsToIDsBatch(ctx context.Context, prefix, fieldPrefix string, fieldValues []string) error { + var keys []string + for _, fieldValue := range fieldValues { + keys = append(keys, r.getKeyWithField(prefix, fieldPrefix, fieldValue)) + } + return r.cacheService.DelMany(ctx, keys) +} + +// FindIDsFromCache 根据字段获取对应的 ID 列表 +func (r *HadeCacheRepository[T, ID]) FindIDsFromCache(ctx context.Context, prefix, fieldPrefix, fieldValue string) ([]ID, error) { + key := r.getKeyWithField(prefix, fieldPrefix, fieldValue) + var ids []ID + err := r.cacheService.GetObj(ctx, key, &ids) + return ids, err } diff --git a/main.go b/main.go index 98c4ba8..7079477 100644 --- a/main.go +++ b/main.go @@ -9,10 +9,12 @@ import ( "github.com/Superdanda/hade/framework/provider/config" "github.com/Superdanda/hade/framework/provider/distributed" "github.com/Superdanda/hade/framework/provider/env" + "github.com/Superdanda/hade/framework/provider/infrastructure" "github.com/Superdanda/hade/framework/provider/kernel" "github.com/Superdanda/hade/framework/provider/log" "github.com/Superdanda/hade/framework/provider/orm" "github.com/Superdanda/hade/framework/provider/redis" + "github.com/Superdanda/hade/framework/provider/repository" "github.com/Superdanda/hade/framework/provider/ssh" "github.com/Superdanda/hade/framework/provider/type_register" ) @@ -34,6 +36,8 @@ func main() { container.Bind(&cache.HadeCacheProvider{}) container.Bind(&ssh.SSHProvider{}) container.Bind(&type_register.TypeRegisterProvider{}) + container.Bind(&infrastructure.InfrastructureProvider{}) + container.Bind(&repository.RepositoryProvider{}) // 将HTTP引擎初始化,并且作为服务提供者绑定到服务容器中 if engine, err := http.NewHttpEngine(container); err == nil {