From 597f7468e97948b142d7b3984801ede4afbccedd Mon Sep 17 00:00:00 2001 From: John Date: Mon, 23 Dec 2019 23:14:54 +0800 Subject: [PATCH] fix issue in prefix feature for method operations of gdb --- .example/database/gdb/mysql/config.toml | 2 +- .example/database/gdb/mysql/gdb_debug1.go | 7 +- .example/database/gdb/mysql/gdb_debug2.go | 1 + database/gdb/gdb.go | 27 +++-- database/gdb/gdb_base.go | 23 ++-- database/gdb/gdb_func.go | 42 ++++---- database/gdb/gdb_model.go | 18 +--- database/gdb/gdb_unit_z_func_test.go | 28 ++--- database/gdb/gdb_unit_z_mysql_method_test.go | 104 +++++++++++++++++++ 9 files changed, 180 insertions(+), 72 deletions(-) diff --git a/.example/database/gdb/mysql/config.toml b/.example/database/gdb/mysql/config.toml index b93aa25e9..d63c06afc 100644 --- a/.example/database/gdb/mysql/config.toml +++ b/.example/database/gdb/mysql/config.toml @@ -1,7 +1,7 @@ # MySQL数据库配置 [database] - debug = true +# debug = true link = "mysql:root:12345678@tcp(127.0.0.1:3306)/test?parseTime=true&loc=Local" #[database] diff --git a/.example/database/gdb/mysql/gdb_debug1.go b/.example/database/gdb/mysql/gdb_debug1.go index fa517dc0c..9a3bc10e3 100644 --- a/.example/database/gdb/mysql/gdb_debug1.go +++ b/.example/database/gdb/mysql/gdb_debug1.go @@ -1,8 +1,6 @@ package main import ( - "fmt" - "github.com/gogf/gf/database/gdb" "github.com/gogf/gf/frame/g" "github.com/gogf/gf/os/glog" @@ -23,7 +21,7 @@ func main() { if err != nil { panic(err) } - db.SetDebug(true) + //db.SetDebug(false) glog.SetPath("/tmp") @@ -36,7 +34,4 @@ func main() { db.Table("user").Data(g.Map{"name": "smith"}).Where("uid=?", 1).Save() - db.PrintQueriedSqls() - - fmt.Println(db.GetLastSql()) } diff --git a/.example/database/gdb/mysql/gdb_debug2.go b/.example/database/gdb/mysql/gdb_debug2.go index 720e04a25..0ed29c2a7 100644 --- a/.example/database/gdb/mysql/gdb_debug2.go +++ b/.example/database/gdb/mysql/gdb_debug2.go @@ -6,6 +6,7 @@ import ( func main() { db := g.DB() + // 执行3条SQL查询 for i := 1; i <= 3; i++ { db.Table("user").Where("id=?", i).One() diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index f07f1b106..cb214eeb4 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -22,17 +22,18 @@ import ( "github.com/gogf/gf/util/grand" ) -// 数据库操作接口 +// DB is the interface for ORM operations. type DB interface { - // 建立数据库连接方法(开发者一般不需要直接调用) + // Open creates a raw connection object for database with given node configuration. + // Note that it is not recommended using the this function manually. Open(config *ConfigNode) (*sql.DB, error) - // SQL操作方法 API + // Query APIs. Query(query string, args ...interface{}) (*sql.Rows, error) Exec(sql string, args ...interface{}) (sql.Result, error) Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error) - // 内部实现API的方法(不同数据库可覆盖这些方法实现自定义的操作) + // Internal APIs for CURD, which can be overwrote for custom CURD implements. doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) doGetAll(link dbLink, query string, args ...interface{}) (result Result, err error) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) @@ -42,7 +43,7 @@ type DB interface { doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) doDelete(link dbLink, table string, condition string, args ...interface{}) (result sql.Result, err error) - // 数据库查询 + // Query APIs for convenience purpose. GetAll(query string, args ...interface{}) (Result, error) GetOne(query string, args ...interface{}) (Record, error) GetValue(query string, args ...interface{}) (Value, error) @@ -51,36 +52,33 @@ type DB interface { GetStructs(objPointerSlice interface{}, query string, args ...interface{}) error GetScan(objPointer interface{}, query string, args ...interface{}) error - // 创建底层数据库master/slave链接对象 + // Master/Slave support. Master() (*sql.DB, error) Slave() (*sql.DB, error) - // Ping + // Ping. PingMaster() error PingSlave() error - // 开启事务操作 + // Transaction. Begin() (*TX, error) - // 数据表插入/更新/保存操作 Insert(table string, data interface{}, batch ...int) (sql.Result, error) Replace(table string, data interface{}, batch ...int) (sql.Result, error) Save(table string, data interface{}, batch ...int) (sql.Result, error) - // 数据表插入/更新/保存操作(批量) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) - // 数据修改/删除 Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) - // 创建链式操作对象 + // Create model. From(tables string) *Model Table(tables string) *Model - // 设置管理 + // Configuration methods. SetDebug(debug bool) SetSchema(schema string) SetLogger(logger *glog.Logger) @@ -91,13 +89,14 @@ type DB interface { Tables() (tables []string, err error) TableFields(table string) (map[string]*TableField, error) - // 内部方法接口 + // Internal methods. getCache() *gcache.Cache getChars() (charLeft string, charRight string) getDebug() bool getPrefix() string quoteWord(s string) string quoteString(s string) string + handleTableName(table string) string doSetSchema(sqlDb *sql.DB, schema string) error filterFields(table string, data map[string]interface{}) map[string]interface{} convertValue(fieldValue []byte, fieldType string) interface{} diff --git a/database/gdb/gdb_base.go b/database/gdb/gdb_base.go index b734bff0b..0b1a1a6d9 100644 --- a/database/gdb/gdb_base.go +++ b/database/gdb/gdb_base.go @@ -56,9 +56,9 @@ func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows query, args = formatQuery(query, args) query = bs.db.handleSqlBeforeExec(query) if bs.db.getDebug() { - mTime1 := gtime.Millisecond() + mTime1 := gtime.TimestampMicro() rows, err = link.Query(query, args...) - mTime2 := gtime.Millisecond() + mTime2 := gtime.TimestampMicro() s := &Sql{ Sql: query, Args: args, @@ -302,7 +302,7 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i var values []string var params []interface{} var dataMap Map - table = bs.db.quoteWord(table) + table = bs.db.handleTableName(table) // 使用反射判断data数据类型,如果为slice类型,那么自动转为批量操作 rv := reflect.ValueOf(data) kind := rv.Kind() @@ -371,7 +371,7 @@ func (bs *dbBase) BatchSave(table string, list interface{}, batch ...int) (sql.R func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { var keys, values []string var params []interface{} - table = bs.db.quoteWord(table) + table = bs.db.handleTableName(table) listMap := (List)(nil) switch v := list.(type) { case Result: @@ -491,7 +491,7 @@ func (bs *dbBase) Update(table string, data interface{}, condition interface{}, // CURD操作:数据更新,统一采用sql预处理。 // data参数支持string/map/struct/*struct类型类型。 func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) { - table = bs.db.quoteWord(table) + table = bs.db.handleTableName(table) updates := "" // 使用反射进行类型判断 rv := reflect.ValueOf(data) @@ -543,7 +543,7 @@ func (bs *dbBase) doDelete(link dbLink, table string, condition string, args ... return nil, err } } - table = bs.db.quoteWord(table) + table = bs.db.handleTableName(table) return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...) } @@ -605,6 +605,17 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) { return records, nil } +// handleTableName adds prefix string and quote chars for the table. It handles table string like: +// "user", "user u", "user,user_detail", "user u, user_detail ut", "user as u, user_detail as ut". +// +// Note that, this will automatically checks the table prefix whether already added, if true it does +// nothing to the table name, or else adds the prefix to the table name. +func (bs *dbBase) handleTableName(table string) string { + charLeft, charRight := bs.db.getChars() + prefix := bs.db.getPrefix() + return doHandleTableName(table, prefix, charLeft, charRight) +} + // quoteWord checks given string a word, if true quotes it with security chars of the database // and returns the quoted string; or else return without any change. func (bs *dbBase) quoteWord(s string) string { diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 9a0028d91..0aef26d34 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -51,8 +51,31 @@ var ( quoteWordReg = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`) ) +// handleTableName adds prefix string and quote chars for the table. It handles table string like: +// "user", "user u", "user,user_detail", "user u, user_detail ut", "user as u, user_detail as ut". +// +// Note that, this will automatically checks the table prefix whether already added, if true it does +// nothing to the table name, or else adds the prefix to the table name. +func doHandleTableName(table, prefix, charLeft, charRight string) string { + array1 := gstr.SplitAndTrim(table, ",") + for k1, v1 := range array1 { + array2 := gstr.SplitAndTrim(v1, " ") + // Trim the security chars. + array2[0] = gstr.TrimLeftStr(array2[0], charLeft) + array2[0] = gstr.TrimRightStr(array2[0], charRight) + // If the table name already has the prefix, skips the prefix adding. + if len(array2[0]) <= len(prefix) || array2[0][:len(prefix)] != prefix { + array2[0] = prefix + array2[0] + } + // Add the security chars. + array2[0] = doQuoteWord(array2[0], charLeft, charRight) + array1[k1] = gstr.Join(array2, " ") + } + return gstr.Join(array1, ",") +} + // doQuoteWord checks given string a word, if true quotes it with and -// and returns the quoted string; or else return without any change. +// and returns the quoted string; or else returns without any change. func doQuoteWord(s, charLeft, charRight string) string { if quoteWordReg.MatchString(s) && !gstr.ContainsAny(s, charLeft+charRight) { return charLeft + s + charRight @@ -78,23 +101,6 @@ func doQuoteString(s, charLeft, charRight string) string { return gstr.Join(array1, ",") } -// addTablePrefix adds prefix string to the table. It handles table string like: -// "user", "user u", "user,user_detail", "user u, user_detail ut", "user as u, user_detail as ut". -// -// Note that, this should be used before any quoting function calls. -func addTablePrefix(table, prefix string) string { - if prefix == "" { - return table - } - array1 := gstr.SplitAndTrim(table, ",") - for k1, v1 := range array1 { - array2 := gstr.SplitAndTrim(v1, " ") - array2[0] = prefix + array2[0] - array1[k1] = gstr.Join(array2, " ") - } - return gstr.Join(array1, ",") -} - // GetWhereConditionOfStruct returns the where condition sql and arguments by given struct pointer. // This function automatically retrieves primary or unique field and its attribute value as condition. func GetWhereConditionOfStruct(pointer interface{}) (where string, args []interface{}) { diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index d2b7b58eb..4f92e0d2b 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -67,8 +67,7 @@ const ( // The parameter can be more than one table names, like : // "user", "user u", "user, user_detail", "user u, user_detail ud" func (bs *dbBase) Table(table string) *Model { - table = addTablePrefix(table, bs.db.getPrefix()) - table = bs.db.quoteString(table) + table = bs.db.handleTableName(table) return &Model{ db: bs.db, tablesInit: table, @@ -90,8 +89,7 @@ func (bs *dbBase) From(tables string) *Model { // Table acts like dbBase.Table except it operates on transaction. // See dbBase.Table. func (tx *TX) Table(table string) *Model { - table = addTablePrefix(table, tx.db.getPrefix()) - table = tx.db.quoteString(table) + table = tx.db.handleTableName(table) return &Model{ db: tx.db, tx: tx, @@ -186,27 +184,21 @@ func (m *Model) getModel() *Model { // LeftJoin does "LEFT JOIN ... ON ..." statement on the model. func (m *Model) LeftJoin(table string, on string) *Model { model := m.getModel() - table = addTablePrefix(table, m.db.getPrefix()) - table = m.db.quoteString(table) - model.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", table, on) + model.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", m.db.handleTableName(table), on) return model } // RightJoin does "RIGHT JOIN ... ON ..." statement on the model. func (m *Model) RightJoin(table string, on string) *Model { model := m.getModel() - table = addTablePrefix(table, m.db.getPrefix()) - table = m.db.quoteString(table) - model.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", table, on) + model.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", m.db.handleTableName(table), on) return model } // InnerJoin does "INNER JOIN ... ON ..." statement on the model. func (m *Model) InnerJoin(table string, on string) *Model { model := m.getModel() - table = addTablePrefix(table, m.db.getPrefix()) - table = m.db.quoteString(table) - model.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", table, on) + model.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", m.db.handleTableName(table), on) return model } diff --git a/database/gdb/gdb_unit_z_func_test.go b/database/gdb/gdb_unit_z_func_test.go index c4af09232..645b5de85 100644 --- a/database/gdb/gdb_unit_z_func_test.go +++ b/database/gdb/gdb_unit_z_func_test.go @@ -49,29 +49,29 @@ func Test_Func_addTablePrefix(t *testing.T) { gtest.Case(t, func() { prefix := "" array := map[string]string{ - "user": "user", - "user u": "user u", - "user as u": "user as u", - "user,user_detail": "user,user_detail", - "user u, user_detail ut": "user u, user_detail ut", - "user as u, user_detail as ut": "user as u, user_detail as ut", + "user": "`user`", + "user u": "`user` u", + "user as u": "`user` as u", + "user,user_detail": "`user`,`user_detail`", + "user u, user_detail ut": "`user` u,`user_detail` ut", + "user as u, user_detail as ut": "`user` as u,`user_detail` as ut", } for k, v := range array { - gtest.Assert(addTablePrefix(k, prefix), v) + gtest.Assert(doHandleTableName(k, prefix, "`", "`"), v) } }) gtest.Case(t, func() { prefix := "gf_" array := map[string]string{ - "user": "gf_user", - "user u": "gf_user u", - "user as u": "gf_user as u", - "user,user_detail": "gf_user,gf_user_detail", - "user u, user_detail ut": "gf_user u,gf_user_detail ut", - "user as u, user_detail as ut": "gf_user as u,gf_user_detail as ut", + "user": "`gf_user`", + "user u": "`gf_user` u", + "user as u": "`gf_user` as u", + "user,user_detail": "`gf_user`,`gf_user_detail`", + "user u, user_detail ut": "`gf_user` u,`gf_user_detail` ut", + "user as u, user_detail as ut": "`gf_user` as u,`gf_user_detail` as ut", } for k, v := range array { - gtest.Assert(addTablePrefix(k, prefix), v) + gtest.Assert(doHandleTableName(k, prefix, "`", "`"), v) } }) } diff --git a/database/gdb/gdb_unit_z_mysql_method_test.go b/database/gdb/gdb_unit_z_mysql_method_test.go index 721892786..b150c847b 100644 --- a/database/gdb/gdb_unit_z_mysql_method_test.go +++ b/database/gdb/gdb_unit_z_mysql_method_test.go @@ -8,6 +8,7 @@ package gdb_test import ( "fmt" + "github.com/gogf/gf/container/garray" "testing" "time" @@ -1052,6 +1053,109 @@ func Test_DB_TableField(t *testing.T) { gtest.Assert(result[0], data) } +func Test_DB_Prefix(t *testing.T) { + db := dbPrefix + name := fmt.Sprintf(`%s_%d`, TABLE, gtime.TimestampNano()) + table := PREFIX1 + name + createTableWithDb(db, table) + defer dropTable(table) + + gtest.Case(t, func() { + id := 10000 + result, err := db.Insert(name, g.Map{ + "id": id, + "passport": fmt.Sprintf(`user_%d`, id), + "password": fmt.Sprintf(`pass_%d`, id), + "nickname": fmt.Sprintf(`name_%d`, id), + "create_time": gtime.NewFromStr("2018-10-24 10:00:00").String(), + }) + gtest.Assert(err, nil) + + n, e := result.RowsAffected() + gtest.Assert(e, nil) + gtest.Assert(n, 1) + }) + + gtest.Case(t, func() { + id := 10000 + result, err := db.Replace(name, g.Map{ + "id": id, + "passport": fmt.Sprintf(`user_%d`, id), + "password": fmt.Sprintf(`pass_%d`, id), + "nickname": fmt.Sprintf(`name_%d`, id), + "create_time": gtime.NewFromStr("2018-10-24 10:00:01").String(), + }) + gtest.Assert(err, nil) + + n, e := result.RowsAffected() + gtest.Assert(e, nil) + gtest.Assert(n, 2) + }) + + gtest.Case(t, func() { + id := 10000 + result, err := db.Save(name, g.Map{ + "id": id, + "passport": fmt.Sprintf(`user_%d`, id), + "password": fmt.Sprintf(`pass_%d`, id), + "nickname": fmt.Sprintf(`name_%d`, id), + "create_time": gtime.NewFromStr("2018-10-24 10:00:02").String(), + }) + gtest.Assert(err, nil) + + n, e := result.RowsAffected() + gtest.Assert(e, nil) + gtest.Assert(n, 2) + }) + + gtest.Case(t, func() { + id := 10000 + result, err := db.Update(name, g.Map{ + "id": id, + "passport": fmt.Sprintf(`user_%d`, id), + "password": fmt.Sprintf(`pass_%d`, id), + "nickname": fmt.Sprintf(`name_%d`, id), + "create_time": gtime.NewFromStr("2018-10-24 10:00:03").String(), + }, "id=?", id) + gtest.Assert(err, nil) + + n, e := result.RowsAffected() + gtest.Assert(e, nil) + gtest.Assert(n, 1) + }) + + gtest.Case(t, func() { + id := 10000 + result, err := db.Delete(name, "id=?", id) + gtest.Assert(err, nil) + + n, e := result.RowsAffected() + gtest.Assert(e, nil) + gtest.Assert(n, 1) + }) + + gtest.Case(t, func() { + array := garray.New(true) + for i := 1; i <= INIT_DATA_SIZE; i++ { + array.Append(g.Map{ + "id": i, + "passport": fmt.Sprintf(`user_%d`, i), + "password": fmt.Sprintf(`pass_%d`, i), + "nickname": fmt.Sprintf(`name_%d`, i), + "create_time": gtime.NewFromStr("2018-10-24 10:00:00").String(), + }) + } + + result, err := db.BatchInsert(name, array.Slice()) + gtest.Assert(err, nil) + + n, e := result.RowsAffected() + gtest.Assert(e, nil) + gtest.Assert(n, INIT_DATA_SIZE) + }) + +} + func Test_Model_InnerJoin(t *testing.T) { gtest.Case(t, func() { table1 := createInitTable("user1")