diff --git a/contrib/drivers/clickhouse/clickhouse.go b/contrib/drivers/clickhouse/clickhouse.go index 1b68b5def..195febcac 100644 --- a/contrib/drivers/clickhouse/clickhouse.go +++ b/contrib/drivers/clickhouse/clickhouse.go @@ -143,11 +143,11 @@ func (d *Driver) TableFields( ctx context.Context, table string, schema ...string, ) (fields map[string]*gdb.TableField, err error) { var ( - result gdb.Result - link gdb.Link - useSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) + result gdb.Result + link gdb.Link + usedSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) ) - if link, err = d.SlaveLink(useSchema); err != nil { + if link, err = d.SlaveLink(usedSchema); err != nil { return nil, err } var ( diff --git a/contrib/drivers/dm/dm.go b/contrib/drivers/dm/dm.go index 6e3678065..71e0a85f1 100644 --- a/contrib/drivers/dm/dm.go +++ b/contrib/drivers/dm/dm.go @@ -30,6 +30,10 @@ type Driver struct { *gdb.Core } +const ( + quoteChar = `"` +) + func init() { var ( err error @@ -92,7 +96,7 @@ func (d *Driver) Open(config *gdb.ConfigNode) (db *sql.DB, err error) { } func (d *Driver) GetChars() (charLeft string, charRight string) { - return `"`, `"` + return quoteChar, quoteChar } func (d *Driver) Tables(ctx context.Context, schema ...string) (tables []string, err error) { diff --git a/contrib/drivers/mssql/mssql.go b/contrib/drivers/mssql/mssql.go index 536ce64a1..04952f4d2 100644 --- a/contrib/drivers/mssql/mssql.go +++ b/contrib/drivers/mssql/mssql.go @@ -33,6 +33,10 @@ type Driver struct { *gdb.Core } +const ( + quoteChar = `"` +) + func init() { if err := gdb.Register(`mssql`, New()); err != nil { panic(err) @@ -95,7 +99,7 @@ func (d *Driver) Open(config *gdb.ConfigNode) (db *sql.DB, err error) { // GetChars returns the security char for this type of database. func (d *Driver) GetChars() (charLeft string, charRight string) { - return `"`, `"` + return quoteChar, quoteChar } // DoFilter deals with the sql string before commits it to underlying sql driver. @@ -237,11 +241,11 @@ func (d *Driver) Tables(ctx context.Context, schema ...string) (tables []string, // Also see DriverMysql.TableFields. func (d *Driver) TableFields(ctx context.Context, table string, schema ...string) (fields map[string]*gdb.TableField, err error) { var ( - result gdb.Result - link gdb.Link - useSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) + result gdb.Result + link gdb.Link + usedSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) ) - if link, err = d.SlaveLink(useSchema); err != nil { + if link, err = d.SlaveLink(usedSchema); err != nil { return nil, err } structureSql := fmt.Sprintf(` diff --git a/contrib/drivers/mysql/mysql.go b/contrib/drivers/mysql/mysql.go index 39112245c..2393d679a 100644 --- a/contrib/drivers/mysql/mysql.go +++ b/contrib/drivers/mysql/mysql.go @@ -28,6 +28,10 @@ type Driver struct { *gdb.Core } +const ( + quoteChar = "`" +) + func init() { var ( err error @@ -98,7 +102,7 @@ func (d *Driver) Open(config *gdb.ConfigNode) (db *sql.DB, err error) { // GetChars returns the security char for this type of database. func (d *Driver) GetChars() (charLeft string, charRight string) { - return "`", "`" + return quoteChar, quoteChar } // DoFilter handles the sql before posts it to database. @@ -142,11 +146,11 @@ func (d *Driver) TableFields( ctx context.Context, table string, schema ...string, ) (fields map[string]*gdb.TableField, err error) { var ( - result gdb.Result - link gdb.Link - useSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) + result gdb.Result + link gdb.Link + usedSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) ) - if link, err = d.SlaveLink(useSchema); err != nil { + if link, err = d.SlaveLink(usedSchema); err != nil { return nil, err } result, err = d.DoSelect( diff --git a/contrib/drivers/mysql/mysql_issue_test.go b/contrib/drivers/mysql/mysql_issue_test.go index 966f7c3b9..9b19dbec2 100644 --- a/contrib/drivers/mysql/mysql_issue_test.go +++ b/contrib/drivers/mysql/mysql_issue_test.go @@ -522,3 +522,114 @@ func Test_Issue2356(t *testing.T) { t.AssertEQ(one["id"].Val(), uint64(18446744073709551615)) }) } + +// https://github.com/gogf/gf/issues/2338 +func Test_Issue2338(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + table1 := "demo_" + guid.S() + table2 := "demo_" + guid.S() + if _, err := db.Schema(TestSchema1).Exec(ctx, fmt.Sprintf(` +CREATE TABLE %s ( + id int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'User ID', + nickname varchar(45) DEFAULT NULL COMMENT 'User Nickname', + create_at datetime DEFAULT NULL COMMENT 'Created Time', + update_at datetime DEFAULT NULL COMMENT 'Updated Time', + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table1, + )); err != nil { + t.AssertNil(err) + } + if _, err := db.Schema(TestSchema2).Exec(ctx, fmt.Sprintf(` +CREATE TABLE %s ( + id int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'User ID', + nickname varchar(45) DEFAULT NULL COMMENT 'User Nickname', + create_at datetime DEFAULT NULL COMMENT 'Created Time', + update_at datetime DEFAULT NULL COMMENT 'Updated Time', + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table2, + )); err != nil { + t.AssertNil(err) + } + defer dropTableWithDb(db.Schema(TestSchema1), table1) + defer dropTableWithDb(db.Schema(TestSchema2), table2) + + var err error + _, err = db.Schema(TestSchema1).Model(table1).Insert(g.Map{ + "id": 1, + "nickname": "name_1", + }) + t.AssertNil(err) + + _, err = db.Schema(TestSchema2).Model(table2).Insert(g.Map{ + "id": 1, + "nickname": "name_2", + }) + t.AssertNil(err) + + tableName1 := fmt.Sprintf(`%s.%s`, TestSchema1, table1) + tableName2 := fmt.Sprintf(`%s.%s`, TestSchema2, table2) + all, err := db.Model(tableName1).As(`a`). + LeftJoin(tableName2+" b", `a.id=b.id`). + Fields(`a.id`, `b.nickname`).All() + t.AssertNil(err) + t.Assert(len(all), 1) + t.Assert(all[0]["nickname"], "name_2") + }) + + gtest.C(t, func(t *gtest.T) { + table1 := "demo_" + guid.S() + table2 := "demo_" + guid.S() + if _, err := db.Schema(TestSchema1).Exec(ctx, fmt.Sprintf(` +CREATE TABLE %s ( + id int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'User ID', + nickname varchar(45) DEFAULT NULL COMMENT 'User Nickname', + create_at datetime DEFAULT NULL COMMENT 'Created Time', + update_at datetime DEFAULT NULL COMMENT 'Updated Time', + deleted_at datetime DEFAULT NULL COMMENT 'Deleted Time', + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table1, + )); err != nil { + t.AssertNil(err) + } + if _, err := db.Schema(TestSchema2).Exec(ctx, fmt.Sprintf(` +CREATE TABLE %s ( + id int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT 'User ID', + nickname varchar(45) DEFAULT NULL COMMENT 'User Nickname', + create_at datetime DEFAULT NULL COMMENT 'Created Time', + update_at datetime DEFAULT NULL COMMENT 'Updated Time', + deleted_at datetime DEFAULT NULL COMMENT 'Deleted Time', + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table2, + )); err != nil { + t.AssertNil(err) + } + defer dropTableWithDb(db.Schema(TestSchema1), table1) + defer dropTableWithDb(db.Schema(TestSchema2), table2) + + var err error + _, err = db.Schema(TestSchema1).Model(table1).Insert(g.Map{ + "id": 1, + "nickname": "name_1", + }) + t.AssertNil(err) + + _, err = db.Schema(TestSchema2).Model(table2).Insert(g.Map{ + "id": 1, + "nickname": "name_2", + }) + t.AssertNil(err) + + tableName1 := fmt.Sprintf(`%s.%s`, TestSchema1, table1) + tableName2 := fmt.Sprintf(`%s.%s`, TestSchema2, table2) + all, err := db.Model(tableName1).As(`a`). + LeftJoin(tableName2+" b", `a.id=b.id`). + Fields(`a.id`, `b.nickname`).All() + t.AssertNil(err) + t.Assert(len(all), 1) + t.Assert(all[0]["nickname"], "name_2") + }) +} diff --git a/contrib/drivers/oracle/oracle.go b/contrib/drivers/oracle/oracle.go index 2e2cc0777..c7dae1690 100644 --- a/contrib/drivers/oracle/oracle.go +++ b/contrib/drivers/oracle/oracle.go @@ -35,6 +35,10 @@ type Driver struct { *gdb.Core } +const ( + quoteChar = `"` +) + func init() { if err := gdb.Register(`oracle`, New()); err != nil { panic(err) @@ -106,7 +110,7 @@ func (d *Driver) Open(config *gdb.ConfigNode) (db *sql.DB, err error) { // GetChars returns the security char for this type of database. func (d *Driver) GetChars() (charLeft string, charRight string) { - return `"`, `"` + return quoteChar, quoteChar } // DoFilter deals with the sql string before commits it to underlying sql driver. @@ -217,7 +221,7 @@ func (d *Driver) TableFields( var ( result gdb.Result link gdb.Link - useSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) + usedSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) structureSql = fmt.Sprintf(` SELECT COLUMN_NAME AS FIELD, @@ -230,7 +234,7 @@ FROM USER_TAB_COLUMNS WHERE TABLE_NAME = '%s' ORDER BY COLUMN_ID`, strings.ToUpper(table), ) ) - if link, err = d.SlaveLink(useSchema); err != nil { + if link, err = d.SlaveLink(usedSchema); err != nil { return nil, err } structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql)) diff --git a/contrib/drivers/pgsql/pgsql.go b/contrib/drivers/pgsql/pgsql.go index 255d35955..375337470 100644 --- a/contrib/drivers/pgsql/pgsql.go +++ b/contrib/drivers/pgsql/pgsql.go @@ -36,6 +36,7 @@ type Driver struct { const ( internalPrimaryKeyInCtx gctx.StrKey = "primary_key" defaultSchema = "public" + quoteChar = `"` ) func init() { @@ -113,7 +114,7 @@ func (d *Driver) Open(config *gdb.ConfigNode) (db *sql.DB, err error) { // GetChars returns the security char for this type of database. func (d *Driver) GetChars() (charLeft string, charRight string) { - return `"`, `"` + return quoteChar, quoteChar } // CheckLocalTypeForField checks and returns corresponding local golang type for given db type. @@ -270,7 +271,7 @@ func (d *Driver) TableFields(ctx context.Context, table string, schema ...string var ( result gdb.Result link gdb.Link - useSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) + usedSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) structureSql = fmt.Sprintf(` SELECT a.attname AS field, t.typname AS type,a.attnotnull as null, (case when d.contype is not null then 'pri' else '' end) as key @@ -288,7 +289,7 @@ ORDER BY a.attnum`, table, ) ) - if link, err = d.SlaveLink(useSchema); err != nil { + if link, err = d.SlaveLink(usedSchema); err != nil { return nil, err } structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql)) diff --git a/contrib/drivers/sqlite/sqlite.go b/contrib/drivers/sqlite/sqlite.go index 0f4ddae76..e19b7b10d 100644 --- a/contrib/drivers/sqlite/sqlite.go +++ b/contrib/drivers/sqlite/sqlite.go @@ -33,6 +33,10 @@ type Driver struct { *gdb.Core } +const ( + quoteChar = "`" +) + func init() { if err := gdb.Register(`sqlite`, New()); err != nil { panic(err) @@ -105,7 +109,7 @@ func (d *Driver) Open(config *gdb.ConfigNode) (db *sql.DB, err error) { // GetChars returns the security char for this type of database. func (d *Driver) GetChars() (charLeft string, charRight string) { - return "`", "`" + return quoteChar, quoteChar } // DoFilter deals with the sql string before commits it to underlying sql driver. @@ -141,11 +145,11 @@ func (d *Driver) TableFields( ctx context.Context, table string, schema ...string, ) (fields map[string]*gdb.TableField, err error) { var ( - result gdb.Result - link gdb.Link - useSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) + result gdb.Result + link gdb.Link + usedSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) ) - if link, err = d.SlaveLink(useSchema); err != nil { + if link, err = d.SlaveLink(usedSchema); err != nil { return nil, err } result, err = d.DoSelect(ctx, link, fmt.Sprintf(`PRAGMA TABLE_INFO(%s)`, table)) diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 62354143d..b86dd9f61 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -125,13 +125,21 @@ func (c *Core) Close(ctx context.Context) (err error) { // Master creates and returns a connection from master node if master-slave configured. // It returns the default connection if master-slave not configured. func (c *Core) Master(schema ...string) (*sql.DB, error) { - return c.getSqlDb(true, gutil.GetOrDefaultStr(c.schema, schema...)) + var ( + usedSchema = gutil.GetOrDefaultStr(c.schema, schema...) + charL, charR = c.db.GetChars() + ) + return c.getSqlDb(true, gstr.Trim(usedSchema, charL+charR)) } // Slave creates and returns a connection from slave node if master-slave configured. // It returns the default connection if master-slave not configured. func (c *Core) Slave(schema ...string) (*sql.DB, error) { - return c.getSqlDb(false, gutil.GetOrDefaultStr(c.schema, schema...)) + var ( + usedSchema = gutil.GetOrDefaultStr(c.schema, schema...) + charL, charR = c.db.GetChars() + ) + return c.getSqlDb(false, gstr.Trim(usedSchema, charL+charR)) } // GetAll queries and returns data records from database. diff --git a/database/gdb/gdb_model_delete.go b/database/gdb/gdb_model_delete.go index 50e8f08df..32ccd0cab 100644 --- a/database/gdb/gdb_model_delete.go +++ b/database/gdb/gdb_model_delete.go @@ -30,7 +30,7 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) { } }() var ( - fieldNameDelete = m.getSoftFieldNameDeleted() + fieldNameDelete = m.getSoftFieldNameDeleted("", m.tablesInit) conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false) ) // Soft deleting. diff --git a/database/gdb/gdb_model_insert.go b/database/gdb/gdb_model_insert.go index 4414c616c..09fb488b5 100644 --- a/database/gdb/gdb_model_insert.go +++ b/database/gdb/gdb_model_insert.go @@ -255,8 +255,8 @@ func (m *Model) doInsertWithOption(ctx context.Context, insertOption int) (resul var ( list List nowString = gtime.Now().String() - fieldNameCreate = m.getSoftFieldNameCreated() - fieldNameUpdate = m.getSoftFieldNameUpdated() + fieldNameCreate = m.getSoftFieldNameCreated("", m.tablesInit) + fieldNameUpdate = m.getSoftFieldNameUpdated("", m.tablesInit) ) newData, err := m.filterDataForInsertOrUpdate(m.data) if err != nil { diff --git a/database/gdb/gdb_model_time.go b/database/gdb/gdb_model_time.go index 41511eef2..cd3671fcc 100644 --- a/database/gdb/gdb_model_time.go +++ b/database/gdb/gdb_model_time.go @@ -32,70 +32,70 @@ func (m *Model) Unscoped() *Model { // getSoftFieldNameCreate checks and returns the field name for record creating time. // If there's no field name for storing creating time, it returns an empty string. // It checks the key with or without cases or chars '-'/'_'/'.'/' '. -func (m *Model) getSoftFieldNameCreated(table ...string) string { +func (m *Model) getSoftFieldNameCreated(schema string, table string) string { // It checks whether this feature disabled. if m.db.GetConfig().TimeMaintainDisabled { return "" } tableName := "" - if len(table) > 0 { - tableName = table[0] + if table != "" { + tableName = table } else { tableName = m.tablesInit } config := m.db.GetConfig() if config.CreatedAt != "" { - return m.getSoftFieldName(tableName, []string{config.CreatedAt}) + return m.getSoftFieldName(schema, tableName, []string{config.CreatedAt}) } - return m.getSoftFieldName(tableName, createdFiledNames) + return m.getSoftFieldName(schema, tableName, createdFiledNames) } // getSoftFieldNameUpdate checks and returns the field name for record updating time. // If there's no field name for storing updating time, it returns an empty string. // It checks the key with or without cases or chars '-'/'_'/'.'/' '. -func (m *Model) getSoftFieldNameUpdated(table ...string) (field string) { +func (m *Model) getSoftFieldNameUpdated(schema string, table string) (field string) { // It checks whether this feature disabled. if m.db.GetConfig().TimeMaintainDisabled { return "" } tableName := "" - if len(table) > 0 { - tableName = table[0] + if table != "" { + tableName = table } else { tableName = m.tablesInit } config := m.db.GetConfig() if config.UpdatedAt != "" { - return m.getSoftFieldName(tableName, []string{config.UpdatedAt}) + return m.getSoftFieldName(schema, tableName, []string{config.UpdatedAt}) } - return m.getSoftFieldName(tableName, updatedFiledNames) + return m.getSoftFieldName(schema, tableName, updatedFiledNames) } // getSoftFieldNameDelete checks and returns the field name for record deleting time. // If there's no field name for storing deleting time, it returns an empty string. // It checks the key with or without cases or chars '-'/'_'/'.'/' '. -func (m *Model) getSoftFieldNameDeleted(table ...string) (field string) { +func (m *Model) getSoftFieldNameDeleted(schema string, table string) (field string) { // It checks whether this feature disabled. if m.db.GetConfig().TimeMaintainDisabled { return "" } tableName := "" - if len(table) > 0 { - tableName = table[0] + if table != "" { + tableName = table } else { tableName = m.tablesInit } config := m.db.GetConfig() if config.DeletedAt != "" { - return m.getSoftFieldName(tableName, []string{config.DeletedAt}) + return m.getSoftFieldName(schema, tableName, []string{config.DeletedAt}) } - return m.getSoftFieldName(tableName, deletedFiledNames) + return m.getSoftFieldName(schema, tableName, deletedFiledNames) } // getSoftFieldName retrieves and returns the field name of the table for possible key. -func (m *Model) getSoftFieldName(table string, keys []string) (field string) { +func (m *Model) getSoftFieldName(schema string, table string, keys []string) (field string) { // Ignore the error from TableFields. - fieldsMap, _ := m.TableFields(table) + fieldsMap, _ := m.TableFields(table, schema) if len(fieldsMap) > 0 { for _, key := range keys { field, _ = gutil.MapPossibleItemByKey( @@ -141,26 +141,33 @@ func (m *Model) getConditionForSoftDeleting() string { return conditionArray.Join(" AND ") } // Only one table. - if fieldName := m.getSoftFieldNameDeleted(); fieldName != "" { + if fieldName := m.getSoftFieldNameDeleted("", m.tablesInit); fieldName != "" { return fmt.Sprintf(`%s IS NULL`, m.db.GetCore().QuoteWord(fieldName)) } return "" } // getConditionOfTableStringForSoftDeleting does something as its name describes. +// Examples for `s`: +// - `test`.`demo` as b +// - `test`.`demo` b +// - `demo` +// - demo func (m *Model) getConditionOfTableStringForSoftDeleting(s string) string { var ( field string table string + schema string array1 = gstr.SplitAndTrim(s, " ") array2 = gstr.SplitAndTrim(array1[0], ".") ) if len(array2) >= 2 { table = array2[1] + schema = array2[0] } else { table = array2[0] } - field = m.getSoftFieldNameDeleted(table) + field = m.getSoftFieldNameDeleted(schema, table) if field == "" { return "" } diff --git a/database/gdb/gdb_model_update.go b/database/gdb/gdb_model_update.go index 59fb7a3cb..3e27ba30e 100644 --- a/database/gdb/gdb_model_update.go +++ b/database/gdb/gdb_model_update.go @@ -46,7 +46,7 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro var ( updateData = m.data reflectInfo = reflection.OriginTypeAndKind(updateData) - fieldNameUpdate = m.getSoftFieldNameUpdated() + fieldNameUpdate = m.getSoftFieldNameUpdated("", m.tablesInit) conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false) ) switch reflectInfo.OriginKind {