diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 726f9ae43..d543bdb11 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -154,12 +154,12 @@ type DB interface { // Utility methods. // =========================================================================== - GetCtx() context.Context // See Core.GetCtx. - GetCore() *Core // See Core.GetCore - GetChars() (charLeft string, charRight string) // See Core.GetChars. - Tables(ctx context.Context, schema ...string) (tables []string, err error) // See Core.Tables. - TableFields(ctx context.Context, link Link, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields. - FilteredLinkInfo() string // See Core.FilteredLinkInfo. + GetCtx() context.Context // See Core.GetCtx. + GetCore() *Core // See Core.GetCore + GetChars() (charLeft string, charRight string) // See Core.GetChars. + Tables(ctx context.Context, schema ...string) (tables []string, err error) // See Core.Tables. + TableFields(ctx context.Context, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields. + FilteredLinkInfo() string // See Core.FilteredLinkInfo. // HandleSqlBeforeCommit is a hook function, which deals with the sql string before // it's committed to underlying driver. The parameter `link` specifies the current diff --git a/database/gdb/gdb_core_structure.go b/database/gdb/gdb_core_structure.go index be34cf062..e7aa76b67 100644 --- a/database/gdb/gdb_core_structure.go +++ b/database/gdb/gdb_core_structure.go @@ -151,7 +151,7 @@ func (c *Core) convertFieldValueToLocalValue(fieldValue interface{}, fieldType s // mappingAndFilterData automatically mappings the map key to table field and removes // all key-value pairs that are not the field of given table. func (c *Core) mappingAndFilterData(schema, table string, data map[string]interface{}, filter bool) (map[string]interface{}, error) { - if fieldsMap, err := c.db.TableFields(c.GetCtx(), nil, table, schema); err == nil { + if fieldsMap, err := c.db.TableFields(c.GetCtx(), table, schema); err == nil { fieldsKeyMap := make(map[string]interface{}, len(fieldsMap)) for k, _ := range fieldsMap { fieldsKeyMap[k] = nil diff --git a/database/gdb/gdb_driver_mssql.go b/database/gdb/gdb_driver_mssql.go index d0bb93cfb..33d3f8732 100644 --- a/database/gdb/gdb_driver_mssql.go +++ b/database/gdb/gdb_driver_mssql.go @@ -207,7 +207,7 @@ func (d *DriverMssql) Tables(ctx context.Context, schema ...string) (tables []st // TableFields retrieves and returns the fields information of specified table of current schema. // // Also see DriverMysql.TableFields. -func (d *DriverMssql) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverMssql) TableFields(ctx context.Context, table string, schema ...string) (fields map[string]*TableField, err error) { charL, charR := d.GetChars() table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { @@ -223,13 +223,11 @@ func (d *DriverMssql) TableFields(ctx context.Context, link Link, table string, ) v := tableFieldsMap.GetOrSetFuncLock(tableFieldsCacheKey, func() interface{} { var ( - result Result - ) - if link == nil { + result Result link, err = d.SlaveLink(useSchema) - if err != nil { - return nil - } + ) + if err != nil { + return nil } structureSql := fmt.Sprintf(` SELECT diff --git a/database/gdb/gdb_driver_mysql.go b/database/gdb/gdb_driver_mysql.go index 724b353ec..d39078c75 100644 --- a/database/gdb/gdb_driver_mysql.go +++ b/database/gdb/gdb_driver_mysql.go @@ -113,7 +113,7 @@ func (d *DriverMysql) Tables(ctx context.Context, schema ...string) (tables []st // // It's using cache feature to enhance the performance, which is never expired util the // process restarts. -func (d *DriverMysql) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverMysql) TableFields(ctx context.Context, table string, schema ...string) (fields map[string]*TableField, err error) { charL, charR := d.GetChars() table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { @@ -129,20 +129,16 @@ func (d *DriverMysql) TableFields(ctx context.Context, link Link, table string, ) v := tableFieldsMap.GetOrSetFuncLock(tableFieldsCacheKey, func() interface{} { var ( - result Result - ) - if link == nil { + result Result link, err = d.SlaveLink(useSchema) - if err != nil { - return nil - } - } - result, err = d.DoGetAll(ctx, link, - fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.QuoteWord(table)), ) if err != nil { return nil } + result, err = d.DoGetAll(ctx, link, fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.QuoteWord(table))) + if err != nil { + return nil + } fields = make(map[string]*TableField) for i, m := range result { fields[m["Field"].String()] = &TableField{ diff --git a/database/gdb/gdb_driver_oracle.go b/database/gdb/gdb_driver_oracle.go index be356f176..9d8543936 100644 --- a/database/gdb/gdb_driver_oracle.go +++ b/database/gdb/gdb_driver_oracle.go @@ -183,7 +183,7 @@ func (d *DriverOracle) Tables(ctx context.Context, schema ...string) (tables []s // TableFields retrieves and returns the fields information of specified table of current schema. // // Also see DriverMysql.TableFields. -func (d *DriverOracle) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverOracle) TableFields(ctx context.Context, table string, schema ...string) (fields map[string]*TableField, err error) { charL, charR := d.GetChars() table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { @@ -200,6 +200,7 @@ func (d *DriverOracle) TableFields(ctx context.Context, link Link, table string, v := tableFieldsMap.GetOrSetFuncLock(tableFieldsCacheKey, func() interface{} { var ( result Result + link, err = d.SlaveLink(useSchema) structureSql = fmt.Sprintf(` SELECT COLUMN_NAME AS FIELD, @@ -211,13 +212,10 @@ FROM USER_TAB_COLUMNS WHERE TABLE_NAME = '%s' ORDER BY COLUMN_ID`, strings.ToUpper(table), ) ) - structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql)) - if link == nil { - link, err = d.SlaveLink(useSchema) - if err != nil { - return nil - } + if err != nil { + return nil } + structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql)) result, err = d.DoGetAll(ctx, link, structureSql) if err != nil { return nil diff --git a/database/gdb/gdb_driver_pgsql.go b/database/gdb/gdb_driver_pgsql.go index 8fb6ad243..b0695f8c0 100644 --- a/database/gdb/gdb_driver_pgsql.go +++ b/database/gdb/gdb_driver_pgsql.go @@ -115,7 +115,7 @@ func (d *DriverPgsql) Tables(ctx context.Context, schema ...string) (tables []st // TableFields retrieves and returns the fields information of specified table of current schema. // // Also see DriverMysql.TableFields. -func (d *DriverPgsql) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverPgsql) TableFields(ctx context.Context, table string, schema ...string) (fields map[string]*TableField, err error) { charL, charR := d.GetChars() table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { @@ -133,6 +133,7 @@ func (d *DriverPgsql) TableFields(ctx context.Context, link Link, table string, v := tableFieldsMap.GetOrSetFuncLock(tableFieldsCacheKey, func() interface{} { var ( result Result + link, err = d.SlaveLink(useSchema) structureSql = fmt.Sprintf(` SELECT a.attname AS field, t.typname AS type FROM pg_class c, pg_attribute a LEFT OUTER JOIN pg_description b ON a.attrelid=b.objoid AND a.attnum = b.objsubid,pg_type t @@ -141,13 +142,10 @@ ORDER BY a.attnum`, strings.ToLower(table), ) ) - structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql)) - if link == nil { - link, err = d.SlaveLink(useSchema) - if err != nil { - return nil - } + if err != nil { + return nil } + structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql)) result, err = d.DoGetAll(ctx, link, structureSql) if err != nil { return nil diff --git a/database/gdb/gdb_driver_sqlite.go b/database/gdb/gdb_driver_sqlite.go index 69f4a42a5..99ae55fcf 100644 --- a/database/gdb/gdb_driver_sqlite.go +++ b/database/gdb/gdb_driver_sqlite.go @@ -97,7 +97,7 @@ func (d *DriverSqlite) Tables(ctx context.Context, schema ...string) (tables []s // TableFields retrieves and returns the fields information of specified table of current schema. // // Also see DriverMysql.TableFields. -func (d *DriverSqlite) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverSqlite) TableFields(ctx context.Context, table string, schema ...string) (fields map[string]*TableField, err error) { charL, charR := d.GetChars() table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { @@ -113,13 +113,11 @@ func (d *DriverSqlite) TableFields(ctx context.Context, link Link, table string, ) v := tableFieldsMap.GetOrSetFuncLock(tableFieldsCacheKey, func() interface{} { var ( - result Result - ) - if link == nil { + result Result link, err = d.SlaveLink(useSchema) - if err != nil { - return nil - } + ) + if err != nil { + return nil } result, err = d.DoGetAll(ctx, link, fmt.Sprintf(`PRAGMA TABLE_INFO(%s)`, table)) if err != nil { diff --git a/database/gdb/gdb_model_utility.go b/database/gdb/gdb_model_utility.go index bf29b0a9c..7b1c36b60 100644 --- a/database/gdb/gdb_model_utility.go +++ b/database/gdb/gdb_model_utility.go @@ -29,7 +29,7 @@ func (m *Model) TableFields(table string, schema ...string) (fields map[string]* if !gregex.IsMatchString(regularFieldNameRegPattern, table) { return nil, nil } - return m.db.TableFields(m.GetCtx(), m.getLink(false), table, schema...) + return m.db.TableFields(m.GetCtx(), table, schema...) } // getModel creates and returns a cloned model of current model if `safe` is true, or else it returns diff --git a/database/gdb/gdb_z_mysql_subquery_test.go b/database/gdb/gdb_z_mysql_subquery_test.go index 2f292085b..6cb76abf2 100644 --- a/database/gdb/gdb_z_mysql_subquery_test.go +++ b/database/gdb/gdb_z_mysql_subquery_test.go @@ -53,7 +53,7 @@ func Test_Model_SubQuery_Having(t *testing.T) { func Test_Model_SubQuery_Model(t *testing.T) { table := createInitTable() defer dropTable(table) - db.SetDebug(true) + gtest.C(t, func(t *gtest.T) { subQuery1 := db.Model(table).Where("id", g.Slice{1, 3, 5}) subQuery2 := db.Model(table).Where("id", g.Slice{5, 7, 9})