From 9a1c264368c7c8f3345f01c872b007917abd1763 Mon Sep 17 00:00:00 2001 From: ydy <517697206@qq.com> Date: Sat, 15 Feb 2025 14:35:59 +0800 Subject: [PATCH 1/7] feat:Add multi db switching --- constants/constants.go | 9 +-- gplus/cache.go | 34 +++++---- gplus/dao.go | 50 +++++++++++-- gplus/option.go | 36 +++++++++- gplus/query.go | 17 ++++- gplus/tool.go | 22 +++--- tests/dao_test.go | 156 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 288 insertions(+), 36 deletions(-) diff --git a/constants/constants.go b/constants/constants.go index 4fe8d1d..bd02c53 100644 --- a/constants/constants.go +++ b/constants/constants.go @@ -18,8 +18,9 @@ package constants const ( - Comma = "," - LeftBracket = "(" - RightBracket = ")" - DefaultPrimaryName = "id" + Comma = "," + LeftBracket = "(" + RightBracket = ")" + DefaultPrimaryName = "id" + DefaultGormPlusConnName = "DefaultGormPlusConnName" //内置的gorm-plus 数据库连接名 ) diff --git a/gplus/cache.go b/gplus/cache.go index feed184..bb987c3 100644 --- a/gplus/cache.go +++ b/gplus/cache.go @@ -18,6 +18,7 @@ package gplus import ( + "github.com/acmestack/gorm-plus/constants" "gorm.io/gorm/schema" "reflect" "sync" @@ -31,9 +32,9 @@ var columnNameCache sync.Map var modelInstanceCache sync.Map // Cache 缓存实体对象所有的字段名 -func Cache(models ...any) { +func Cache(dbConnName string, models ...any) { for _, model := range models { - columnNameMap := getColumnNameMap(model) + columnNameMap := getColumnNameMap(model, dbConnName) for pointer, columnName := range columnNameMap { columnNameCache.Store(pointer, columnName) } @@ -43,7 +44,7 @@ func Cache(models ...any) { } } -func getColumnNameMap(model any) map[uintptr]string { +func getColumnNameMap(model any, dbConnName string) map[uintptr]string { var columnNameMap = make(map[uintptr]string) valueOf := reflect.ValueOf(model).Elem() typeOf := reflect.TypeOf(model).Elem() @@ -52,14 +53,14 @@ func getColumnNameMap(model any) map[uintptr]string { // 如果当前实体嵌入了其他实体,同样需要缓存它的字段名 if field.Anonymous { // 如果存在多重嵌套,通过递归方式获取他们的字段名 - subFieldMap := getSubFieldColumnNameMap(valueOf, field) + subFieldMap := getSubFieldColumnNameMap(valueOf, field, dbConnName) for pointer, columnName := range subFieldMap { columnNameMap[pointer] = columnName } } else { // 获取对象字段指针值 pointer := valueOf.Field(i).Addr().Pointer() - columnName := parseColumnName(field) + columnName := parseColumnName(field, dbConnName) columnNameMap[pointer] = columnName } } @@ -68,6 +69,11 @@ func getColumnNameMap(model any) map[uintptr]string { // GetModel 获取 func GetModel[T any]() *T { + return GetModelBaseDb[T]("") +} + +// GetModelBaseDb 获取根据数据库连接名 +func GetModelBaseDb[T any](dbConnName string) *T { modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := modelInstanceCache.Load(modelTypeStr); ok { m, isReal := model.(*T) @@ -76,12 +82,12 @@ func GetModel[T any]() *T { } } t := new(T) - Cache(t) + Cache(dbConnName, t) return t } // 递归获取嵌套字段名 -func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField) map[uintptr]string { +func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField, dbConnName string) map[uintptr]string { result := make(map[uintptr]string) modelType := field.Type if modelType.Kind() == reflect.Ptr { @@ -90,13 +96,13 @@ func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField) for j := 0; j < modelType.NumField(); j++ { subField := modelType.Field(j) if subField.Anonymous { - nestedFields := getSubFieldColumnNameMap(valueOf, subField) + nestedFields := getSubFieldColumnNameMap(valueOf, subField, dbConnName) for key, value := range nestedFields { result[key] = value } } else { pointer := valueOf.FieldByName(modelType.Field(j).Name).Addr().Pointer() - name := parseColumnName(modelType.Field(j)) + name := parseColumnName(modelType.Field(j), dbConnName) result[pointer] = name } } @@ -104,14 +110,18 @@ func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField) return result } -// 解析字段名称 -func parseColumnName(field reflect.StructField) string { +// 解析字段名称 兼容多数据库切换 +func parseColumnName(field reflect.StructField, dbConnName string) string { tagSetting := schema.ParseTagSetting(field.Tag.Get("gorm"), ";") name, ok := tagSetting["COLUMN"] if ok { return name } - return globalDb.Config.NamingStrategy.ColumnName("", field.Name) + if len(dbConnName) == 0 { + dbConnName = constants.DefaultGormPlusConnName + } + db, _ := GetDb(dbConnName) + return db.Config.NamingStrategy.ColumnName("", field.Name) } func getColumnName(v any) string { diff --git a/gplus/dao.go b/gplus/dao.go index 341496a..18573e8 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -18,6 +18,7 @@ package gplus import ( "database/sql" + "errors" "fmt" "github.com/acmestack/gorm-plus/constants" "gorm.io/gorm" @@ -28,11 +29,33 @@ import ( "time" ) -var globalDb *gorm.DB +var globalDbMap = make(map[string]*gorm.DB) var defaultBatchSize = 1000 func Init(db *gorm.DB) { - globalDb = db + InitDb(db, constants.DefaultGormPlusConnName) +} + +func InitDb(db *gorm.DB, dbConnName string) error { + if len(dbConnName) == 0 { + return errors.New("InitMultiple dbConnName is empty please check") + } + _, exists := globalDbMap[dbConnName] + if !exists { + // db instance register to global variable + globalDbMap[dbConnName] = db + return nil + } + return errors.New("InitMultiple have same name:" + dbConnName + ",please check") +} + +// GetDb 获取数据库连接 +func GetDb(dbConnName string) (*gorm.DB, error) { + db, exists := globalDbMap[dbConnName] + if exists { + return db, nil + } + return nil, errors.New("MultipleDbChange not exists dbConn:" + dbConnName + ",please check") } type Page[T any] struct { @@ -49,6 +72,10 @@ func (dao Dao[T]) NewQuery() (*QueryCond[T], *T) { return NewQuery[T]() } +func (dao Dao[T]) NewQueryBaseDb(opt OptionFunc) (*QueryCond[T], *T) { + return NewQueryBaseDb[T](opt) +} + func NewPage[T any](current, size int) *Page[T] { if current <= 0 { current = 1 @@ -157,7 +184,7 @@ func UpdateZeroById[T any](entity *T, opts ...OptionFunc) *gorm.DB { func updateAllIfNeed(entity any, opts []OptionFunc, db *gorm.DB) { option := getOption(opts) if len(option.Selects) == 0 { - columnNameMap := getColumnNameMap(entity) + columnNameMap := getColumnNameMap(entity, option.DbConnName) var columnNames []string for _, columnName := range columnNameMap { columnNames = append(columnNames, columnName) @@ -449,12 +476,19 @@ func buildSqlAndArgs[T any](expressions []any, sqlBuilder *strings.Builder, quer } func getDb(opts ...OptionFunc) *gorm.DB { + var db *gorm.DB option := getOption(opts) - // Clauses()目的是为了初始化Db,如果db已经被初始化了,会直接返回db - var db = globalDb.Clauses() if option.Db != nil { db = option.Db.Clauses() + } else { + if len(option.DbConnName) == 0 { + option.DbConnName = constants.DefaultGormPlusConnName + } + + db, _ = GetDb(option.DbConnName) + // Clauses()目的是为了初始化Db,如果db已经被初始化了,会直接返回db + db = db.Clauses() } // 设置需要忽略的字段 @@ -474,6 +508,12 @@ func getOption(opts []OptionFunc) Option { return config } +func getOneOption(opt OptionFunc) Option { + var config Option + opt(&config) + return config +} + func setSelectIfNeed(option Option, db *gorm.DB) { if len(option.Selects) > 0 { var columnNames []string diff --git a/gplus/option.go b/gplus/option.go index 1a89ee7..f4941cb 100644 --- a/gplus/option.go +++ b/gplus/option.go @@ -17,13 +17,17 @@ package gplus -import "gorm.io/gorm" +import ( + "github.com/acmestack/gorm-plus/constants" + "gorm.io/gorm" +) type Option struct { Db *gorm.DB Selects []any Omits []any IgnoreTotal bool + DbConnName string } type OptionFunc func(*Option) @@ -35,10 +39,12 @@ func Db(db *gorm.DB) OptionFunc { } } -// Session 创建回话 +// Session 创建会话 func Session(session *gorm.Session) OptionFunc { return func(o *Option) { - o.Db = globalDb.Session(session) + //兼容之前的设计 + db, _ := GetDb(constants.DefaultGormPlusConnName) + o.Db = db.Session(session) } } @@ -62,3 +68,27 @@ func IgnoreTotal() OptionFunc { o.IgnoreTotal = true } } + +// DbConnName 多个数据库连接根据自定义连接名称选择切换 +func DbConnName(dbConnName string) OptionFunc { + return func(o *Option) { + o.DbConnName = dbConnName + } +} + +// DbSessionBaseName 创建特定的Db会话 +func DbSessionBaseName(dbConnName string, session *gorm.Session) OptionFunc { + return func(o *Option) { + o.DbConnName = dbConnName + db, _ := GetDb(dbConnName) + o.Db = db.Session(session) + } +} + +// DbBaseName 使用特定的Db对象 +func DbBaseName(dbConnName string) OptionFunc { + return func(o *Option) { + o.DbConnName = dbConnName + o.Db, _ = GetDb(dbConnName) + } +} diff --git a/gplus/query.go b/gplus/query.go index d40a2b8..badfe07 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -47,6 +47,11 @@ func (q *QueryCond[T]) getSqlSegment() string { // NewQuery 构建查询条件 func NewQuery[T any]() (*QueryCond[T], *T) { + return NewQueryBaseDb[T](DbBaseName(constants.DefaultGormPlusConnName)) +} + +// NewQueryBaseDb 构建查询条件 +func NewQueryBaseDb[T any](opt OptionFunc) (*QueryCond[T], *T) { q := &QueryCond[T]{} modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := modelInstanceCache.Load(modelTypeStr); ok { @@ -56,12 +61,18 @@ func NewQuery[T any]() (*QueryCond[T], *T) { } } m := new(T) - Cache(m) + option := getOneOption(opt) + Cache(option.DbConnName, m) return q, m } // NewQueryModel 构建查询条件 func NewQueryModel[T any, R any]() (*QueryCond[T], *T, *R) { + return NewQueryModelBaseDb[T, R]("") +} + +// NewQueryModelBaseDb 构建查询条件 +func NewQueryModelBaseDb[T any, R any](dbConnName string) (*QueryCond[T], *T, *R) { q := &QueryCond[T]{} var t *T var r *R @@ -83,12 +94,12 @@ func NewQueryModel[T any, R any]() (*QueryCond[T], *T, *R) { if t == nil { t = new(T) - Cache(t) + Cache(dbConnName, t) } if r == nil { r = new(R) - Cache(r) + Cache(dbConnName, r) } return q, t, r diff --git a/gplus/tool.go b/gplus/tool.go index 5e7917c..19933b7 100644 --- a/gplus/tool.go +++ b/gplus/tool.go @@ -56,12 +56,16 @@ var builders = map[string]func(query *QueryCond[any], name string, value any){ } func BuildQuery[T any](queryParams url.Values) *QueryCond[T] { + return BuildQueryBaseDb[T](queryParams, "") +} + +func BuildQueryBaseDb[T any](queryParams url.Values, dbConnName string) *QueryCond[T] { columnCondMap, conditionMap, gcond := parseParams(queryParams) parentQuery := buildParentQuery[T](conditionMap) - queryCondMap := buildQueryCondMap[T](columnCondMap) + queryCondMap := buildQueryCondMap[T](columnCondMap, dbConnName) // 如果没有分组条件,直接返回默认的查询条件 if len(gcond) == 0 { @@ -159,9 +163,9 @@ func getCurrentOp(value string) string { return currentOperator } -func buildQueryCondMap[T any](columnCondMap map[string][]*Condition) map[string]*QueryCond[T] { +func buildQueryCondMap[T any](columnCondMap map[string][]*Condition, dbConnName string) map[string]*QueryCond[T] { var queryCondMap = make(map[string]*QueryCond[T]) - columnTypeMap := getColumnTypeMap[T]() + columnTypeMap := getColumnTypeMap[T](dbConnName) for key, conditions := range columnCondMap { query := &QueryCond[any]{} query.columnTypeMap = columnTypeMap @@ -273,7 +277,7 @@ func buildGroupQuery[T any](gcond string, queryMaps map[string]*QueryCond[T], qu return query } -func getColumnTypeMap[T any]() map[string]reflect.Type { +func getColumnTypeMap[T any](dbConnName string) map[string]reflect.Type { modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := columnTypeCache.Load(modelTypeStr); ok { if columnNameMap, isOk := model.(map[string]reflect.Type); isOk { @@ -285,19 +289,19 @@ func getColumnTypeMap[T any]() map[string]reflect.Type { for i := 0; i < typeOf.NumField(); i++ { field := typeOf.Field(i) if field.Anonymous { - nestedFields := getSubFieldColumnTypeMap(field) + nestedFields := getSubFieldColumnTypeMap(field, dbConnName) for key, value := range nestedFields { columnTypeMap[key] = value } } - columnName := parseColumnName(field) + columnName := parseColumnName(field, dbConnName) columnTypeMap[columnName] = field.Type } columnTypeCache.Store(modelTypeStr, columnTypeMap) return columnTypeMap } -func getSubFieldColumnTypeMap(field reflect.StructField) map[string]reflect.Type { +func getSubFieldColumnTypeMap(field reflect.StructField, dbConnName string) map[string]reflect.Type { columnTypeMap := make(map[string]reflect.Type) modelType := field.Type if modelType.Kind() == reflect.Ptr { @@ -306,12 +310,12 @@ func getSubFieldColumnTypeMap(field reflect.StructField) map[string]reflect.Type for j := 0; j < modelType.NumField(); j++ { subField := modelType.Field(j) if subField.Anonymous { - nestedFields := getSubFieldColumnTypeMap(subField) + nestedFields := getSubFieldColumnTypeMap(subField, dbConnName) for key, value := range nestedFields { columnTypeMap[key] = value } } else { - columnName := parseColumnName(subField) + columnName := parseColumnName(subField, dbConnName) columnTypeMap[columnName] = subField.Type } } diff --git a/tests/dao_test.go b/tests/dao_test.go index 8ad5133..25d8d23 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -31,6 +31,7 @@ import ( ) var gormDb *gorm.DB +var gormDbConnName = "test1" func init() { dsn := "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local" @@ -44,6 +45,21 @@ func init() { var u User gormDb.AutoMigrate(u) gplus.Init(gormDb) + initDb() +} + +func initDb() { + dsn := "root:123456@tcp(127.0.0.1:3306)/test1?charset=utf8mb4&parseTime=True&loc=Local" + var err error + gormDb1, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Info), + }) + if err != nil { + fmt.Println(err) + } + var u User + gormDb1.AutoMigrate(u) + gplus.InitDb(gormDb1, gormDbConnName) } func TestInsert(t *testing.T) { @@ -589,12 +605,152 @@ func TestTx(t *testing.T) { } } +func TestInsertBaseDb(t *testing.T) { + deleteOldDataBaseDb() + + user := &User{Username: "afumu", Password: "123456", Age: 18, Score: 100, Dept: "开发部门"} + resultDb := gplus.Insert(user, gplus.DbBaseName(gormDbConnName)) + + if resultDb.Error != nil { + t.Fatalf("errors happened when insert: %v", resultDb.Error) + } else if resultDb.RowsAffected != 1 { + t.Fatalf("rows affected expects: %v, got %v", 1, resultDb.RowsAffected) + } + + newUser, db := gplus.SelectById[User](user.ID, gplus.DbBaseName(gormDbConnName)) + if db.Error != nil { + t.Fatalf("errors happened when SelectById: %v", db.Error) + } + AssertObjEqual(t, newUser, user, "ID", "Username", "Password", "Address", "Age", "Phone", "Score", "Dept", "CreatedAt", "UpdatedAt") +} + +func TestInsertBatchBaseDb(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + resultDb := gplus.InsertBatch[User](users, gplus.DbBaseName(gormDbConnName)) + if resultDb.RowsAffected != int64(len(users)) { + t.Errorf("affected rows should be %v, but got %v", len(users), resultDb.RowsAffected) + } + + for _, user := range users { + newUser, db := gplus.SelectById[User](user.ID, gplus.DbBaseName(gormDbConnName)) + if db.Error != nil { + t.Fatalf("errors happened when SelectById: %v", db.Error) + } + AssertObjEqual(t, newUser, user, "ID", "Username", "Password", "Address", "Age", "Phone", "Score", "Dept", "CreatedAt", "UpdatedAt") + } +} + +func TestDeleteByIdBaseDb(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + gplus.InsertBatchSize[User](users, 2, gplus.DbBaseName(gormDbConnName)) + + if res := gplus.DeleteById[User](users[1].ID, gplus.DbBaseName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("errors happened when deleteById: %v, affected: %v", res.Error, res.RowsAffected) + } + + _, resultDb := gplus.SelectById[User](users[1].ID, gplus.DbBaseName(gormDbConnName)) + if !errors.Is(resultDb.Error, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", resultDb.Error) + } +} + +func TestDeleteBaseDb(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + opt := gplus.DbBaseName(gormDbConnName) + gplus.InsertBatch[User](users, opt) + + query, u := gplus.NewQueryBaseDb[User](opt) + query.Eq(&u.Username, "afumu1") + if res := gplus.Delete[User](query, gplus.DbBaseName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("errors happened when Delete: %v, affected: %v", res.Error, res.RowsAffected) + } + + _, resultDb := gplus.SelectOne[User](query, gplus.DbBaseName(gormDbConnName)) + if !errors.Is(resultDb.Error, gorm.ErrRecordNotFound) { + t.Errorf("should returns record not found error, but got %v", resultDb.Error) + } +} + +func TestUpdateByIdBaseDb(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + gplus.InsertBatch[User](users, gplus.DbBaseName(gormDbConnName)) + + user := users[0] + user.Score = 100 + user.Age = 25 + + if res := gplus.UpdateById[User](user, gplus.DbBaseName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { + t.Errorf("errors happened when deleteByIds: %v, affected: %v", res.Error, res.RowsAffected) + } + + newUser, db := gplus.SelectById[User](user.ID, gplus.DbBaseName(gormDbConnName)) + if db.Error != nil { + t.Fatalf("errors happened when SelectById: %v", db.Error) + } + AssertObjEqual(t, newUser, user, "ID", "Username", "Password", "Address", "Age", "Phone", "Score", "Dept", "CreatedAt", "UpdatedAt") + +} + +func TestSelectByIdBaseDb(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + gplus.InsertBatch[User](users, gplus.DbBaseName(gormDbConnName)) + user := users[0] + resultUser, db := gplus.SelectById[User](user.ID, gplus.DbBaseName(gormDbConnName)) + if db.Error != nil { + t.Errorf("errors happened when selectById : %v", db.Error) + } else { + AssertObjEqual(t, resultUser, user, "ID", "Username", "Password", "Address", "Age", "Phone", "Score", "Dept", "CreatedAt", "UpdatedAt") + } +} + +func TestSelectGeneric6BaseName(t *testing.T) { + deleteOldDataBaseDb() + users := getUsers() + opt := gplus.DbBaseName(gormDbConnName) + gplus.InsertBatch[User](users, opt) + type UserVo struct { + Dept string + Score int + } + var userMap = make(map[string]int) + for _, user := range users { + userMap[user.Dept] += user.Score + } + query, u := gplus.NewQueryBaseDb[User](opt) + uvo := gplus.GetModelBaseDb[UserVo](gormDbConnName) + query.Select(&u.Dept, gplus.Sum(&u.Score).As(&uvo.Score)).Group(&u.Dept) + UserVos, resultDb := gplus.SelectGeneric[User, []UserVo](query, gplus.DbBaseName(gormDbConnName)) + + if resultDb.Error != nil { + t.Errorf("errors happened when resultDb : %v", resultDb.Error) + } + + for _, userVo := range UserVos { + score := userMap[userVo.Dept] + if userVo.Score != score { + t.Errorf("errors happened when SelectGeneric") + } + } +} + func deleteOldData() { q, u := gplus.NewQuery[User]() q.IsNotNull(&u.ID) gplus.Delete(q) } +func deleteOldDataBaseDb() { + opt := gplus.DbBaseName(gormDbConnName) + q, u := gplus.NewQueryBaseDb[User](opt) + q.IsNotNull(&u.ID) + gplus.Delete(q, gplus.DbBaseName(gormDbConnName)) +} + func getUsers() []*User { user1 := &User{Username: "afumu1", Password: "123456", Age: 18, Score: 12, Dept: "开发部门"} user2 := &User{Username: "afumu2", Password: "123456", Age: 16, Score: 34, Dept: "行政部门"} From c6d7436155dccc718c6854421e9364f322b2ab34 Mon Sep 17 00:00:00 2001 From: ydy <517697206@qq.com> Date: Sat, 15 Feb 2025 15:18:52 +0800 Subject: [PATCH 2/7] feat:change NewQuery and GetModel for getDbConnName --- gplus/cache.go | 20 ++++++++++++++++---- gplus/dao.go | 2 ++ gplus/query.go | 3 ++- tests/dao_test.go | 13 +++++++++++-- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/gplus/cache.go b/gplus/cache.go index bb987c3..5a25606 100644 --- a/gplus/cache.go +++ b/gplus/cache.go @@ -69,11 +69,12 @@ func getColumnNameMap(model any, dbConnName string) map[uintptr]string { // GetModel 获取 func GetModel[T any]() *T { - return GetModelBaseDb[T]("") + dbConnName := getDefaultDbConnName() + return GetModelBaseDb[T](DbBaseName(dbConnName)) } // GetModelBaseDb 获取根据数据库连接名 -func GetModelBaseDb[T any](dbConnName string) *T { +func GetModelBaseDb[T any](opt OptionFunc) *T { modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := modelInstanceCache.Load(modelTypeStr); ok { m, isReal := model.(*T) @@ -82,7 +83,8 @@ func GetModelBaseDb[T any](dbConnName string) *T { } } t := new(T) - Cache(dbConnName, t) + option := getOneOption(opt) + Cache(option.DbConnName, t) return t } @@ -118,7 +120,7 @@ func parseColumnName(field reflect.StructField, dbConnName string) string { return name } if len(dbConnName) == 0 { - dbConnName = constants.DefaultGormPlusConnName + dbConnName = getDefaultDbConnName() } db, _ := GetDb(dbConnName) return db.Config.NamingStrategy.ColumnName("", field.Name) @@ -142,3 +144,13 @@ func getColumnName(v any) string { } return columnName } + +func getDefaultDbConnName() string { + dbConnName := constants.DefaultGormPlusConnName + //如果用户没传数据库连接名称,优先从全局globalDbKeys里获取第一个连接名 + //避免用户使用InitDb方法初始化数据库 自定义数据库连接名,然后方法里不传是哪个数据库连接名 则只能默认取第一条 + if len(globalDbKeys) >= 1 { + dbConnName = globalDbKeys[0] + } + return dbConnName +} diff --git a/gplus/dao.go b/gplus/dao.go index 18573e8..d9c5152 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -30,6 +30,7 @@ import ( ) var globalDbMap = make(map[string]*gorm.DB) +var globalDbKeys []string var defaultBatchSize = 1000 func Init(db *gorm.DB) { @@ -44,6 +45,7 @@ func InitDb(db *gorm.DB, dbConnName string) error { if !exists { // db instance register to global variable globalDbMap[dbConnName] = db + globalDbKeys = append(globalDbKeys, dbConnName) return nil } return errors.New("InitMultiple have same name:" + dbConnName + ",please check") diff --git a/gplus/query.go b/gplus/query.go index badfe07..aa61153 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -47,7 +47,8 @@ func (q *QueryCond[T]) getSqlSegment() string { // NewQuery 构建查询条件 func NewQuery[T any]() (*QueryCond[T], *T) { - return NewQueryBaseDb[T](DbBaseName(constants.DefaultGormPlusConnName)) + dbConnName := getDefaultDbConnName() + return NewQueryBaseDb[T](DbBaseName(dbConnName)) } // NewQueryBaseDb 构建查询条件 diff --git a/tests/dao_test.go b/tests/dao_test.go index 25d8d23..26806de 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -722,9 +722,9 @@ func TestSelectGeneric6BaseName(t *testing.T) { userMap[user.Dept] += user.Score } query, u := gplus.NewQueryBaseDb[User](opt) - uvo := gplus.GetModelBaseDb[UserVo](gormDbConnName) + uvo := gplus.GetModelBaseDb[UserVo](opt) query.Select(&u.Dept, gplus.Sum(&u.Score).As(&uvo.Score)).Group(&u.Dept) - UserVos, resultDb := gplus.SelectGeneric[User, []UserVo](query, gplus.DbBaseName(gormDbConnName)) + UserVos, resultDb := gplus.SelectGeneric[User, []UserVo](query, opt) if resultDb.Error != nil { t.Errorf("errors happened when resultDb : %v", resultDb.Error) @@ -736,6 +736,15 @@ func TestSelectGeneric6BaseName(t *testing.T) { t.Errorf("errors happened when SelectGeneric") } } + //如果还是使用旧有的方法测试 + query, u = gplus.NewQuery[User]() + uvo = gplus.GetModel[UserVo]() + query.Select(&u.Dept, gplus.Sum(&u.Score).As(&uvo.Score)).Group(&u.Dept) + UserVos, resultDb = gplus.SelectGeneric[User, []UserVo](query, opt) + + if resultDb.Error != nil { + t.Errorf("errors happened when resultDb : %v", resultDb.Error) + } } func deleteOldData() { From 2be0972ae0d9db927b1045906d983f37110b10cb Mon Sep 17 00:00:00 2001 From: ydy <517697206@qq.com> Date: Sat, 15 Feb 2025 17:07:21 +0800 Subject: [PATCH 3/7] feat: 1.modify BuildQueryBaseDb and getDefaultDbConnName for getDefaultConnName first 2.add testDemo for BuildQueryBaseDb --- gplus/cache.go | 12 ++++++++---- gplus/tool.go | 8 +++++--- tests/dao_test.go | 16 ++++++++++++++-- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/gplus/cache.go b/gplus/cache.go index 5a25606..8641bab 100644 --- a/gplus/cache.go +++ b/gplus/cache.go @@ -147,10 +147,14 @@ func getColumnName(v any) string { func getDefaultDbConnName() string { dbConnName := constants.DefaultGormPlusConnName - //如果用户没传数据库连接名称,优先从全局globalDbKeys里获取第一个连接名 - //避免用户使用InitDb方法初始化数据库 自定义数据库连接名,然后方法里不传是哪个数据库连接名 则只能默认取第一条 - if len(globalDbKeys) >= 1 { - dbConnName = globalDbKeys[0] + //如果用户没传数据库连接名称,优先判断全局自定义的连接名是否存在, + //如果上面不存在其次从全局globalDbKeys里获取第一个连接名 + //1.避免用户使用InitDb方法初始化数据库 自定义数据库连接名 ,然后方法里不传是哪个数据库连接名 则只能默认取第一条 + //2.再混用单库Init取初始化,做方法兼容 + _, exists := globalDbMap[dbConnName] + if exists { + return dbConnName } + dbConnName = globalDbKeys[0] return dbConnName } diff --git a/gplus/tool.go b/gplus/tool.go index 19933b7..7b042cf 100644 --- a/gplus/tool.go +++ b/gplus/tool.go @@ -56,16 +56,18 @@ var builders = map[string]func(query *QueryCond[any], name string, value any){ } func BuildQuery[T any](queryParams url.Values) *QueryCond[T] { - return BuildQueryBaseDb[T](queryParams, "") + dbConnName := getDefaultDbConnName() + return BuildQueryBaseDb[T](queryParams, DbBaseName(dbConnName)) } -func BuildQueryBaseDb[T any](queryParams url.Values, dbConnName string) *QueryCond[T] { +func BuildQueryBaseDb[T any](queryParams url.Values, opt OptionFunc) *QueryCond[T] { columnCondMap, conditionMap, gcond := parseParams(queryParams) parentQuery := buildParentQuery[T](conditionMap) - queryCondMap := buildQueryCondMap[T](columnCondMap, dbConnName) + option := getOneOption(opt) + queryCondMap := buildQueryCondMap[T](columnCondMap, option.DbConnName) // 如果没有分组条件,直接返回默认的查询条件 if len(gcond) == 0 { diff --git a/tests/dao_test.go b/tests/dao_test.go index 26806de..bc4345f 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -24,6 +24,7 @@ import ( "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" + "net/url" "reflect" "sort" "strconv" @@ -32,9 +33,12 @@ import ( var gormDb *gorm.DB var gormDbConnName = "test1" +var dbAddress = "127.0.0.1:3306" +var dbUser = "root" +var dbPassword = "123456" func init() { - dsn := "root:123456@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local" + dsn := fmt.Sprintf("%s:%s@tcp(%s)/test?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPassword, dbAddress) var err error gormDb, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), @@ -49,7 +53,7 @@ func init() { } func initDb() { - dsn := "root:123456@tcp(127.0.0.1:3306)/test1?charset=utf8mb4&parseTime=True&loc=Local" + dsn := fmt.Sprintf("%s:%s@tcp(%s)/test1?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPassword, dbAddress) var err error gormDb1, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), @@ -747,6 +751,14 @@ func TestSelectGeneric6BaseName(t *testing.T) { } } +func TestQueryByIdBaseDb(t *testing.T) { + opt := gplus.DbBaseName(gormDbConnName) + values := url.Values{} + values["q"] = []string{"id=1"} + query := gplus.BuildQueryBaseDb[User](values, opt) + gplus.SelectList[User](query, opt) +} + func deleteOldData() { q, u := gplus.NewQuery[User]() q.IsNotNull(&u.ID) From 8d2718307bbec81b1c9fe8474c490fc956610001 Mon Sep 17 00:00:00 2001 From: ydy <517697206@qq.com> Date: Sun, 16 Feb 2025 10:46:11 +0800 Subject: [PATCH 4/7] =?UTF-8?q?feat:add=20option=20parameter=20DbSession?= =?UTF-8?q?=20and=20=20Implementation=20it=20in=20getDb=20method=EF=BC=8Ca?= =?UTF-8?q?nd=20adjust=20same=20logical=20and=20remove=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gplus/cache.go | 2 +- gplus/dao.go | 19 ++++++++++++++----- gplus/option.go | 23 ++--------------------- gplus/query.go | 2 +- gplus/tool.go | 2 +- tests/dao_test.go | 40 ++++++++++++++++++++-------------------- 6 files changed, 39 insertions(+), 49 deletions(-) diff --git a/gplus/cache.go b/gplus/cache.go index 8641bab..fe94c56 100644 --- a/gplus/cache.go +++ b/gplus/cache.go @@ -70,7 +70,7 @@ func getColumnNameMap(model any, dbConnName string) map[uintptr]string { // GetModel 获取 func GetModel[T any]() *T { dbConnName := getDefaultDbConnName() - return GetModelBaseDb[T](DbBaseName(dbConnName)) + return GetModelBaseDb[T](DbConnName(dbConnName)) } // GetModelBaseDb 获取根据数据库连接名 diff --git a/gplus/dao.go b/gplus/dao.go index d9c5152..9097753 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -482,17 +482,20 @@ func getDb(opts ...OptionFunc) *gorm.DB { option := getOption(opts) if option.Db != nil { - db = option.Db.Clauses() + db = option.Db } else { if len(option.DbConnName) == 0 { - option.DbConnName = constants.DefaultGormPlusConnName + option.DbConnName = getDefaultDbConnName() } - db, _ = GetDb(option.DbConnName) - // Clauses()目的是为了初始化Db,如果db已经被初始化了,会直接返回db - db = db.Clauses() } + //设置session,如果需要子句仅在当前会话生效,先调用 Session(),再调用 Clauses()。 + setSessionIfNeed(option, db) + + // Clauses()目的是为了初始化Db,如果db已经被初始化了,会直接返回db + db = db.Clauses() + // 设置需要忽略的字段 setOmitIfNeed(option, db) @@ -538,6 +541,12 @@ func setOmitIfNeed(option Option, db *gorm.DB) { } } +func setSessionIfNeed(option Option, db *gorm.DB) { + if option.DbSession != nil { + db.Session(option.DbSession) + } +} + func getPkColumnName[T any]() string { var entity T entityType := reflect.TypeOf(entity) diff --git a/gplus/option.go b/gplus/option.go index f4941cb..675530e 100644 --- a/gplus/option.go +++ b/gplus/option.go @@ -18,7 +18,6 @@ package gplus import ( - "github.com/acmestack/gorm-plus/constants" "gorm.io/gorm" ) @@ -28,6 +27,7 @@ type Option struct { Omits []any IgnoreTotal bool DbConnName string + DbSession *gorm.Session } type OptionFunc func(*Option) @@ -42,9 +42,7 @@ func Db(db *gorm.DB) OptionFunc { // Session 创建会话 func Session(session *gorm.Session) OptionFunc { return func(o *Option) { - //兼容之前的设计 - db, _ := GetDb(constants.DefaultGormPlusConnName) - o.Db = db.Session(session) + o.DbSession = session //调整session 在dao类的getDb方法那边处理 } } @@ -75,20 +73,3 @@ func DbConnName(dbConnName string) OptionFunc { o.DbConnName = dbConnName } } - -// DbSessionBaseName 创建特定的Db会话 -func DbSessionBaseName(dbConnName string, session *gorm.Session) OptionFunc { - return func(o *Option) { - o.DbConnName = dbConnName - db, _ := GetDb(dbConnName) - o.Db = db.Session(session) - } -} - -// DbBaseName 使用特定的Db对象 -func DbBaseName(dbConnName string) OptionFunc { - return func(o *Option) { - o.DbConnName = dbConnName - o.Db, _ = GetDb(dbConnName) - } -} diff --git a/gplus/query.go b/gplus/query.go index aa61153..cf177b2 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -48,7 +48,7 @@ func (q *QueryCond[T]) getSqlSegment() string { // NewQuery 构建查询条件 func NewQuery[T any]() (*QueryCond[T], *T) { dbConnName := getDefaultDbConnName() - return NewQueryBaseDb[T](DbBaseName(dbConnName)) + return NewQueryBaseDb[T](DbConnName(dbConnName)) } // NewQueryBaseDb 构建查询条件 diff --git a/gplus/tool.go b/gplus/tool.go index 7b042cf..c86aeeb 100644 --- a/gplus/tool.go +++ b/gplus/tool.go @@ -57,7 +57,7 @@ var builders = map[string]func(query *QueryCond[any], name string, value any){ func BuildQuery[T any](queryParams url.Values) *QueryCond[T] { dbConnName := getDefaultDbConnName() - return BuildQueryBaseDb[T](queryParams, DbBaseName(dbConnName)) + return BuildQueryBaseDb[T](queryParams, DbConnName(dbConnName)) } func BuildQueryBaseDb[T any](queryParams url.Values, opt OptionFunc) *QueryCond[T] { diff --git a/tests/dao_test.go b/tests/dao_test.go index bc4345f..af2c86a 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -613,7 +613,7 @@ func TestInsertBaseDb(t *testing.T) { deleteOldDataBaseDb() user := &User{Username: "afumu", Password: "123456", Age: 18, Score: 100, Dept: "开发部门"} - resultDb := gplus.Insert(user, gplus.DbBaseName(gormDbConnName)) + resultDb := gplus.Insert(user, gplus.DbConnName(gormDbConnName)) if resultDb.Error != nil { t.Fatalf("errors happened when insert: %v", resultDb.Error) @@ -621,7 +621,7 @@ func TestInsertBaseDb(t *testing.T) { t.Fatalf("rows affected expects: %v, got %v", 1, resultDb.RowsAffected) } - newUser, db := gplus.SelectById[User](user.ID, gplus.DbBaseName(gormDbConnName)) + newUser, db := gplus.SelectById[User](user.ID, gplus.DbConnName(gormDbConnName)) if db.Error != nil { t.Fatalf("errors happened when SelectById: %v", db.Error) } @@ -631,13 +631,13 @@ func TestInsertBaseDb(t *testing.T) { func TestInsertBatchBaseDb(t *testing.T) { deleteOldDataBaseDb() users := getUsers() - resultDb := gplus.InsertBatch[User](users, gplus.DbBaseName(gormDbConnName)) + resultDb := gplus.InsertBatch[User](users, gplus.DbConnName(gormDbConnName)) if resultDb.RowsAffected != int64(len(users)) { t.Errorf("affected rows should be %v, but got %v", len(users), resultDb.RowsAffected) } for _, user := range users { - newUser, db := gplus.SelectById[User](user.ID, gplus.DbBaseName(gormDbConnName)) + newUser, db := gplus.SelectById[User](user.ID, gplus.DbConnName(gormDbConnName)) if db.Error != nil { t.Fatalf("errors happened when SelectById: %v", db.Error) } @@ -648,13 +648,13 @@ func TestInsertBatchBaseDb(t *testing.T) { func TestDeleteByIdBaseDb(t *testing.T) { deleteOldDataBaseDb() users := getUsers() - gplus.InsertBatchSize[User](users, 2, gplus.DbBaseName(gormDbConnName)) + gplus.InsertBatchSize[User](users, 2, gplus.DbConnName(gormDbConnName)) - if res := gplus.DeleteById[User](users[1].ID, gplus.DbBaseName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { + if res := gplus.DeleteById[User](users[1].ID, gplus.DbConnName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { t.Errorf("errors happened when deleteById: %v, affected: %v", res.Error, res.RowsAffected) } - _, resultDb := gplus.SelectById[User](users[1].ID, gplus.DbBaseName(gormDbConnName)) + _, resultDb := gplus.SelectById[User](users[1].ID, gplus.DbConnName(gormDbConnName)) if !errors.Is(resultDb.Error, gorm.ErrRecordNotFound) { t.Errorf("should returns record not found error, but got %v", resultDb.Error) } @@ -663,16 +663,16 @@ func TestDeleteByIdBaseDb(t *testing.T) { func TestDeleteBaseDb(t *testing.T) { deleteOldDataBaseDb() users := getUsers() - opt := gplus.DbBaseName(gormDbConnName) + opt := gplus.DbConnName(gormDbConnName) gplus.InsertBatch[User](users, opt) query, u := gplus.NewQueryBaseDb[User](opt) query.Eq(&u.Username, "afumu1") - if res := gplus.Delete[User](query, gplus.DbBaseName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { + if res := gplus.Delete[User](query, gplus.DbConnName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { t.Errorf("errors happened when Delete: %v, affected: %v", res.Error, res.RowsAffected) } - _, resultDb := gplus.SelectOne[User](query, gplus.DbBaseName(gormDbConnName)) + _, resultDb := gplus.SelectOne[User](query, gplus.DbConnName(gormDbConnName)) if !errors.Is(resultDb.Error, gorm.ErrRecordNotFound) { t.Errorf("should returns record not found error, but got %v", resultDb.Error) } @@ -681,17 +681,17 @@ func TestDeleteBaseDb(t *testing.T) { func TestUpdateByIdBaseDb(t *testing.T) { deleteOldDataBaseDb() users := getUsers() - gplus.InsertBatch[User](users, gplus.DbBaseName(gormDbConnName)) + gplus.InsertBatch[User](users, gplus.DbConnName(gormDbConnName)) user := users[0] user.Score = 100 user.Age = 25 - if res := gplus.UpdateById[User](user, gplus.DbBaseName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { + if res := gplus.UpdateById[User](user, gplus.DbConnName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { t.Errorf("errors happened when deleteByIds: %v, affected: %v", res.Error, res.RowsAffected) } - newUser, db := gplus.SelectById[User](user.ID, gplus.DbBaseName(gormDbConnName)) + newUser, db := gplus.SelectById[User](user.ID, gplus.DbConnName(gormDbConnName)) if db.Error != nil { t.Fatalf("errors happened when SelectById: %v", db.Error) } @@ -702,9 +702,9 @@ func TestUpdateByIdBaseDb(t *testing.T) { func TestSelectByIdBaseDb(t *testing.T) { deleteOldDataBaseDb() users := getUsers() - gplus.InsertBatch[User](users, gplus.DbBaseName(gormDbConnName)) + gplus.InsertBatch[User](users, gplus.DbConnName(gormDbConnName)) user := users[0] - resultUser, db := gplus.SelectById[User](user.ID, gplus.DbBaseName(gormDbConnName)) + resultUser, db := gplus.SelectById[User](user.ID, gplus.DbConnName(gormDbConnName)) if db.Error != nil { t.Errorf("errors happened when selectById : %v", db.Error) } else { @@ -712,10 +712,10 @@ func TestSelectByIdBaseDb(t *testing.T) { } } -func TestSelectGeneric6BaseName(t *testing.T) { +func TestSelectGeneric6BaseDb(t *testing.T) { deleteOldDataBaseDb() users := getUsers() - opt := gplus.DbBaseName(gormDbConnName) + opt := gplus.DbConnName(gormDbConnName) gplus.InsertBatch[User](users, opt) type UserVo struct { Dept string @@ -752,7 +752,7 @@ func TestSelectGeneric6BaseName(t *testing.T) { } func TestQueryByIdBaseDb(t *testing.T) { - opt := gplus.DbBaseName(gormDbConnName) + opt := gplus.DbConnName(gormDbConnName) values := url.Values{} values["q"] = []string{"id=1"} query := gplus.BuildQueryBaseDb[User](values, opt) @@ -766,10 +766,10 @@ func deleteOldData() { } func deleteOldDataBaseDb() { - opt := gplus.DbBaseName(gormDbConnName) + opt := gplus.DbConnName(gormDbConnName) q, u := gplus.NewQueryBaseDb[User](opt) q.IsNotNull(&u.ID) - gplus.Delete(q, gplus.DbBaseName(gormDbConnName)) + gplus.Delete(q, gplus.DbConnName(gormDbConnName)) } func getUsers() []*User { From c68b420a3f56809e2cdc13f442555850a782ff20 Mon Sep 17 00:00:00 2001 From: ydy <517697206@qq.com> Date: Sun, 16 Feb 2025 12:12:01 +0800 Subject: [PATCH 5/7] feat:adjust methods and parameter --- gplus/cache.go | 43 ++++++++++++------------------------------- gplus/dao.go | 38 +++++++++++++++++++++++++++++++++----- gplus/query.go | 12 ++++++------ gplus/tool.go | 22 ++++++++++++---------- tests/dao_test.go | 13 +++++++++++++ 5 files changed, 76 insertions(+), 52 deletions(-) diff --git a/gplus/cache.go b/gplus/cache.go index fe94c56..d237533 100644 --- a/gplus/cache.go +++ b/gplus/cache.go @@ -18,7 +18,6 @@ package gplus import ( - "github.com/acmestack/gorm-plus/constants" "gorm.io/gorm/schema" "reflect" "sync" @@ -32,9 +31,10 @@ var columnNameCache sync.Map var modelInstanceCache sync.Map // Cache 缓存实体对象所有的字段名 -func Cache(dbConnName string, models ...any) { +func Cache(opt OptionFunc, models ...any) { + db, _, _ := getDefaultDbByOpt(opt) for _, model := range models { - columnNameMap := getColumnNameMap(model, dbConnName) + columnNameMap := getColumnNameMap(model, db.Config.NamingStrategy) for pointer, columnName := range columnNameMap { columnNameCache.Store(pointer, columnName) } @@ -44,7 +44,7 @@ func Cache(dbConnName string, models ...any) { } } -func getColumnNameMap(model any, dbConnName string) map[uintptr]string { +func getColumnNameMap(model any, namingStrategy schema.Namer) map[uintptr]string { var columnNameMap = make(map[uintptr]string) valueOf := reflect.ValueOf(model).Elem() typeOf := reflect.TypeOf(model).Elem() @@ -53,14 +53,14 @@ func getColumnNameMap(model any, dbConnName string) map[uintptr]string { // 如果当前实体嵌入了其他实体,同样需要缓存它的字段名 if field.Anonymous { // 如果存在多重嵌套,通过递归方式获取他们的字段名 - subFieldMap := getSubFieldColumnNameMap(valueOf, field, dbConnName) + subFieldMap := getSubFieldColumnNameMap(valueOf, field, namingStrategy) for pointer, columnName := range subFieldMap { columnNameMap[pointer] = columnName } } else { // 获取对象字段指针值 pointer := valueOf.Field(i).Addr().Pointer() - columnName := parseColumnName(field, dbConnName) + columnName := parseColumnName(field, namingStrategy) columnNameMap[pointer] = columnName } } @@ -83,13 +83,12 @@ func GetModelBaseDb[T any](opt OptionFunc) *T { } } t := new(T) - option := getOneOption(opt) - Cache(option.DbConnName, t) + Cache(opt, t) return t } // 递归获取嵌套字段名 -func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField, dbConnName string) map[uintptr]string { +func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField, namingStrategy schema.Namer) map[uintptr]string { result := make(map[uintptr]string) modelType := field.Type if modelType.Kind() == reflect.Ptr { @@ -98,13 +97,13 @@ func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField, for j := 0; j < modelType.NumField(); j++ { subField := modelType.Field(j) if subField.Anonymous { - nestedFields := getSubFieldColumnNameMap(valueOf, subField, dbConnName) + nestedFields := getSubFieldColumnNameMap(valueOf, subField, namingStrategy) for key, value := range nestedFields { result[key] = value } } else { pointer := valueOf.FieldByName(modelType.Field(j).Name).Addr().Pointer() - name := parseColumnName(modelType.Field(j), dbConnName) + name := parseColumnName(modelType.Field(j), namingStrategy) result[pointer] = name } } @@ -113,17 +112,13 @@ func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField, } // 解析字段名称 兼容多数据库切换 -func parseColumnName(field reflect.StructField, dbConnName string) string { +func parseColumnName(field reflect.StructField, namingStrategy schema.Namer) string { tagSetting := schema.ParseTagSetting(field.Tag.Get("gorm"), ";") name, ok := tagSetting["COLUMN"] if ok { return name } - if len(dbConnName) == 0 { - dbConnName = getDefaultDbConnName() - } - db, _ := GetDb(dbConnName) - return db.Config.NamingStrategy.ColumnName("", field.Name) + return namingStrategy.ColumnName("", field.Name) } func getColumnName(v any) string { @@ -144,17 +139,3 @@ func getColumnName(v any) string { } return columnName } - -func getDefaultDbConnName() string { - dbConnName := constants.DefaultGormPlusConnName - //如果用户没传数据库连接名称,优先判断全局自定义的连接名是否存在, - //如果上面不存在其次从全局globalDbKeys里获取第一个连接名 - //1.避免用户使用InitDb方法初始化数据库 自定义数据库连接名 ,然后方法里不传是哪个数据库连接名 则只能默认取第一条 - //2.再混用单库Init取初始化,做方法兼容 - _, exists := globalDbMap[dbConnName] - if exists { - return dbConnName - } - dbConnName = globalDbKeys[0] - return dbConnName -} diff --git a/gplus/dao.go b/gplus/dao.go index 9097753..4193999 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -186,7 +186,7 @@ func UpdateZeroById[T any](entity *T, opts ...OptionFunc) *gorm.DB { func updateAllIfNeed(entity any, opts []OptionFunc, db *gorm.DB) { option := getOption(opts) if len(option.Selects) == 0 { - columnNameMap := getColumnNameMap(entity, option.DbConnName) + columnNameMap := getColumnNameMap(entity, db.Config.NamingStrategy) var columnNames []string for _, columnName := range columnNameMap { columnNames = append(columnNames, columnName) @@ -484,10 +484,7 @@ func getDb(opts ...OptionFunc) *gorm.DB { if option.Db != nil { db = option.Db } else { - if len(option.DbConnName) == 0 { - option.DbConnName = getDefaultDbConnName() - } - db, _ = GetDb(option.DbConnName) + db, option.DbConnName, _ = getDefaultDbByName(option.DbConnName) } //设置session,如果需要子句仅在当前会话生效,先调用 Session(),再调用 Clauses()。 @@ -571,3 +568,34 @@ func getPkColumnName[T any]() string { } return columnName } + +func getDefaultDbConnName() string { + dbConnName := constants.DefaultGormPlusConnName + //如果用户没传数据库连接名称,优先判断全局自定义的连接名是否存在, + //如果上面不存在其次从全局globalDbKeys里获取第一个连接名 + //1.避免用户使用InitDb方法初始化数据库 自定义数据库连接名 ,然后方法里不传是哪个数据库连接名 则只能默认取第一条 + //2.再混用单库Init取初始化,做方法兼容 + _, exists := globalDbMap[dbConnName] + if exists { + return dbConnName + } + dbConnName = globalDbKeys[0] + return dbConnName +} + +func getDefaultDbByOpt(opt OptionFunc) (*gorm.DB, string, error) { + option := getOneOption(opt) + if len(option.DbConnName) == 0 { + option.DbConnName = getDefaultDbConnName() + } + db, err := GetDb(option.DbConnName) + return db, option.DbConnName, err +} + +func getDefaultDbByName(dbConnName string) (*gorm.DB, string, error) { + if len(dbConnName) == 0 { + dbConnName = getDefaultDbConnName() + } + db, err := GetDb(dbConnName) + return db, dbConnName, err +} diff --git a/gplus/query.go b/gplus/query.go index cf177b2..bf8d69a 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -62,18 +62,18 @@ func NewQueryBaseDb[T any](opt OptionFunc) (*QueryCond[T], *T) { } } m := new(T) - option := getOneOption(opt) - Cache(option.DbConnName, m) + Cache(opt, m) return q, m } // NewQueryModel 构建查询条件 func NewQueryModel[T any, R any]() (*QueryCond[T], *T, *R) { - return NewQueryModelBaseDb[T, R]("") + dbConnName := getDefaultDbConnName() //兼容之前设计 + return NewQueryModelBaseDb[T, R](DbConnName(dbConnName)) } // NewQueryModelBaseDb 构建查询条件 -func NewQueryModelBaseDb[T any, R any](dbConnName string) (*QueryCond[T], *T, *R) { +func NewQueryModelBaseDb[T any, R any](opt OptionFunc) (*QueryCond[T], *T, *R) { q := &QueryCond[T]{} var t *T var r *R @@ -95,12 +95,12 @@ func NewQueryModelBaseDb[T any, R any](dbConnName string) (*QueryCond[T], *T, *R if t == nil { t = new(T) - Cache(dbConnName, t) + Cache(opt, t) } if r == nil { r = new(R) - Cache(dbConnName, r) + Cache(opt, r) } return q, t, r diff --git a/gplus/tool.go b/gplus/tool.go index c86aeeb..9ea546c 100644 --- a/gplus/tool.go +++ b/gplus/tool.go @@ -19,6 +19,7 @@ package gplus import ( "fmt" + "gorm.io/gorm/schema" "net/url" "reflect" "strconv" @@ -66,8 +67,9 @@ func BuildQueryBaseDb[T any](queryParams url.Values, opt OptionFunc) *QueryCond[ parentQuery := buildParentQuery[T](conditionMap) - option := getOneOption(opt) - queryCondMap := buildQueryCondMap[T](columnCondMap, option.DbConnName) + db, _, _ := getDefaultDbByOpt(opt) + + queryCondMap := buildQueryCondMap[T](columnCondMap, db.Config.NamingStrategy) // 如果没有分组条件,直接返回默认的查询条件 if len(gcond) == 0 { @@ -165,9 +167,9 @@ func getCurrentOp(value string) string { return currentOperator } -func buildQueryCondMap[T any](columnCondMap map[string][]*Condition, dbConnName string) map[string]*QueryCond[T] { +func buildQueryCondMap[T any](columnCondMap map[string][]*Condition, namingStrategy schema.Namer) map[string]*QueryCond[T] { var queryCondMap = make(map[string]*QueryCond[T]) - columnTypeMap := getColumnTypeMap[T](dbConnName) + columnTypeMap := getColumnTypeMap[T](namingStrategy) for key, conditions := range columnCondMap { query := &QueryCond[any]{} query.columnTypeMap = columnTypeMap @@ -279,7 +281,7 @@ func buildGroupQuery[T any](gcond string, queryMaps map[string]*QueryCond[T], qu return query } -func getColumnTypeMap[T any](dbConnName string) map[string]reflect.Type { +func getColumnTypeMap[T any](namingStrategy schema.Namer) map[string]reflect.Type { modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := columnTypeCache.Load(modelTypeStr); ok { if columnNameMap, isOk := model.(map[string]reflect.Type); isOk { @@ -291,19 +293,19 @@ func getColumnTypeMap[T any](dbConnName string) map[string]reflect.Type { for i := 0; i < typeOf.NumField(); i++ { field := typeOf.Field(i) if field.Anonymous { - nestedFields := getSubFieldColumnTypeMap(field, dbConnName) + nestedFields := getSubFieldColumnTypeMap(field, namingStrategy) for key, value := range nestedFields { columnTypeMap[key] = value } } - columnName := parseColumnName(field, dbConnName) + columnName := parseColumnName(field, namingStrategy) columnTypeMap[columnName] = field.Type } columnTypeCache.Store(modelTypeStr, columnTypeMap) return columnTypeMap } -func getSubFieldColumnTypeMap(field reflect.StructField, dbConnName string) map[string]reflect.Type { +func getSubFieldColumnTypeMap(field reflect.StructField, namingStrategy schema.Namer) map[string]reflect.Type { columnTypeMap := make(map[string]reflect.Type) modelType := field.Type if modelType.Kind() == reflect.Ptr { @@ -312,12 +314,12 @@ func getSubFieldColumnTypeMap(field reflect.StructField, dbConnName string) map[ for j := 0; j < modelType.NumField(); j++ { subField := modelType.Field(j) if subField.Anonymous { - nestedFields := getSubFieldColumnTypeMap(subField, dbConnName) + nestedFields := getSubFieldColumnTypeMap(subField, namingStrategy) for key, value := range nestedFields { columnTypeMap[key] = value } } else { - columnName := parseColumnName(subField, dbConnName) + columnName := parseColumnName(subField, namingStrategy) columnTypeMap[columnName] = subField.Type } } diff --git a/tests/dao_test.go b/tests/dao_test.go index af2c86a..cee9e8f 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -725,6 +725,7 @@ func TestSelectGeneric6BaseDb(t *testing.T) { for _, user := range users { userMap[user.Dept] += user.Score } + //测试NewQueryBaseDb和GetModelBaseDb query, u := gplus.NewQueryBaseDb[User](opt) uvo := gplus.GetModelBaseDb[UserVo](opt) query.Select(&u.Dept, gplus.Sum(&u.Score).As(&uvo.Score)).Group(&u.Dept) @@ -740,6 +741,18 @@ func TestSelectGeneric6BaseDb(t *testing.T) { t.Errorf("errors happened when SelectGeneric") } } + + //测试NewQueryModelBaseDb + type UserV1 struct { + Name string + Age int64 + } + query, user, userV1 := gplus.NewQueryModelBaseDb[User, UserV1](opt) + query.Eq(&user.Username, "afumu").And(func(q *gplus.QueryCond[User]) { + q.Eq(&user.Address, "北京").Or().Eq(&user.Age, 20) + }).Select(gplus.As(&user.Username, &userV1.Name), &user.Age) + gplus.SelectGeneric[User, []UserV1](query, opt) + //如果还是使用旧有的方法测试 query, u = gplus.NewQuery[User]() uvo = gplus.GetModel[UserVo]() From 76ff0d13190860fdd830a525e8b9031b45770f1b Mon Sep 17 00:00:00 2001 From: ydy <517697206@qq.com> Date: Sun, 16 Feb 2025 15:03:55 +0800 Subject: [PATCH 6/7] feat: 1.merge original methos use Functional Options Pattern for compatible 2.add InitDbMany --- gplus/cache.go | 11 +++-------- gplus/dao.go | 32 +++++++++++++++++++++----------- gplus/query.go | 18 ++++-------------- gplus/tool.go | 8 ++------ tests/dao_test.go | 16 ++++++++-------- 5 files changed, 38 insertions(+), 47 deletions(-) diff --git a/gplus/cache.go b/gplus/cache.go index d237533..bb3b67b 100644 --- a/gplus/cache.go +++ b/gplus/cache.go @@ -31,7 +31,7 @@ var columnNameCache sync.Map var modelInstanceCache sync.Map // Cache 缓存实体对象所有的字段名 -func Cache(opt OptionFunc, models ...any) { +func Cache(opt Option, models ...any) { db, _, _ := getDefaultDbByOpt(opt) for _, model := range models { columnNameMap := getColumnNameMap(model, db.Config.NamingStrategy) @@ -68,13 +68,8 @@ func getColumnNameMap(model any, namingStrategy schema.Namer) map[uintptr]string } // GetModel 获取 -func GetModel[T any]() *T { - dbConnName := getDefaultDbConnName() - return GetModelBaseDb[T](DbConnName(dbConnName)) -} - -// GetModelBaseDb 获取根据数据库连接名 -func GetModelBaseDb[T any](opt OptionFunc) *T { +func GetModel[T any](opts ...OptionFunc) *T { + opt := getDefaultOptionInfo(opts...) //兼容设计 modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := modelInstanceCache.Load(modelTypeStr); ok { m, isReal := model.(*T) diff --git a/gplus/dao.go b/gplus/dao.go index 4193999..3d9f308 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -51,6 +51,16 @@ func InitDb(db *gorm.DB, dbConnName string) error { return errors.New("InitMultiple have same name:" + dbConnName + ",please check") } +func InitDbMany(dic map[string]*gorm.DB) []error { + var errs []error + for k, v := range dic { + if err := InitDb(v, k); err != nil { + errs = append(errs, err) + } + } + return errs +} + // GetDb 获取数据库连接 func GetDb(dbConnName string) (*gorm.DB, error) { db, exists := globalDbMap[dbConnName] @@ -70,12 +80,8 @@ type Page[T any] struct { type Dao[T any] struct{} -func (dao Dao[T]) NewQuery() (*QueryCond[T], *T) { - return NewQuery[T]() -} - -func (dao Dao[T]) NewQueryBaseDb(opt OptionFunc) (*QueryCond[T], *T) { - return NewQueryBaseDb[T](opt) +func (dao Dao[T]) NewQuery(opts ...OptionFunc) (*QueryCond[T], *T) { + return NewQuery[T](opts...) } func NewPage[T any](current, size int) *Page[T] { @@ -583,13 +589,17 @@ func getDefaultDbConnName() string { return dbConnName } -func getDefaultDbByOpt(opt OptionFunc) (*gorm.DB, string, error) { - option := getOneOption(opt) +// 获取如果连接名为空则默认填充的option数据 +func getDefaultOptionInfo(opts ...OptionFunc) Option { + option := getOption(opts) if len(option.DbConnName) == 0 { - option.DbConnName = getDefaultDbConnName() + option.DbConnName = getDefaultDbConnName() //兼容之前设计 } - db, err := GetDb(option.DbConnName) - return db, option.DbConnName, err + return option +} + +func getDefaultDbByOpt(opt Option) (*gorm.DB, string, error) { + return getDefaultDbByName(opt.DbConnName) } func getDefaultDbByName(dbConnName string) (*gorm.DB, string, error) { diff --git a/gplus/query.go b/gplus/query.go index bf8d69a..b13063f 100644 --- a/gplus/query.go +++ b/gplus/query.go @@ -46,13 +46,8 @@ func (q *QueryCond[T]) getSqlSegment() string { } // NewQuery 构建查询条件 -func NewQuery[T any]() (*QueryCond[T], *T) { - dbConnName := getDefaultDbConnName() - return NewQueryBaseDb[T](DbConnName(dbConnName)) -} - -// NewQueryBaseDb 构建查询条件 -func NewQueryBaseDb[T any](opt OptionFunc) (*QueryCond[T], *T) { +func NewQuery[T any](opts ...OptionFunc) (*QueryCond[T], *T) { + opt := getDefaultOptionInfo(opts...) //兼容设计 q := &QueryCond[T]{} modelTypeStr := reflect.TypeOf((*T)(nil)).Elem().String() if model, ok := modelInstanceCache.Load(modelTypeStr); ok { @@ -67,13 +62,8 @@ func NewQueryBaseDb[T any](opt OptionFunc) (*QueryCond[T], *T) { } // NewQueryModel 构建查询条件 -func NewQueryModel[T any, R any]() (*QueryCond[T], *T, *R) { - dbConnName := getDefaultDbConnName() //兼容之前设计 - return NewQueryModelBaseDb[T, R](DbConnName(dbConnName)) -} - -// NewQueryModelBaseDb 构建查询条件 -func NewQueryModelBaseDb[T any, R any](opt OptionFunc) (*QueryCond[T], *T, *R) { +func NewQueryModel[T any, R any](opts ...OptionFunc) (*QueryCond[T], *T, *R) { + opt := getDefaultOptionInfo(opts...) //兼容设计 q := &QueryCond[T]{} var t *T var r *R diff --git a/gplus/tool.go b/gplus/tool.go index 9ea546c..661f2f2 100644 --- a/gplus/tool.go +++ b/gplus/tool.go @@ -56,12 +56,8 @@ var builders = map[string]func(query *QueryCond[any], name string, value any){ "<": lt, } -func BuildQuery[T any](queryParams url.Values) *QueryCond[T] { - dbConnName := getDefaultDbConnName() - return BuildQueryBaseDb[T](queryParams, DbConnName(dbConnName)) -} - -func BuildQueryBaseDb[T any](queryParams url.Values, opt OptionFunc) *QueryCond[T] { +func BuildQuery[T any](queryParams url.Values, opts ...OptionFunc) *QueryCond[T] { + opt := getDefaultOptionInfo(opts...) //兼容设计 columnCondMap, conditionMap, gcond := parseParams(queryParams) diff --git a/tests/dao_test.go b/tests/dao_test.go index cee9e8f..b4554fa 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -666,7 +666,7 @@ func TestDeleteBaseDb(t *testing.T) { opt := gplus.DbConnName(gormDbConnName) gplus.InsertBatch[User](users, opt) - query, u := gplus.NewQueryBaseDb[User](opt) + query, u := gplus.NewQuery[User](opt) query.Eq(&u.Username, "afumu1") if res := gplus.Delete[User](query, gplus.DbConnName(gormDbConnName)); res.Error != nil || res.RowsAffected != 1 { t.Errorf("errors happened when Delete: %v, affected: %v", res.Error, res.RowsAffected) @@ -725,9 +725,9 @@ func TestSelectGeneric6BaseDb(t *testing.T) { for _, user := range users { userMap[user.Dept] += user.Score } - //测试NewQueryBaseDb和GetModelBaseDb - query, u := gplus.NewQueryBaseDb[User](opt) - uvo := gplus.GetModelBaseDb[UserVo](opt) + //测试NewQuery和GetModel + query, u := gplus.NewQuery[User](opt) + uvo := gplus.GetModel[UserVo](opt) query.Select(&u.Dept, gplus.Sum(&u.Score).As(&uvo.Score)).Group(&u.Dept) UserVos, resultDb := gplus.SelectGeneric[User, []UserVo](query, opt) @@ -742,12 +742,12 @@ func TestSelectGeneric6BaseDb(t *testing.T) { } } - //测试NewQueryModelBaseDb + //测试NewQueryModel type UserV1 struct { Name string Age int64 } - query, user, userV1 := gplus.NewQueryModelBaseDb[User, UserV1](opt) + query, user, userV1 := gplus.NewQueryModel[User, UserV1](opt) query.Eq(&user.Username, "afumu").And(func(q *gplus.QueryCond[User]) { q.Eq(&user.Address, "北京").Or().Eq(&user.Age, 20) }).Select(gplus.As(&user.Username, &userV1.Name), &user.Age) @@ -768,7 +768,7 @@ func TestQueryByIdBaseDb(t *testing.T) { opt := gplus.DbConnName(gormDbConnName) values := url.Values{} values["q"] = []string{"id=1"} - query := gplus.BuildQueryBaseDb[User](values, opt) + query := gplus.BuildQuery[User](values, opt) gplus.SelectList[User](query, opt) } @@ -780,7 +780,7 @@ func deleteOldData() { func deleteOldDataBaseDb() { opt := gplus.DbConnName(gormDbConnName) - q, u := gplus.NewQueryBaseDb[User](opt) + q, u := gplus.NewQuery[User](opt) q.IsNotNull(&u.ID) gplus.Delete(q, gplus.DbConnName(gormDbConnName)) } From 01027fb9c34ca8d964749649fba1046ff7cadfb0 Mon Sep 17 00:00:00 2001 From: ydy <517697206@qq.com> Date: Mon, 17 Feb 2025 14:00:47 +0800 Subject: [PATCH 7/7] feat:Init for compatibility --- gplus/cache.go | 5 ++++- gplus/dao.go | 55 +++++++++++++++++++++++++++-------------------- tests/dao_test.go | 2 +- 3 files changed, 37 insertions(+), 25 deletions(-) diff --git a/gplus/cache.go b/gplus/cache.go index bb3b67b..7c1f041 100644 --- a/gplus/cache.go +++ b/gplus/cache.go @@ -106,7 +106,10 @@ func getSubFieldColumnNameMap(valueOf reflect.Value, field reflect.StructField, return result } -// 解析字段名称 兼容多数据库切换 +// 解析字段名称 兼容多数据库切换, +// 如果用户使用Option的GetDb而没有传数据库连接名这边获取的namingStrategy 是默认的一个可能会有问题, +// 所以建议用户多数据库的时候弃用Option里的Db,并且重新改写初始化,给与每个db连接有连接名 +// 并且改造下多数据使用NewQuery和GetModel和NewQueryModel相关方法传入数据库连接名 func parseColumnName(field reflect.StructField, namingStrategy schema.Namer) string { tagSetting := schema.ParseTagSetting(field.Tag.Get("gorm"), ";") name, ok := tagSetting["COLUMN"] diff --git a/gplus/dao.go b/gplus/dao.go index 3d9f308..fa55cc4 100644 --- a/gplus/dao.go +++ b/gplus/dao.go @@ -33,28 +33,21 @@ var globalDbMap = make(map[string]*gorm.DB) var globalDbKeys []string var defaultBatchSize = 1000 -func Init(db *gorm.DB) { - InitDb(db, constants.DefaultGormPlusConnName) -} - -func InitDb(db *gorm.DB, dbConnName string) error { - if len(dbConnName) == 0 { - return errors.New("InitMultiple dbConnName is empty please check") - } - _, exists := globalDbMap[dbConnName] - if !exists { - // db instance register to global variable - globalDbMap[dbConnName] = db - globalDbKeys = append(globalDbKeys, dbConnName) - return nil +// Init 可选参数dbConnNameArr 代表数据库连接名,只需要传一个就行, +// 主要为了兼容之前用户只传一个db无需修改 +func Init(db *gorm.DB, dbConnNameArr ...string) error { + var dbConnName = "" + if len(dbConnNameArr) > 0 { + dbConnName = dbConnNameArr[0] } - return errors.New("InitMultiple have same name:" + dbConnName + ",please check") + return setGlobalInfo(db, dbConnName) } -func InitDbMany(dic map[string]*gorm.DB) []error { +// InitMany 初始化多个 +func InitMany(dic map[string]*gorm.DB) []error { var errs []error for k, v := range dic { - if err := InitDb(v, k); err != nil { + if err := setGlobalInfo(v, k); err != nil { errs = append(errs, err) } } @@ -516,12 +509,6 @@ func getOption(opts []OptionFunc) Option { return config } -func getOneOption(opt OptionFunc) Option { - var config Option - opt(&config) - return config -} - func setSelectIfNeed(option Option, db *gorm.DB) { if len(option.Selects) > 0 { var columnNames []string @@ -609,3 +596,25 @@ func getDefaultDbByName(dbConnName string) (*gorm.DB, string, error) { db, err := GetDb(dbConnName) return db, dbConnName, err } + +func setGlobalInfo(db *gorm.DB, dbConnName string) error { + if len(dbConnName) == 0 { + //return errors.New("InitMultiple dbConnName is empty please check") + //如果字典里不包含了默认名则使用默认名,兼容之前单库 + _, exists := globalDbMap[constants.DefaultGormPlusConnName] + if exists { + //根据db指针地址获取作为连接名,因为GORM 本身不提供直接获取数据库连接地址的方法,也不推荐使用反射来获取dsn + dbConnName = fmt.Sprintf("%p", db) + } else { + dbConnName = constants.DefaultGormPlusConnName + } + } + _, exists := globalDbMap[dbConnName] + if !exists { + // db instance register to global variable + globalDbMap[dbConnName] = db + globalDbKeys = append(globalDbKeys, dbConnName) + return nil + } + return errors.New("InitMultiple have same name:" + dbConnName + ",please check") +} diff --git a/tests/dao_test.go b/tests/dao_test.go index b4554fa..818dc60 100644 --- a/tests/dao_test.go +++ b/tests/dao_test.go @@ -63,7 +63,7 @@ func initDb() { } var u User gormDb1.AutoMigrate(u) - gplus.InitDb(gormDb1, gormDbConnName) + gplus.Init(gormDb1, gormDbConnName) } func TestInsert(t *testing.T) {