mirror of
https://gitee.com/johng/gf
synced 2026-06-06 16:21:40 +08:00
add quotes for fields and table name for gdb
This commit is contained in:
@ -95,8 +95,10 @@ type DB interface {
|
||||
getCache() *gcache.Cache
|
||||
getChars() (charLeft string, charRight string)
|
||||
getDebug() bool
|
||||
quoteWord(s string) string
|
||||
setSchema(sqlDb *sql.DB, schema string) error
|
||||
filterFields(table string, data map[string]interface{}) map[string]interface{}
|
||||
formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{})
|
||||
convertValue(fieldValue []byte, fieldType string) interface{}
|
||||
getTableFields(table string) (map[string]string, error)
|
||||
rowsToResult(rows *sql.Rows) (Result, error)
|
||||
|
||||
@ -8,10 +8,13 @@
|
||||
package gdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gogf/gf/g/text/gstr"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/g/container/gvar"
|
||||
@ -26,6 +29,11 @@ const (
|
||||
gDEFAULT_DEBUG_SQL_LENGTH = 1000
|
||||
)
|
||||
|
||||
var (
|
||||
// 用于可转义的单词的识别正则对象
|
||||
wordReg = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`)
|
||||
)
|
||||
|
||||
// 获取最近一条执行的sql
|
||||
func (bs *dbBase) GetLastSql() *Sql {
|
||||
if bs.sqls == nil {
|
||||
@ -311,6 +319,7 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i
|
||||
var values []string
|
||||
var params []interface{}
|
||||
var dataMap Map
|
||||
table = bs.db.quoteWord(table)
|
||||
// 使用反射判断data数据类型,如果为slice类型,那么自动转为批量操作
|
||||
rv := reflect.ValueOf(data)
|
||||
kind := rv.Kind()
|
||||
@ -339,16 +348,16 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i
|
||||
operation := getInsertOperationByOption(option)
|
||||
updateStr := ""
|
||||
if option == OPTION_SAVE {
|
||||
var updates []string
|
||||
for k, _ := range dataMap {
|
||||
updates = append(updates,
|
||||
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
|
||||
charL, k, charR,
|
||||
charL, k, charR,
|
||||
),
|
||||
if len(updateStr) > 0 {
|
||||
updateStr += ","
|
||||
}
|
||||
updateStr += fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
|
||||
charL, k, charR,
|
||||
charL, k, charR,
|
||||
)
|
||||
}
|
||||
updateStr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
|
||||
updateStr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", updateStr)
|
||||
}
|
||||
if link == nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
@ -381,6 +390,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
|
||||
var keys []string
|
||||
var values []string
|
||||
var params []interface{}
|
||||
table = bs.db.quoteWord(table)
|
||||
listMap := (List)(nil)
|
||||
switch v := list.(type) {
|
||||
case Result:
|
||||
@ -432,22 +442,22 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
|
||||
}
|
||||
batchResult := new(batchSqlResult)
|
||||
charL, charR := bs.db.getChars()
|
||||
keyStr := charL + strings.Join(keys, charL+","+charR) + charR
|
||||
keyStr := charL + strings.Join(keys, charR+","+charL) + charR
|
||||
valueHolderStr := "(" + strings.Join(holders, ",") + ")"
|
||||
// 操作判断
|
||||
operation := getInsertOperationByOption(option)
|
||||
updateStr := ""
|
||||
if option == OPTION_SAVE {
|
||||
var updates []string
|
||||
for _, k := range keys {
|
||||
updates = append(updates,
|
||||
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
|
||||
charL, k, charR,
|
||||
charL, k, charR,
|
||||
),
|
||||
if len(updateStr) > 0 {
|
||||
updateStr += ","
|
||||
}
|
||||
updateStr += fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
|
||||
charL, k, charR,
|
||||
charL, k, charR,
|
||||
)
|
||||
}
|
||||
updateStr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
|
||||
updateStr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", updateStr)
|
||||
}
|
||||
// 构造批量写入数据格式(注意map的遍历是无序的)
|
||||
batchNum := gDEFAULT_BATCH_NUM
|
||||
@ -499,7 +509,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
|
||||
// CURD操作:数据更新,统一采用sql预处理。
|
||||
// data参数支持string/map/struct/*struct类型。
|
||||
func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
newWhere, newArgs := formatWhere(condition, args)
|
||||
newWhere, newArgs := bs.db.formatWhere(condition, args)
|
||||
if newWhere != "" {
|
||||
newWhere = " WHERE " + newWhere
|
||||
}
|
||||
@ -509,8 +519,8 @@ func (bs *dbBase) Update(table string, data interface{}, condition interface{},
|
||||
// CURD操作:数据更新,统一采用sql预处理。
|
||||
// data参数支持string/map/struct/*struct类型类型。
|
||||
func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) {
|
||||
table = bs.db.quoteWord(table)
|
||||
updates := ""
|
||||
charL, charR := bs.db.getChars()
|
||||
// 使用反射进行类型判断
|
||||
rv := reflect.ValueOf(data)
|
||||
kind := rv.Kind()
|
||||
@ -525,7 +535,7 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
|
||||
case reflect.Struct:
|
||||
var fields []string
|
||||
for k, v := range structToMap(data) {
|
||||
fields = append(fields, fmt.Sprintf("%s%s%s=?", charL, k, charR))
|
||||
fields = append(fields, bs.db.quoteWord(k)+"=?")
|
||||
params = append(params, convertParam(v))
|
||||
}
|
||||
updates = strings.Join(fields, ",")
|
||||
@ -546,7 +556,7 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
|
||||
|
||||
// CURD操作:删除数据
|
||||
func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
|
||||
newWhere, newArgs := formatWhere(condition, args)
|
||||
newWhere, newArgs := bs.db.formatWhere(condition, args)
|
||||
if newWhere != "" {
|
||||
newWhere = " WHERE " + newWhere
|
||||
}
|
||||
@ -560,6 +570,7 @@ func (bs *dbBase) doDelete(link dbLink, table string, condition string, args ...
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
table = bs.db.quoteWord(table)
|
||||
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
|
||||
}
|
||||
|
||||
@ -617,6 +628,98 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// 格式化Where查询条件。
|
||||
func (bs *dbBase) formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) {
|
||||
// 条件字符串处理
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
// 使用反射进行类型判断
|
||||
rv := reflect.ValueOf(where)
|
||||
kind := rv.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
kind = rv.Kind()
|
||||
}
|
||||
switch kind {
|
||||
// map/struct类型
|
||||
case reflect.Map:
|
||||
fallthrough
|
||||
case reflect.Struct:
|
||||
for key, value := range structToMap(where) {
|
||||
// 字段安全符号判断
|
||||
key = bs.db.quoteWord(key)
|
||||
if buffer.Len() > 0 {
|
||||
buffer.WriteString(" AND ")
|
||||
}
|
||||
// 支持slice键值/属性,如果只有一个?占位符号,那么作为IN查询,否则打散作为多个查询参数
|
||||
rv := reflect.ValueOf(value)
|
||||
switch rv.Kind() {
|
||||
case reflect.Slice:
|
||||
fallthrough
|
||||
case reflect.Array:
|
||||
count := gstr.Count(key, "?")
|
||||
if count == 0 {
|
||||
buffer.WriteString(key + " IN(?)")
|
||||
newArgs = append(newArgs, value)
|
||||
} else if count != rv.Len() {
|
||||
buffer.WriteString(key)
|
||||
newArgs = append(newArgs, value)
|
||||
} else {
|
||||
buffer.WriteString(key)
|
||||
// 如果键名/属性名称中带有多个?占位符号,那么将参数打散
|
||||
newArgs = append(newArgs, gconv.Interfaces(value)...)
|
||||
}
|
||||
default:
|
||||
if value == nil {
|
||||
buffer.WriteString(key)
|
||||
} else {
|
||||
// 支持key带操作符号
|
||||
if gstr.Pos(key, "?") == -1 {
|
||||
if gstr.Pos(key, "<") == -1 && gstr.Pos(key, ">") == -1 && gstr.Pos(key, "=") == -1 {
|
||||
buffer.WriteString(key + "=?")
|
||||
} else {
|
||||
buffer.WriteString(key + "?")
|
||||
}
|
||||
} else {
|
||||
buffer.WriteString(key)
|
||||
}
|
||||
newArgs = append(newArgs, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
buffer.WriteString(gconv.String(where))
|
||||
}
|
||||
// 没有任何条件查询参数,直接返回
|
||||
if buffer.Len() == 0 {
|
||||
return "", args
|
||||
}
|
||||
newArgs = append(newArgs, args...)
|
||||
newWhere = buffer.String()
|
||||
// 查询条件参数处理,主要处理slice参数类型
|
||||
if len(newArgs) > 0 {
|
||||
// 支持例如 Where/And/Or("uid", 1) 这种格式
|
||||
if gstr.Pos(newWhere, "?") == -1 {
|
||||
if gstr.Pos(newWhere, "<") == -1 && gstr.Pos(newWhere, ">") == -1 && gstr.Pos(newWhere, "=") == -1 {
|
||||
newWhere += "=?"
|
||||
} else {
|
||||
newWhere += "?"
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 使用关键字操作符转义给定字符串。
|
||||
// 如果给定的字符串不为单词,那么不转义,直接返回该字符串。
|
||||
func (bs *dbBase) quoteWord(s string) string {
|
||||
charLeft, charRight := bs.db.getChars()
|
||||
if wordReg.MatchString(s) && !gstr.ContainsAny(s, charLeft+charRight) {
|
||||
return charLeft + s + charRight
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// 动态切换数据库
|
||||
func (bs *dbBase) setSchema(sqlDb *sql.DB, schema string) error {
|
||||
_, err := sqlDb.Exec("USE " + schema)
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
package gdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -74,86 +73,6 @@ func formatQuery(query string, args []interface{}) (newQuery string, newArgs []i
|
||||
return
|
||||
}
|
||||
|
||||
// 格式化Where查询条件。
|
||||
func formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) {
|
||||
// 条件字符串处理
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
// 使用反射进行类型判断
|
||||
rv := reflect.ValueOf(where)
|
||||
kind := rv.Kind()
|
||||
if kind == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
kind = rv.Kind()
|
||||
}
|
||||
switch kind {
|
||||
// map/struct类型
|
||||
case reflect.Map:
|
||||
fallthrough
|
||||
case reflect.Struct:
|
||||
for key, value := range structToMap(where) {
|
||||
if buffer.Len() > 0 {
|
||||
buffer.WriteString(" AND ")
|
||||
}
|
||||
// 支持slice键值/属性,如果只有一个?占位符号,那么作为IN查询,否则打散作为多个查询参数
|
||||
rv := reflect.ValueOf(value)
|
||||
switch rv.Kind() {
|
||||
case reflect.Slice:
|
||||
fallthrough
|
||||
case reflect.Array:
|
||||
count := gstr.Count(key, "?")
|
||||
if count == 0 {
|
||||
buffer.WriteString(key + " IN(?)")
|
||||
newArgs = append(newArgs, value)
|
||||
} else if count != rv.Len() {
|
||||
buffer.WriteString(key)
|
||||
newArgs = append(newArgs, value)
|
||||
} else {
|
||||
buffer.WriteString(key)
|
||||
// 如果键名/属性名称中带有多个?占位符号,那么将参数打散
|
||||
newArgs = append(newArgs, gconv.Interfaces(value)...)
|
||||
}
|
||||
default:
|
||||
if value == nil {
|
||||
buffer.WriteString(key)
|
||||
} else {
|
||||
// 支持key带操作符号
|
||||
if gstr.Pos(key, "?") == -1 {
|
||||
if gstr.Pos(key, "<") == -1 && gstr.Pos(key, ">") == -1 && gstr.Pos(key, "=") == -1 {
|
||||
buffer.WriteString(key + "=?")
|
||||
} else {
|
||||
buffer.WriteString(key + "?")
|
||||
}
|
||||
} else {
|
||||
buffer.WriteString(key)
|
||||
}
|
||||
newArgs = append(newArgs, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
buffer.WriteString(gconv.String(where))
|
||||
}
|
||||
// 没有任何条件查询参数,直接返回
|
||||
if buffer.Len() == 0 {
|
||||
return "", args
|
||||
}
|
||||
newArgs = append(newArgs, args...)
|
||||
newWhere = buffer.String()
|
||||
// 查询条件参数处理,主要处理slice参数类型
|
||||
if len(newArgs) > 0 {
|
||||
// 支持例如 Where/And/Or("uid", 1) 这种格式
|
||||
if gstr.Pos(newWhere, "?") == -1 {
|
||||
if gstr.Pos(newWhere, "<") == -1 && gstr.Pos(newWhere, ">") == -1 && gstr.Pos(newWhere, "=") == -1 {
|
||||
newWhere += "=?"
|
||||
} else {
|
||||
newWhere += "?"
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 将预处理参数转换为底层数据库引擎支持的格式。
|
||||
// 主要是判断参数是否为复杂数据类型,如果是,那么转换为基础类型。
|
||||
func convertParam(value interface{}) interface{} {
|
||||
|
||||
@ -43,7 +43,7 @@ func (bs *dbBase) Table(tables string) *Model {
|
||||
return &Model{
|
||||
db: bs.db,
|
||||
tablesInit: tables,
|
||||
tables: tables,
|
||||
tables: bs.db.quoteWord(tables),
|
||||
fields: "*",
|
||||
start: -1,
|
||||
offset: -1,
|
||||
@ -62,7 +62,7 @@ func (tx *TX) Table(tables string) *Model {
|
||||
db: tx.db,
|
||||
tx: tx,
|
||||
tablesInit: tables,
|
||||
tables: tables,
|
||||
tables: tx.db.quoteWord(tables),
|
||||
fields: "*",
|
||||
start: -1,
|
||||
offset: -1,
|
||||
@ -154,7 +154,7 @@ func (md *Model) Where(where interface{}, args ...interface{}) *Model {
|
||||
if model.where != "" {
|
||||
return md.And(where, args...)
|
||||
}
|
||||
newWhere, newArgs := formatWhere(where, args)
|
||||
newWhere, newArgs := md.db.formatWhere(where, args)
|
||||
model.where = newWhere
|
||||
model.whereArgs = newArgs
|
||||
return model
|
||||
@ -163,7 +163,7 @@ func (md *Model) Where(where interface{}, args ...interface{}) *Model {
|
||||
// 链式操作,添加AND条件到Where中
|
||||
func (md *Model) And(where interface{}, args ...interface{}) *Model {
|
||||
model := md.getModel()
|
||||
newWhere, newArgs := formatWhere(where, args)
|
||||
newWhere, newArgs := md.db.formatWhere(where, args)
|
||||
if len(model.where) > 0 && model.where[0] == '(' {
|
||||
model.where = fmt.Sprintf(`%s AND (%s)`, model.where, newWhere)
|
||||
} else {
|
||||
@ -176,7 +176,7 @@ func (md *Model) And(where interface{}, args ...interface{}) *Model {
|
||||
// 链式操作,添加OR条件到Where中
|
||||
func (md *Model) Or(where interface{}, args ...interface{}) *Model {
|
||||
model := md.getModel()
|
||||
newWhere, newArgs := formatWhere(where, args)
|
||||
newWhere, newArgs := md.db.formatWhere(where, args)
|
||||
if len(model.where) > 0 && model.where[0] == '(' {
|
||||
model.where = fmt.Sprintf(`%s OR (%s)`, model.where, newWhere)
|
||||
} else {
|
||||
@ -474,7 +474,7 @@ func (md *Model) Select() (Result, error) {
|
||||
|
||||
// 链式操作,查询所有记录
|
||||
func (md *Model) All() (Result, error) {
|
||||
return md.getAll(fmt.Sprintf("SELECT %s FROM %s %s", md.fields, md.tables, md.getConditionSql()), md.whereArgs...)
|
||||
return md.getAll(fmt.Sprintf("SELECT %s FROM %s%s", md.fields, md.tables, md.getConditionSql()), md.whereArgs...)
|
||||
}
|
||||
|
||||
// 链式操作,查询单条记录
|
||||
|
||||
@ -96,8 +96,7 @@ func (bs *dbBase) getTableFields(table string) (fields map[string]string, err er
|
||||
// 缓存不存在时会查询数据表结构,缓存后不过期,直至程序重启(重新部署)
|
||||
v := bs.cache.GetOrSetFunc("table_fields_"+table, func() interface{} {
|
||||
result := (Result)(nil)
|
||||
charL, charR := bs.db.getChars()
|
||||
result, err = bs.GetAll(fmt.Sprintf(`SHOW COLUMNS FROM %s%s%s`, charL, table, charR))
|
||||
result, err = bs.GetAll(fmt.Sprintf(`SHOW COLUMNS FROM %s`, bs.db.quoteWord(table)))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user