add quotes for fields and table name for gdb

This commit is contained in:
john
2019-07-22 23:48:39 +08:00
parent 697dbdc604
commit 6a76725d64
5 changed files with 131 additions and 108 deletions

View File

@ -95,8 +95,10 @@ type DB interface {
getCache() *gcache.Cache
getChars() (charLeft string, charRight string)
getDebug() bool
quoteWord(s string) string
setSchema(sqlDb *sql.DB, schema string) error
filterFields(table string, data map[string]interface{}) map[string]interface{}
formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{})
convertValue(fieldValue []byte, fieldType string) interface{}
getTableFields(table string) (map[string]string, error)
rowsToResult(rows *sql.Rows) (Result, error)

View File

@ -8,10 +8,13 @@
package gdb
import (
"bytes"
"database/sql"
"errors"
"fmt"
"github.com/gogf/gf/g/text/gstr"
"reflect"
"regexp"
"strings"
"github.com/gogf/gf/g/container/gvar"
@ -26,6 +29,11 @@ const (
gDEFAULT_DEBUG_SQL_LENGTH = 1000
)
var (
// 用于可转义的单词的识别正则对象
wordReg = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`)
)
// 获取最近一条执行的sql
func (bs *dbBase) GetLastSql() *Sql {
if bs.sqls == nil {
@ -311,6 +319,7 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i
var values []string
var params []interface{}
var dataMap Map
table = bs.db.quoteWord(table)
// 使用反射判断data数据类型如果为slice类型那么自动转为批量操作
rv := reflect.ValueOf(data)
kind := rv.Kind()
@ -339,16 +348,16 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i
operation := getInsertOperationByOption(option)
updateStr := ""
if option == OPTION_SAVE {
var updates []string
for k, _ := range dataMap {
updates = append(updates,
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
charL, k, charR,
charL, k, charR,
),
if len(updateStr) > 0 {
updateStr += ","
}
updateStr += fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
charL, k, charR,
charL, k, charR,
)
}
updateStr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
updateStr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", updateStr)
}
if link == nil {
if link, err = bs.db.Master(); err != nil {
@ -381,6 +390,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
var keys []string
var values []string
var params []interface{}
table = bs.db.quoteWord(table)
listMap := (List)(nil)
switch v := list.(type) {
case Result:
@ -432,22 +442,22 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
}
batchResult := new(batchSqlResult)
charL, charR := bs.db.getChars()
keyStr := charL + strings.Join(keys, charL+","+charR) + charR
keyStr := charL + strings.Join(keys, charR+","+charL) + charR
valueHolderStr := "(" + strings.Join(holders, ",") + ")"
// 操作判断
operation := getInsertOperationByOption(option)
updateStr := ""
if option == OPTION_SAVE {
var updates []string
for _, k := range keys {
updates = append(updates,
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
charL, k, charR,
charL, k, charR,
),
if len(updateStr) > 0 {
updateStr += ","
}
updateStr += fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
charL, k, charR,
charL, k, charR,
)
}
updateStr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
updateStr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", updateStr)
}
// 构造批量写入数据格式(注意map的遍历是无序的)
batchNum := gDEFAULT_BATCH_NUM
@ -499,7 +509,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
// CURD操作:数据更新统一采用sql预处理。
// data参数支持string/map/struct/*struct类型。
func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
newWhere, newArgs := formatWhere(condition, args)
newWhere, newArgs := bs.db.formatWhere(condition, args)
if newWhere != "" {
newWhere = " WHERE " + newWhere
}
@ -509,8 +519,8 @@ func (bs *dbBase) Update(table string, data interface{}, condition interface{},
// CURD操作:数据更新统一采用sql预处理。
// data参数支持string/map/struct/*struct类型类型。
func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) {
table = bs.db.quoteWord(table)
updates := ""
charL, charR := bs.db.getChars()
// 使用反射进行类型判断
rv := reflect.ValueOf(data)
kind := rv.Kind()
@ -525,7 +535,7 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
case reflect.Struct:
var fields []string
for k, v := range structToMap(data) {
fields = append(fields, fmt.Sprintf("%s%s%s=?", charL, k, charR))
fields = append(fields, bs.db.quoteWord(k)+"=?")
params = append(params, convertParam(v))
}
updates = strings.Join(fields, ",")
@ -546,7 +556,7 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
// CURD操作:删除数据
func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
newWhere, newArgs := formatWhere(condition, args)
newWhere, newArgs := bs.db.formatWhere(condition, args)
if newWhere != "" {
newWhere = " WHERE " + newWhere
}
@ -560,6 +570,7 @@ func (bs *dbBase) doDelete(link dbLink, table string, condition string, args ...
return nil, err
}
}
table = bs.db.quoteWord(table)
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
}
@ -617,6 +628,98 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
return records, nil
}
// 格式化Where查询条件。
func (bs *dbBase) formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) {
// 条件字符串处理
buffer := bytes.NewBuffer(nil)
// 使用反射进行类型判断
rv := reflect.ValueOf(where)
kind := rv.Kind()
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
// map/struct类型
case reflect.Map:
fallthrough
case reflect.Struct:
for key, value := range structToMap(where) {
// 字段安全符号判断
key = bs.db.quoteWord(key)
if buffer.Len() > 0 {
buffer.WriteString(" AND ")
}
// 支持slice键值/属性,如果只有一个?占位符号那么作为IN查询否则打散作为多个查询参数
rv := reflect.ValueOf(value)
switch rv.Kind() {
case reflect.Slice:
fallthrough
case reflect.Array:
count := gstr.Count(key, "?")
if count == 0 {
buffer.WriteString(key + " IN(?)")
newArgs = append(newArgs, value)
} else if count != rv.Len() {
buffer.WriteString(key)
newArgs = append(newArgs, value)
} else {
buffer.WriteString(key)
// 如果键名/属性名称中带有多个?占位符号,那么将参数打散
newArgs = append(newArgs, gconv.Interfaces(value)...)
}
default:
if value == nil {
buffer.WriteString(key)
} else {
// 支持key带操作符号
if gstr.Pos(key, "?") == -1 {
if gstr.Pos(key, "<") == -1 && gstr.Pos(key, ">") == -1 && gstr.Pos(key, "=") == -1 {
buffer.WriteString(key + "=?")
} else {
buffer.WriteString(key + "?")
}
} else {
buffer.WriteString(key)
}
newArgs = append(newArgs, value)
}
}
}
default:
buffer.WriteString(gconv.String(where))
}
// 没有任何条件查询参数,直接返回
if buffer.Len() == 0 {
return "", args
}
newArgs = append(newArgs, args...)
newWhere = buffer.String()
// 查询条件参数处理主要处理slice参数类型
if len(newArgs) > 0 {
// 支持例如 Where/And/Or("uid", 1) 这种格式
if gstr.Pos(newWhere, "?") == -1 {
if gstr.Pos(newWhere, "<") == -1 && gstr.Pos(newWhere, ">") == -1 && gstr.Pos(newWhere, "=") == -1 {
newWhere += "=?"
} else {
newWhere += "?"
}
}
}
return
}
// 使用关键字操作符转义给定字符串。
// 如果给定的字符串不为单词,那么不转义,直接返回该字符串。
func (bs *dbBase) quoteWord(s string) string {
charLeft, charRight := bs.db.getChars()
if wordReg.MatchString(s) && !gstr.ContainsAny(s, charLeft+charRight) {
return charLeft + s + charRight
}
return s
}
// 动态切换数据库
func (bs *dbBase) setSchema(sqlDb *sql.DB, schema string) error {
_, err := sqlDb.Exec("USE " + schema)

View File

@ -7,7 +7,6 @@
package gdb
import (
"bytes"
"database/sql"
"errors"
"fmt"
@ -74,86 +73,6 @@ func formatQuery(query string, args []interface{}) (newQuery string, newArgs []i
return
}
// 格式化Where查询条件。
func formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) {
// 条件字符串处理
buffer := bytes.NewBuffer(nil)
// 使用反射进行类型判断
rv := reflect.ValueOf(where)
kind := rv.Kind()
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
// map/struct类型
case reflect.Map:
fallthrough
case reflect.Struct:
for key, value := range structToMap(where) {
if buffer.Len() > 0 {
buffer.WriteString(" AND ")
}
// 支持slice键值/属性,如果只有一个?占位符号那么作为IN查询否则打散作为多个查询参数
rv := reflect.ValueOf(value)
switch rv.Kind() {
case reflect.Slice:
fallthrough
case reflect.Array:
count := gstr.Count(key, "?")
if count == 0 {
buffer.WriteString(key + " IN(?)")
newArgs = append(newArgs, value)
} else if count != rv.Len() {
buffer.WriteString(key)
newArgs = append(newArgs, value)
} else {
buffer.WriteString(key)
// 如果键名/属性名称中带有多个?占位符号,那么将参数打散
newArgs = append(newArgs, gconv.Interfaces(value)...)
}
default:
if value == nil {
buffer.WriteString(key)
} else {
// 支持key带操作符号
if gstr.Pos(key, "?") == -1 {
if gstr.Pos(key, "<") == -1 && gstr.Pos(key, ">") == -1 && gstr.Pos(key, "=") == -1 {
buffer.WriteString(key + "=?")
} else {
buffer.WriteString(key + "?")
}
} else {
buffer.WriteString(key)
}
newArgs = append(newArgs, value)
}
}
}
default:
buffer.WriteString(gconv.String(where))
}
// 没有任何条件查询参数,直接返回
if buffer.Len() == 0 {
return "", args
}
newArgs = append(newArgs, args...)
newWhere = buffer.String()
// 查询条件参数处理主要处理slice参数类型
if len(newArgs) > 0 {
// 支持例如 Where/And/Or("uid", 1) 这种格式
if gstr.Pos(newWhere, "?") == -1 {
if gstr.Pos(newWhere, "<") == -1 && gstr.Pos(newWhere, ">") == -1 && gstr.Pos(newWhere, "=") == -1 {
newWhere += "=?"
} else {
newWhere += "?"
}
}
}
return
}
// 将预处理参数转换为底层数据库引擎支持的格式。
// 主要是判断参数是否为复杂数据类型,如果是,那么转换为基础类型。
func convertParam(value interface{}) interface{} {

View File

@ -43,7 +43,7 @@ func (bs *dbBase) Table(tables string) *Model {
return &Model{
db: bs.db,
tablesInit: tables,
tables: tables,
tables: bs.db.quoteWord(tables),
fields: "*",
start: -1,
offset: -1,
@ -62,7 +62,7 @@ func (tx *TX) Table(tables string) *Model {
db: tx.db,
tx: tx,
tablesInit: tables,
tables: tables,
tables: tx.db.quoteWord(tables),
fields: "*",
start: -1,
offset: -1,
@ -154,7 +154,7 @@ func (md *Model) Where(where interface{}, args ...interface{}) *Model {
if model.where != "" {
return md.And(where, args...)
}
newWhere, newArgs := formatWhere(where, args)
newWhere, newArgs := md.db.formatWhere(where, args)
model.where = newWhere
model.whereArgs = newArgs
return model
@ -163,7 +163,7 @@ func (md *Model) Where(where interface{}, args ...interface{}) *Model {
// 链式操作添加AND条件到Where中
func (md *Model) And(where interface{}, args ...interface{}) *Model {
model := md.getModel()
newWhere, newArgs := formatWhere(where, args)
newWhere, newArgs := md.db.formatWhere(where, args)
if len(model.where) > 0 && model.where[0] == '(' {
model.where = fmt.Sprintf(`%s AND (%s)`, model.where, newWhere)
} else {
@ -176,7 +176,7 @@ func (md *Model) And(where interface{}, args ...interface{}) *Model {
// 链式操作添加OR条件到Where中
func (md *Model) Or(where interface{}, args ...interface{}) *Model {
model := md.getModel()
newWhere, newArgs := formatWhere(where, args)
newWhere, newArgs := md.db.formatWhere(where, args)
if len(model.where) > 0 && model.where[0] == '(' {
model.where = fmt.Sprintf(`%s OR (%s)`, model.where, newWhere)
} else {
@ -474,7 +474,7 @@ func (md *Model) Select() (Result, error) {
// 链式操作,查询所有记录
func (md *Model) All() (Result, error) {
return md.getAll(fmt.Sprintf("SELECT %s FROM %s %s", md.fields, md.tables, md.getConditionSql()), md.whereArgs...)
return md.getAll(fmt.Sprintf("SELECT %s FROM %s%s", md.fields, md.tables, md.getConditionSql()), md.whereArgs...)
}
// 链式操作,查询单条记录

View File

@ -96,8 +96,7 @@ func (bs *dbBase) getTableFields(table string) (fields map[string]string, err er
// 缓存不存在时会查询数据表结构,缓存后不过期,直至程序重启(重新部署)
v := bs.cache.GetOrSetFunc("table_fields_"+table, func() interface{} {
result := (Result)(nil)
charL, charR := bs.db.getChars()
result, err = bs.GetAll(fmt.Sprintf(`SHOW COLUMNS FROM %s%s%s`, charL, table, charR))
result, err = bs.GetAll(fmt.Sprintf(`SHOW COLUMNS FROM %s`, bs.db.quoteWord(table)))
if err != nil {
return nil
}