fix issue in prefix feature for method operations of gdb

This commit is contained in:
John
2019-12-23 23:14:54 +08:00
parent 5db8851213
commit 597f7468e9
9 changed files with 180 additions and 72 deletions

View File

@ -1,7 +1,7 @@
# MySQL数据库配置
[database]
debug = true
# debug = true
link = "mysql:root:12345678@tcp(127.0.0.1:3306)/test?parseTime=true&loc=Local"
#[database]

View File

@ -1,8 +1,6 @@
package main
import (
"fmt"
"github.com/gogf/gf/database/gdb"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/os/glog"
@ -23,7 +21,7 @@ func main() {
if err != nil {
panic(err)
}
db.SetDebug(true)
//db.SetDebug(false)
glog.SetPath("/tmp")
@ -36,7 +34,4 @@ func main() {
db.Table("user").Data(g.Map{"name": "smith"}).Where("uid=?", 1).Save()
db.PrintQueriedSqls()
fmt.Println(db.GetLastSql())
}

View File

@ -6,6 +6,7 @@ import (
func main() {
db := g.DB()
// 执行3条SQL查询
for i := 1; i <= 3; i++ {
db.Table("user").Where("id=?", i).One()

View File

@ -22,17 +22,18 @@ import (
"github.com/gogf/gf/util/grand"
)
// 数据库操作接口
// DB is the interface for ORM operations.
type DB interface {
// 建立数据库连接方法(开发者一般不需要直接调用)
// Open creates a raw connection object for database with given node configuration.
// Note that it is not recommended using the this function manually.
Open(config *ConfigNode) (*sql.DB, error)
// SQL操作方法 API
// Query APIs.
Query(query string, args ...interface{}) (*sql.Rows, error)
Exec(sql string, args ...interface{}) (sql.Result, error)
Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error)
// 内部实现API的方法(不同数据库可覆盖这些方法实现自定义的操作)
// Internal APIs for CURD, which can be overwrote for custom CURD implements.
doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error)
doGetAll(link dbLink, query string, args ...interface{}) (result Result, err error)
doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error)
@ -42,7 +43,7 @@ type DB interface {
doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error)
doDelete(link dbLink, table string, condition string, args ...interface{}) (result sql.Result, err error)
// 数据库查询
// Query APIs for convenience purpose.
GetAll(query string, args ...interface{}) (Result, error)
GetOne(query string, args ...interface{}) (Record, error)
GetValue(query string, args ...interface{}) (Value, error)
@ -51,36 +52,33 @@ type DB interface {
GetStructs(objPointerSlice interface{}, query string, args ...interface{}) error
GetScan(objPointer interface{}, query string, args ...interface{}) error
// 创建底层数据库master/slave链接对象
// Master/Slave support.
Master() (*sql.DB, error)
Slave() (*sql.DB, error)
// Ping
// Ping.
PingMaster() error
PingSlave() error
// 开启事务操作
// Transaction.
Begin() (*TX, error)
// 数据表插入/更新/保存操作
Insert(table string, data interface{}, batch ...int) (sql.Result, error)
Replace(table string, data interface{}, batch ...int) (sql.Result, error)
Save(table string, data interface{}, batch ...int) (sql.Result, error)
// 数据表插入/更新/保存操作(批量)
BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error)
BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error)
BatchSave(table string, list interface{}, batch ...int) (sql.Result, error)
// 数据修改/删除
Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error)
Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error)
// 创建链式操作对象
// Create model.
From(tables string) *Model
Table(tables string) *Model
// 设置管理
// Configuration methods.
SetDebug(debug bool)
SetSchema(schema string)
SetLogger(logger *glog.Logger)
@ -91,13 +89,14 @@ type DB interface {
Tables() (tables []string, err error)
TableFields(table string) (map[string]*TableField, error)
// 内部方法接口
// Internal methods.
getCache() *gcache.Cache
getChars() (charLeft string, charRight string)
getDebug() bool
getPrefix() string
quoteWord(s string) string
quoteString(s string) string
handleTableName(table string) string
doSetSchema(sqlDb *sql.DB, schema string) error
filterFields(table string, data map[string]interface{}) map[string]interface{}
convertValue(fieldValue []byte, fieldType string) interface{}

View File

@ -56,9 +56,9 @@ func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows
query, args = formatQuery(query, args)
query = bs.db.handleSqlBeforeExec(query)
if bs.db.getDebug() {
mTime1 := gtime.Millisecond()
mTime1 := gtime.TimestampMicro()
rows, err = link.Query(query, args...)
mTime2 := gtime.Millisecond()
mTime2 := gtime.TimestampMicro()
s := &Sql{
Sql: query,
Args: args,
@ -302,7 +302,7 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i
var values []string
var params []interface{}
var dataMap Map
table = bs.db.quoteWord(table)
table = bs.db.handleTableName(table)
// 使用反射判断data数据类型如果为slice类型那么自动转为批量操作
rv := reflect.ValueOf(data)
kind := rv.Kind()
@ -371,7 +371,7 @@ func (bs *dbBase) BatchSave(table string, list interface{}, batch ...int) (sql.R
func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
var keys, values []string
var params []interface{}
table = bs.db.quoteWord(table)
table = bs.db.handleTableName(table)
listMap := (List)(nil)
switch v := list.(type) {
case Result:
@ -491,7 +491,7 @@ func (bs *dbBase) Update(table string, data interface{}, condition interface{},
// CURD操作:数据更新统一采用sql预处理。
// data参数支持string/map/struct/*struct类型类型。
func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) {
table = bs.db.quoteWord(table)
table = bs.db.handleTableName(table)
updates := ""
// 使用反射进行类型判断
rv := reflect.ValueOf(data)
@ -543,7 +543,7 @@ func (bs *dbBase) doDelete(link dbLink, table string, condition string, args ...
return nil, err
}
}
table = bs.db.quoteWord(table)
table = bs.db.handleTableName(table)
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
}
@ -605,6 +605,17 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
return records, nil
}
// handleTableName adds prefix string and quote chars for the table. It handles table string like:
// "user", "user u", "user,user_detail", "user u, user_detail ut", "user as u, user_detail as ut".
//
// Note that, this will automatically checks the table prefix whether already added, if true it does
// nothing to the table name, or else adds the prefix to the table name.
func (bs *dbBase) handleTableName(table string) string {
charLeft, charRight := bs.db.getChars()
prefix := bs.db.getPrefix()
return doHandleTableName(table, prefix, charLeft, charRight)
}
// quoteWord checks given string <s> a word, if true quotes it with security chars of the database
// and returns the quoted string; or else return <s> without any change.
func (bs *dbBase) quoteWord(s string) string {

View File

@ -51,8 +51,31 @@ var (
quoteWordReg = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`)
)
// handleTableName adds prefix string and quote chars for the table. It handles table string like:
// "user", "user u", "user,user_detail", "user u, user_detail ut", "user as u, user_detail as ut".
//
// Note that, this will automatically checks the table prefix whether already added, if true it does
// nothing to the table name, or else adds the prefix to the table name.
func doHandleTableName(table, prefix, charLeft, charRight string) string {
array1 := gstr.SplitAndTrim(table, ",")
for k1, v1 := range array1 {
array2 := gstr.SplitAndTrim(v1, " ")
// Trim the security chars.
array2[0] = gstr.TrimLeftStr(array2[0], charLeft)
array2[0] = gstr.TrimRightStr(array2[0], charRight)
// If the table name already has the prefix, skips the prefix adding.
if len(array2[0]) <= len(prefix) || array2[0][:len(prefix)] != prefix {
array2[0] = prefix + array2[0]
}
// Add the security chars.
array2[0] = doQuoteWord(array2[0], charLeft, charRight)
array1[k1] = gstr.Join(array2, " ")
}
return gstr.Join(array1, ",")
}
// doQuoteWord checks given string <s> a word, if true quotes it with <charLeft> and <charRight>
// and returns the quoted string; or else return <s> without any change.
// and returns the quoted string; or else returns <s> without any change.
func doQuoteWord(s, charLeft, charRight string) string {
if quoteWordReg.MatchString(s) && !gstr.ContainsAny(s, charLeft+charRight) {
return charLeft + s + charRight
@ -78,23 +101,6 @@ func doQuoteString(s, charLeft, charRight string) string {
return gstr.Join(array1, ",")
}
// addTablePrefix adds prefix string to the table. It handles table string like:
// "user", "user u", "user,user_detail", "user u, user_detail ut", "user as u, user_detail as ut".
//
// Note that, this should be used before any quoting function calls.
func addTablePrefix(table, prefix string) string {
if prefix == "" {
return table
}
array1 := gstr.SplitAndTrim(table, ",")
for k1, v1 := range array1 {
array2 := gstr.SplitAndTrim(v1, " ")
array2[0] = prefix + array2[0]
array1[k1] = gstr.Join(array2, " ")
}
return gstr.Join(array1, ",")
}
// GetWhereConditionOfStruct returns the where condition sql and arguments by given struct pointer.
// This function automatically retrieves primary or unique field and its attribute value as condition.
func GetWhereConditionOfStruct(pointer interface{}) (where string, args []interface{}) {

View File

@ -67,8 +67,7 @@ const (
// The parameter <tables> can be more than one table names, like :
// "user", "user u", "user, user_detail", "user u, user_detail ud"
func (bs *dbBase) Table(table string) *Model {
table = addTablePrefix(table, bs.db.getPrefix())
table = bs.db.quoteString(table)
table = bs.db.handleTableName(table)
return &Model{
db: bs.db,
tablesInit: table,
@ -90,8 +89,7 @@ func (bs *dbBase) From(tables string) *Model {
// Table acts like dbBase.Table except it operates on transaction.
// See dbBase.Table.
func (tx *TX) Table(table string) *Model {
table = addTablePrefix(table, tx.db.getPrefix())
table = tx.db.quoteString(table)
table = tx.db.handleTableName(table)
return &Model{
db: tx.db,
tx: tx,
@ -186,27 +184,21 @@ func (m *Model) getModel() *Model {
// LeftJoin does "LEFT JOIN ... ON ..." statement on the model.
func (m *Model) LeftJoin(table string, on string) *Model {
model := m.getModel()
table = addTablePrefix(table, m.db.getPrefix())
table = m.db.quoteString(table)
model.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", table, on)
model.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", m.db.handleTableName(table), on)
return model
}
// RightJoin does "RIGHT JOIN ... ON ..." statement on the model.
func (m *Model) RightJoin(table string, on string) *Model {
model := m.getModel()
table = addTablePrefix(table, m.db.getPrefix())
table = m.db.quoteString(table)
model.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", table, on)
model.tables += fmt.Sprintf(" RIGHT JOIN %s ON (%s)", m.db.handleTableName(table), on)
return model
}
// InnerJoin does "INNER JOIN ... ON ..." statement on the model.
func (m *Model) InnerJoin(table string, on string) *Model {
model := m.getModel()
table = addTablePrefix(table, m.db.getPrefix())
table = m.db.quoteString(table)
model.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", table, on)
model.tables += fmt.Sprintf(" INNER JOIN %s ON (%s)", m.db.handleTableName(table), on)
return model
}

View File

@ -49,29 +49,29 @@ func Test_Func_addTablePrefix(t *testing.T) {
gtest.Case(t, func() {
prefix := ""
array := map[string]string{
"user": "user",
"user u": "user u",
"user as u": "user as u",
"user,user_detail": "user,user_detail",
"user u, user_detail ut": "user u, user_detail ut",
"user as u, user_detail as ut": "user as u, user_detail as ut",
"user": "`user`",
"user u": "`user` u",
"user as u": "`user` as u",
"user,user_detail": "`user`,`user_detail`",
"user u, user_detail ut": "`user` u,`user_detail` ut",
"user as u, user_detail as ut": "`user` as u,`user_detail` as ut",
}
for k, v := range array {
gtest.Assert(addTablePrefix(k, prefix), v)
gtest.Assert(doHandleTableName(k, prefix, "`", "`"), v)
}
})
gtest.Case(t, func() {
prefix := "gf_"
array := map[string]string{
"user": "gf_user",
"user u": "gf_user u",
"user as u": "gf_user as u",
"user,user_detail": "gf_user,gf_user_detail",
"user u, user_detail ut": "gf_user u,gf_user_detail ut",
"user as u, user_detail as ut": "gf_user as u,gf_user_detail as ut",
"user": "`gf_user`",
"user u": "`gf_user` u",
"user as u": "`gf_user` as u",
"user,user_detail": "`gf_user`,`gf_user_detail`",
"user u, user_detail ut": "`gf_user` u,`gf_user_detail` ut",
"user as u, user_detail as ut": "`gf_user` as u,`gf_user_detail` as ut",
}
for k, v := range array {
gtest.Assert(addTablePrefix(k, prefix), v)
gtest.Assert(doHandleTableName(k, prefix, "`", "`"), v)
}
})
}

View File

@ -8,6 +8,7 @@ package gdb_test
import (
"fmt"
"github.com/gogf/gf/container/garray"
"testing"
"time"
@ -1052,6 +1053,109 @@ func Test_DB_TableField(t *testing.T) {
gtest.Assert(result[0], data)
}
func Test_DB_Prefix(t *testing.T) {
db := dbPrefix
name := fmt.Sprintf(`%s_%d`, TABLE, gtime.TimestampNano())
table := PREFIX1 + name
createTableWithDb(db, table)
defer dropTable(table)
gtest.Case(t, func() {
id := 10000
result, err := db.Insert(name, g.Map{
"id": id,
"passport": fmt.Sprintf(`user_%d`, id),
"password": fmt.Sprintf(`pass_%d`, id),
"nickname": fmt.Sprintf(`name_%d`, id),
"create_time": gtime.NewFromStr("2018-10-24 10:00:00").String(),
})
gtest.Assert(err, nil)
n, e := result.RowsAffected()
gtest.Assert(e, nil)
gtest.Assert(n, 1)
})
gtest.Case(t, func() {
id := 10000
result, err := db.Replace(name, g.Map{
"id": id,
"passport": fmt.Sprintf(`user_%d`, id),
"password": fmt.Sprintf(`pass_%d`, id),
"nickname": fmt.Sprintf(`name_%d`, id),
"create_time": gtime.NewFromStr("2018-10-24 10:00:01").String(),
})
gtest.Assert(err, nil)
n, e := result.RowsAffected()
gtest.Assert(e, nil)
gtest.Assert(n, 2)
})
gtest.Case(t, func() {
id := 10000
result, err := db.Save(name, g.Map{
"id": id,
"passport": fmt.Sprintf(`user_%d`, id),
"password": fmt.Sprintf(`pass_%d`, id),
"nickname": fmt.Sprintf(`name_%d`, id),
"create_time": gtime.NewFromStr("2018-10-24 10:00:02").String(),
})
gtest.Assert(err, nil)
n, e := result.RowsAffected()
gtest.Assert(e, nil)
gtest.Assert(n, 2)
})
gtest.Case(t, func() {
id := 10000
result, err := db.Update(name, g.Map{
"id": id,
"passport": fmt.Sprintf(`user_%d`, id),
"password": fmt.Sprintf(`pass_%d`, id),
"nickname": fmt.Sprintf(`name_%d`, id),
"create_time": gtime.NewFromStr("2018-10-24 10:00:03").String(),
}, "id=?", id)
gtest.Assert(err, nil)
n, e := result.RowsAffected()
gtest.Assert(e, nil)
gtest.Assert(n, 1)
})
gtest.Case(t, func() {
id := 10000
result, err := db.Delete(name, "id=?", id)
gtest.Assert(err, nil)
n, e := result.RowsAffected()
gtest.Assert(e, nil)
gtest.Assert(n, 1)
})
gtest.Case(t, func() {
array := garray.New(true)
for i := 1; i <= INIT_DATA_SIZE; i++ {
array.Append(g.Map{
"id": i,
"passport": fmt.Sprintf(`user_%d`, i),
"password": fmt.Sprintf(`pass_%d`, i),
"nickname": fmt.Sprintf(`name_%d`, i),
"create_time": gtime.NewFromStr("2018-10-24 10:00:00").String(),
})
}
result, err := db.BatchInsert(name, array.Slice())
gtest.Assert(err, nil)
n, e := result.RowsAffected()
gtest.Assert(e, nil)
gtest.Assert(n, INIT_DATA_SIZE)
})
}
func Test_Model_InnerJoin(t *testing.T) {
gtest.Case(t, func() {
table1 := createInitTable("user1")