framework1/app/infrastructure/user.go

144 lines
4.4 KiB
Go
Raw Normal View History

2024-10-30 16:56:41 +08:00
package infrastructure
import (
"context"
"github.com/Superdanda/hade/app/provider/database_connect"
userModule "github.com/Superdanda/hade/app/provider/user"
"github.com/Superdanda/hade/framework"
"github.com/Superdanda/hade/framework/contract"
"gorm.io/gorm"
)
type UserRepository struct {
container framework.Container
db *gorm.DB
contract.OrmRepository[userModule.User, int64]
userModule.Repository
}
func NewUserRepository(container framework.Container) contract.OrmRepository[userModule.User, int64] {
connectService := container.MustMake(database_connect.DatabaseConnectKey).(database_connect.Service)
connect := connectService.DefaultDatabaseConnect()
userOrmService := &UserRepository{container: container, db: connect}
infrastructureService := container.MustMake(contract.InfrastructureKey).(contract.InfrastructureService)
infrastructureService.RegisterOrmRepository(userModule.UserKey, userOrmService)
//repository.RegisterRepository[userModule.User, int64](userModule.UserKey,)
return userOrmService
}
func (u *UserRepository) SaveToDB(entity *userModule.User) error {
u.db.Save(entity)
return nil
}
func (u *UserRepository) FindByIDFromDB(id int64) (*userModule.User, error) {
user := &userModule.User{}
u.db.Find(user, id)
return user, nil
}
func (u *UserRepository) FindByID64sFromDB(ids []int64) ([]*userModule.User, error) {
var users []*userModule.User
// 使用 GORM 的 Where 方法查询用户 ID 在给定 ID 列表中的记录
if err := u.db.Where("id IN ?", ids).Find(&users).Error; err != nil {
return nil, err // 如果查询出错,返回错误
}
return users, nil // 返回查询结果和 nil 错误
}
func (u *UserRepository) GetPrimaryKey(entity *userModule.User) int64 {
return entity.ID
}
func (u *UserRepository) GetBaseField() string {
return userModule.UserKey
}
func (u *UserRepository) GetFieldQueryFunc(fieldName string) (func(value string) ([]*userModule.User, error), bool) {
switch fieldName {
case "Email":
return func(value string) ([]*userModule.User, error) {
var users []*userModule.User
// 执行查询,匹配 Email 字段
if err := u.db.Where("email = ?", value).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}, true
case "UserName":
return func(value string) ([]*userModule.User, error) {
var users []*userModule.User
// 执行查询,匹配 UserName 字段
if err := u.db.Where("user_name = ?", value).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}, true
default:
// 如果传入的字段名不支持,返回 nil 和 false
return nil, false
}
}
func (u *UserRepository) GetFieldInQueryFunc(fieldName string) (func(values []string) ([]*userModule.User, error), bool) {
switch fieldName {
case "Email":
return func(values []string) ([]*userModule.User, error) {
var users []*userModule.User
// 批量查询 Email 字段匹配的用户
if err := u.db.Where("email IN ?", values).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}, true
case "UserName":
return func(values []string) ([]*userModule.User, error) {
var users []*userModule.User
// 批量查询 UserName 字段匹配的用户
if err := u.db.Where("user_name IN ?", values).Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}, true
default:
return nil, false // 不支持的字段返回 false
}
}
func (u *UserRepository) GetFieldValueFunc(fieldName string) (func(entity *userModule.User) string, bool) {
switch fieldName {
case "Email":
return func(entity *userModule.User) string {
return entity.Email
}, true
case "UserName":
return func(entity *userModule.User) string {
return entity.UserName
}, true
default:
return nil, false // 不支持的字段返回 false
}
}
func (u *UserRepository) Save(ctx context.Context, user *userModule.User) error {
repository := u.container.MustMake(contract.RepositoryKey).(contract.Repository[userModule.User, int64])
if err := repository.Save(ctx, userModule.UserKey, user); err != nil {
return err
}
return nil
}
func (u *UserRepository) FindById(ctx context.Context, id int64) (*userModule.User, error) {
repository := u.container.MustMake(contract.RepositoryKey).(contract.Repository[userModule.User, int64])
byID, err := repository.FindByID(ctx, userModule.UserKey, id)
if err != nil {
return nil, err
}
return byID, nil
}