framework1/framework/provider/repository/service.go

431 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repository
import (
"context"
"encoding/json"
"fmt"
"github.com/Superdanda/hade/framework"
"github.com/Superdanda/hade/framework/contract"
"github.com/pkg/errors"
"time"
)
func RegisterRepository[T any, ID comparable](service contract.RepositoryService, key string, ormRepository interface{}) *HadeGenericRepository[T, ID] {
container := service.GetContainer()
cacheService := container.MustMake(contract.CacheKey).(contract.CacheService)
configService := container.MustMake(contract.ConfigKey).(contract.Config)
expireTime := configService.GetString("cache.repository.expire")
if expireTime == "" {
fmt.Println("从配置文件获取缓存失败使用默认缓存时间6个小时")
}
expireTime = "6h"
duration, err := time.ParseDuration(expireTime)
if err != nil {
fmt.Println("从配置文件获取缓存失败使用默认缓存时间6个小时")
} else {
duration = 6 * time.Hour
}
genericRepository := NewHadeGenericRepository[T, ID](
key,
container,
NewHadeCacheRepository[T, ID](cacheService, duration),
ormRepository.(contract.OrmRepository[T, ID]),
)
service.GetGenericRepositoryMap()[key] = genericRepository
return genericRepository
}
type HadeRepositoryService struct {
container framework.Container
genericRepositoryMap map[string]interface{}
contract.RepositoryService
}
func NewHadeRepositoryService(params ...interface{}) (interface{}, error) {
container, ok := params[0].(framework.Container)
if !ok {
return nil, errors.New("invalid container parameter")
}
return &HadeRepositoryService{
container: container,
genericRepositoryMap: make(map[string]interface{}),
}, nil
}
func (h *HadeRepositoryService) GetGenericRepositoryByKey(key string) interface{} {
return h.genericRepositoryMap[key]
}
func (h *HadeRepositoryService) GetGenericRepositoryMap() map[string]interface{} {
return h.genericRepositoryMap
}
func (h *HadeRepositoryService) GetContainer() framework.Container {
return h.container
}
type HadeGenericRepository[T any, ID comparable] struct {
repositoryKey string
container framework.Container
cacheRepository *HadeCacheRepository[T, ID]
ormRepository contract.OrmRepository[T, ID]
contract.GenericRepository[T, ID]
}
func NewHadeGenericRepository[T any, ID comparable](repositoryKey string, container framework.Container, cacheRepository *HadeCacheRepository[T, ID],
ormRepository contract.OrmRepository[T, ID]) *HadeGenericRepository[T, ID] {
return &HadeGenericRepository[T, ID]{repositoryKey: repositoryKey, container: container, cacheRepository: cacheRepository, ormRepository: ormRepository}
}
type HadeCacheRepository[T any, ID comparable] struct {
cacheService contract.CacheService
cacheExpiration time.Duration
}
func NewHadeCacheRepository[T any, ID comparable](cacheService contract.CacheService, cacheExpiration time.Duration) *HadeCacheRepository[T, ID] {
return &HadeCacheRepository[T, ID]{
cacheService: cacheService,
cacheExpiration: cacheExpiration,
}
}
func (g *HadeGenericRepository[T, ID]) Save(ctx context.Context, entity *T) error {
ormRepository := g.ormRepository
err := ormRepository.SaveToDB(entity)
if err != nil {
return err
}
err = g.updateCache(ctx, entity)
if err != nil {
return err
}
return nil
}
func (g *HadeGenericRepository[T, ID]) updateCache(ctx context.Context, entity *T) error {
ormRepository := g.ormRepository
id := ormRepository.GetPrimaryKey(entity)
return g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity)
}
func (g *HadeGenericRepository[T, ID]) FindByID(ctx context.Context, id ID) (*T, error) {
ormRepository := g.ormRepository
cache, _ := g.cacheRepository.FindFromCache(ctx, ormRepository.GetBaseField(), id)
if cache != nil {
return cache, nil
}
entityFromDB, err := ormRepository.FindByIDFromDB(id)
if err != nil {
return nil, err
}
go g.updateCache(ctx, entityFromDB)
return entityFromDB, nil
}
func (g *HadeGenericRepository[T, ID]) FindByField(ctx context.Context, fieldName string, value string) ([]*T, error) {
ormRepository := g.ormRepository
// 获取字段查询函数
queryFunc, ok := ormRepository.GetFieldQueryFunc(fieldName)
if !ok {
return nil, fmt.Errorf("no query function found for field: %s", fieldName)
}
// 从缓存中获取 IDs
ids, err := g.cacheRepository.FindIDsFromCache(ctx, ormRepository.GetBaseField(), fieldName, value)
if err != nil || ids == nil || len(ids) == 0 {
// 缓存未命中,从数据库查询
entitiesFromDB, err := queryFunc(value)
if err != nil {
return nil, err
}
if len(entitiesFromDB) == 0 {
return nil, nil // 数据库中没有数据
}
// 更新缓存
var idsToCache []ID
for _, entity := range entitiesFromDB {
id := ormRepository.GetPrimaryKey(entity)
idsToCache = append(idsToCache, id)
// 异步更新实体缓存
go g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity)
}
// 缓存字段到 IDs 的映射
fieldValuesToIDs := map[string][]ID{value: idsToCache}
go g.cacheRepository.CacheFieldToIDs(ctx, ormRepository.GetBaseField(), fieldName, fieldValuesToIDs)
return entitiesFromDB, nil
}
// 从缓存中获取实体
entities, err := g.cacheRepository.FindBatchByIds(ctx, ormRepository.GetBaseField(), ids)
if err != nil {
return nil, err
}
// 检查缓存中是否有缺失的实体
var missingIDs []ID
for i, entity := range entities {
if entity == nil {
missingIDs = append(missingIDs, ids[i])
}
}
if len(missingIDs) > 0 {
// 从数据库获取缺失的实体
missingEntities, err := ormRepository.FindByIDsFromDB(missingIDs)
if err != nil {
return nil, err
}
// 更新缓存并合并结果
for _, entity := range missingEntities {
entities = append(entities, entity)
id := ormRepository.GetPrimaryKey(entity)
go g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity)
}
}
return entities, nil
}
func (g *HadeGenericRepository[T, ID]) FindByIDs(ctx context.Context, ids []ID) ([]*T, error) {
ormRepository := g.ormRepository
// 从缓存中获取实体
entities, err := g.cacheRepository.FindBatchByIds(ctx, ormRepository.GetBaseField(), ids)
if err != nil {
return nil, err
}
// 记录缓存中缺失的 IDs
var missingIDs []ID
for i, entity := range entities {
if entity == nil {
missingIDs = append(missingIDs, ids[i])
}
}
// 从数据库获取缺失的实体
if len(missingIDs) > 0 {
missingEntities, err := ormRepository.FindByIDsFromDB(missingIDs)
if err != nil {
return nil, err
}
// 更新缓存并合并结果
for _, entity := range missingEntities {
entities = append(entities, entity)
id := ormRepository.GetPrimaryKey(entity)
go g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity)
}
}
// 过滤掉 nil 值的实体
var result []*T
for _, entity := range entities {
if entity != nil {
result = append(result, entity)
}
}
return result, nil
}
func (g *HadeGenericRepository[T, ID]) FindByFieldIn(ctx context.Context, fieldName string, values []string) ([]*T, error) {
ormRepository := g.ormRepository
// 获取字段查询函数
queryFunc, ok := ormRepository.GetFieldInQueryFunc(fieldName)
if !ok {
return nil, fmt.Errorf("no query function found for field: %s", fieldName)
}
var allEntities []*T
var allIDs []ID
var missingFieldValues []string
// 对每个字段值,尝试从缓存中获取 IDs
fieldValuesToIDs := make(map[string][]ID)
for _, value := range values {
ids, err := g.cacheRepository.FindIDsFromCache(ctx, ormRepository.GetBaseField(), fieldName, value)
if err != nil || ids == nil || len(ids) == 0 {
// 缓存未命中,记录缺失的字段值
missingFieldValues = append(missingFieldValues, value)
} else {
fieldValuesToIDs[value] = ids
allIDs = append(allIDs, ids...)
}
}
// 从缓存中获取实体
entities, err := g.cacheRepository.FindBatchByIds(ctx, ormRepository.GetBaseField(), allIDs)
if err != nil {
return nil, err
}
// 记录缓存中缺失的 IDs
var missingIDs []ID
idEntityMap := make(map[ID]*T)
for i, entity := range entities {
if entity != nil {
allEntities = append(allEntities, entity)
id := ormRepository.GetPrimaryKey(entity)
idEntityMap[id] = entity
} else {
missingIDs = append(missingIDs, allIDs[i])
}
}
// 从数据库获取缺失的实体
if len(missingIDs) > 0 {
missingEntities, err := ormRepository.FindByIDsFromDB(missingIDs)
if err != nil {
return nil, err
}
// 更新缓存并合并结果
for _, entity := range missingEntities {
allEntities = append(allEntities, entity)
id := ormRepository.GetPrimaryKey(entity)
idEntityMap[id] = entity
go g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity)
}
}
// 处理缓存中缺失的字段值
if len(missingFieldValues) > 0 {
missingEntities, err := queryFunc(missingFieldValues)
if err != nil {
return nil, err
}
// 更新缓存并合并结果
fieldValuesToIDsToCache := make(map[string][]ID)
for _, entity := range missingEntities {
allEntities = append(allEntities, entity)
id := ormRepository.GetPrimaryKey(entity)
fieldGetter, ok := ormRepository.GetFieldValueFunc(fieldName)
if !ok {
continue
}
fieldValue := fieldGetter(entity)
fieldValuesToIDsToCache[fieldValue] = append(fieldValuesToIDsToCache[fieldValue], id)
go g.cacheRepository.Cache(ctx, ormRepository.GetBaseField(), id, entity)
}
// 更新字段到 IDs 的缓存
go g.cacheRepository.CacheFieldToIDs(ctx, ormRepository.GetBaseField(), fieldName, fieldValuesToIDsToCache)
}
return allEntities, nil
}
// getKey 生成缓存键
func (r *HadeCacheRepository[T, ID]) getKey(prefix string, value any) string {
return fmt.Sprintf("%s::%v", prefix, value)
}
func (r *HadeCacheRepository[T, ID]) getKeyWithField(prefix, fieldPrefix string, value any) string {
return fmt.Sprintf("%s::%s::%v", prefix, fieldPrefix, value)
}
// Cache 将实体缓存到 Redis
func (r *HadeCacheRepository[T, ID]) Cache(ctx context.Context, prefix string, id ID, entity *T) error {
key := r.getKey(prefix, id)
return r.cacheService.SetObj(ctx, key, entity, r.cacheExpiration)
}
// CacheEvict 从缓存中删除某个实体
func (r *HadeCacheRepository[T, ID]) CacheEvict(ctx context.Context, prefix string, id ID) error {
key := r.getKey(prefix, id)
return r.cacheService.Del(ctx, key)
}
// FindFromCache 根据 ID 从缓存中获取实体
func (r *HadeCacheRepository[T, ID]) FindFromCache(ctx context.Context, prefix string, id ID) (*T, error) {
key := r.getKey(prefix, id)
var entity T
err := r.cacheService.GetObj(ctx, key, &entity)
if err != nil {
return nil, err
}
return &entity, nil
}
// FindBatchByIds 从缓存中批量获取实体
func (r *HadeCacheRepository[T, ID]) FindBatchByIds(ctx context.Context, prefix string, ids []ID) ([]*T, error) {
keys := make([]string, len(ids))
for i, id := range ids {
keys[i] = r.getKey(prefix, id)
}
keyValueMap, err := r.cacheService.GetMany(ctx, keys)
if err != nil {
return nil, err
}
results := make([]*T, len(ids))
for i, key := range keys {
data, ok := keyValueMap[key]
if ok && data != "" {
var entity T
if err := json.Unmarshal([]byte(data), &entity); err != nil {
return nil, err
}
results[i] = &entity
} else {
results[i] = nil
}
}
return results, nil
}
// CacheFieldToID 缓存字段到 ID 的映射
func (r *HadeCacheRepository[T, ID]) CacheFieldToID(ctx context.Context, prefix, fieldPrefix, fieldValue string, id ID) error {
key := r.getKeyWithField(prefix, fieldPrefix, fieldValue)
return r.cacheService.SetObj(ctx, key, id, r.cacheExpiration)
}
// CacheEvictFieldToID 从缓存中删除字段到 ID 的映射
func (r *HadeCacheRepository[T, ID]) CacheEvictFieldToID(ctx context.Context, prefix, fieldPrefix, fieldValue string) error {
key := r.getKeyWithField(prefix, fieldPrefix, fieldValue)
return r.cacheService.Del(ctx, key)
}
// FindIDFromCache 根据字段获取对应的 ID
func (r *HadeCacheRepository[T, ID]) FindIDFromCache(ctx context.Context, prefix, fieldPrefix, fieldValue string) (ID, error) {
key := r.getKeyWithField(prefix, fieldPrefix, fieldValue)
var id ID
err := r.cacheService.GetObj(ctx, key, &id)
return id, err
}
// CacheFieldToIDs 缓存字段到多个 ID 的映射
func (r *HadeCacheRepository[T, ID]) CacheFieldToIDs(ctx context.Context, prefix, fieldPrefix string, fieldValuesToIDs map[string][]ID) error {
data := make(map[string]string)
for fieldValue, ids := range fieldValuesToIDs {
key := r.getKeyWithField(prefix, fieldPrefix, fieldValue)
idsBytes, err := json.Marshal(ids)
if err != nil {
return err
}
data[key] = string(idsBytes)
}
return r.cacheService.SetMany(ctx, data, r.cacheExpiration)
}
// CacheEvictFieldsToIDsBatch 从缓存中批量删除字段到 ID 的映射
func (r *HadeCacheRepository[T, ID]) CacheEvictFieldsToIDsBatch(ctx context.Context, prefix, fieldPrefix string, fieldValues []string) error {
var keys []string
for _, fieldValue := range fieldValues {
keys = append(keys, r.getKeyWithField(prefix, fieldPrefix, fieldValue))
}
return r.cacheService.DelMany(ctx, keys)
}
// FindIDsFromCache 根据字段获取对应的 ID 列表
func (r *HadeCacheRepository[T, ID]) FindIDsFromCache(ctx context.Context, prefix, fieldPrefix, fieldValue string) ([]ID, error) {
key := r.getKeyWithField(prefix, fieldPrefix, fieldValue)
var ids []ID
err := r.cacheService.GetObj(ctx, key, &ids)
return ids, err
}