diff --git a/g/database/gdb/gdb.go b/g/database/gdb/gdb.go index d6bb3cedc..c9235e012 100644 --- a/g/database/gdb/gdb.go +++ b/g/database/gdb/gdb.go @@ -95,8 +95,10 @@ type DB interface { getCache() *gcache.Cache getChars() (charLeft string, charRight string) getDebug() bool + quoteWord(s string) string setSchema(sqlDb *sql.DB, schema string) error filterFields(table string, data map[string]interface{}) map[string]interface{} + formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) convertValue(fieldValue []byte, fieldType string) interface{} getTableFields(table string) (map[string]string, error) rowsToResult(rows *sql.Rows) (Result, error) diff --git a/g/database/gdb/gdb_base.go b/g/database/gdb/gdb_base.go index d2df5dc12..4e6186570 100644 --- a/g/database/gdb/gdb_base.go +++ b/g/database/gdb/gdb_base.go @@ -8,10 +8,13 @@ package gdb import ( + "bytes" "database/sql" "errors" "fmt" + "github.com/gogf/gf/g/text/gstr" "reflect" + "regexp" "strings" "github.com/gogf/gf/g/container/gvar" @@ -26,6 +29,11 @@ const ( gDEFAULT_DEBUG_SQL_LENGTH = 1000 ) +var ( + // 用于可转义的单词的识别正则对象 + wordReg = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`) +) + // 获取最近一条执行的sql func (bs *dbBase) GetLastSql() *Sql { if bs.sqls == nil { @@ -311,6 +319,7 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i var values []string var params []interface{} var dataMap Map + table = bs.db.quoteWord(table) // 使用反射判断data数据类型,如果为slice类型,那么自动转为批量操作 rv := reflect.ValueOf(data) kind := rv.Kind() @@ -339,16 +348,16 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i operation := getInsertOperationByOption(option) updateStr := "" if option == OPTION_SAVE { - var updates []string for k, _ := range dataMap { - updates = append(updates, - fmt.Sprintf("%s%s%s=VALUES(%s%s%s)", - charL, k, charR, - charL, k, charR, - ), + if len(updateStr) > 0 { + updateStr += "," + } + updateStr += fmt.Sprintf("%s%s%s=VALUES(%s%s%s)", + charL, k, charR, + charL, k, charR, ) } - updateStr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) + updateStr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", updateStr) } if link == nil { if link, err = bs.db.Master(); err != nil { @@ -381,6 +390,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt var keys []string var values []string var params []interface{} + table = bs.db.quoteWord(table) listMap := (List)(nil) switch v := list.(type) { case Result: @@ -432,22 +442,22 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt } batchResult := new(batchSqlResult) charL, charR := bs.db.getChars() - keyStr := charL + strings.Join(keys, charL+","+charR) + charR + keyStr := charL + strings.Join(keys, charR+","+charL) + charR valueHolderStr := "(" + strings.Join(holders, ",") + ")" // 操作判断 operation := getInsertOperationByOption(option) updateStr := "" if option == OPTION_SAVE { - var updates []string for _, k := range keys { - updates = append(updates, - fmt.Sprintf("%s%s%s=VALUES(%s%s%s)", - charL, k, charR, - charL, k, charR, - ), + if len(updateStr) > 0 { + updateStr += "," + } + updateStr += fmt.Sprintf("%s%s%s=VALUES(%s%s%s)", + charL, k, charR, + charL, k, charR, ) } - updateStr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) + updateStr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", updateStr) } // 构造批量写入数据格式(注意map的遍历是无序的) batchNum := gDEFAULT_BATCH_NUM @@ -499,7 +509,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt // CURD操作:数据更新,统一采用sql预处理。 // data参数支持string/map/struct/*struct类型。 func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { - newWhere, newArgs := formatWhere(condition, args) + newWhere, newArgs := bs.db.formatWhere(condition, args) if newWhere != "" { newWhere = " WHERE " + newWhere } @@ -509,8 +519,8 @@ func (bs *dbBase) Update(table string, data interface{}, condition interface{}, // CURD操作:数据更新,统一采用sql预处理。 // data参数支持string/map/struct/*struct类型类型。 func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) { + table = bs.db.quoteWord(table) updates := "" - charL, charR := bs.db.getChars() // 使用反射进行类型判断 rv := reflect.ValueOf(data) kind := rv.Kind() @@ -525,7 +535,7 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio case reflect.Struct: var fields []string for k, v := range structToMap(data) { - fields = append(fields, fmt.Sprintf("%s%s%s=?", charL, k, charR)) + fields = append(fields, bs.db.quoteWord(k)+"=?") params = append(params, convertParam(v)) } updates = strings.Join(fields, ",") @@ -546,7 +556,7 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio // CURD操作:删除数据 func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) { - newWhere, newArgs := formatWhere(condition, args) + newWhere, newArgs := bs.db.formatWhere(condition, args) if newWhere != "" { newWhere = " WHERE " + newWhere } @@ -560,6 +570,7 @@ func (bs *dbBase) doDelete(link dbLink, table string, condition string, args ... return nil, err } } + table = bs.db.quoteWord(table) return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...) } @@ -617,6 +628,98 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) { return records, nil } +// 格式化Where查询条件。 +func (bs *dbBase) formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) { + // 条件字符串处理 + buffer := bytes.NewBuffer(nil) + // 使用反射进行类型判断 + rv := reflect.ValueOf(where) + kind := rv.Kind() + if kind == reflect.Ptr { + rv = rv.Elem() + kind = rv.Kind() + } + switch kind { + // map/struct类型 + case reflect.Map: + fallthrough + case reflect.Struct: + for key, value := range structToMap(where) { + // 字段安全符号判断 + key = bs.db.quoteWord(key) + if buffer.Len() > 0 { + buffer.WriteString(" AND ") + } + // 支持slice键值/属性,如果只有一个?占位符号,那么作为IN查询,否则打散作为多个查询参数 + rv := reflect.ValueOf(value) + switch rv.Kind() { + case reflect.Slice: + fallthrough + case reflect.Array: + count := gstr.Count(key, "?") + if count == 0 { + buffer.WriteString(key + " IN(?)") + newArgs = append(newArgs, value) + } else if count != rv.Len() { + buffer.WriteString(key) + newArgs = append(newArgs, value) + } else { + buffer.WriteString(key) + // 如果键名/属性名称中带有多个?占位符号,那么将参数打散 + newArgs = append(newArgs, gconv.Interfaces(value)...) + } + default: + if value == nil { + buffer.WriteString(key) + } else { + // 支持key带操作符号 + if gstr.Pos(key, "?") == -1 { + if gstr.Pos(key, "<") == -1 && gstr.Pos(key, ">") == -1 && gstr.Pos(key, "=") == -1 { + buffer.WriteString(key + "=?") + } else { + buffer.WriteString(key + "?") + } + } else { + buffer.WriteString(key) + } + newArgs = append(newArgs, value) + } + } + } + + default: + buffer.WriteString(gconv.String(where)) + } + // 没有任何条件查询参数,直接返回 + if buffer.Len() == 0 { + return "", args + } + newArgs = append(newArgs, args...) + newWhere = buffer.String() + // 查询条件参数处理,主要处理slice参数类型 + if len(newArgs) > 0 { + // 支持例如 Where/And/Or("uid", 1) 这种格式 + if gstr.Pos(newWhere, "?") == -1 { + if gstr.Pos(newWhere, "<") == -1 && gstr.Pos(newWhere, ">") == -1 && gstr.Pos(newWhere, "=") == -1 { + newWhere += "=?" + } else { + newWhere += "?" + } + } + } + return +} + +// 使用关键字操作符转义给定字符串。 +// 如果给定的字符串不为单词,那么不转义,直接返回该字符串。 +func (bs *dbBase) quoteWord(s string) string { + charLeft, charRight := bs.db.getChars() + if wordReg.MatchString(s) && !gstr.ContainsAny(s, charLeft+charRight) { + return charLeft + s + charRight + } + return s +} + // 动态切换数据库 func (bs *dbBase) setSchema(sqlDb *sql.DB, schema string) error { _, err := sqlDb.Exec("USE " + schema) diff --git a/g/database/gdb/gdb_func.go b/g/database/gdb/gdb_func.go index 4bbfaa002..92a1ac078 100644 --- a/g/database/gdb/gdb_func.go +++ b/g/database/gdb/gdb_func.go @@ -7,7 +7,6 @@ package gdb import ( - "bytes" "database/sql" "errors" "fmt" @@ -74,86 +73,6 @@ func formatQuery(query string, args []interface{}) (newQuery string, newArgs []i return } -// 格式化Where查询条件。 -func formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) { - // 条件字符串处理 - buffer := bytes.NewBuffer(nil) - // 使用反射进行类型判断 - rv := reflect.ValueOf(where) - kind := rv.Kind() - if kind == reflect.Ptr { - rv = rv.Elem() - kind = rv.Kind() - } - switch kind { - // map/struct类型 - case reflect.Map: - fallthrough - case reflect.Struct: - for key, value := range structToMap(where) { - if buffer.Len() > 0 { - buffer.WriteString(" AND ") - } - // 支持slice键值/属性,如果只有一个?占位符号,那么作为IN查询,否则打散作为多个查询参数 - rv := reflect.ValueOf(value) - switch rv.Kind() { - case reflect.Slice: - fallthrough - case reflect.Array: - count := gstr.Count(key, "?") - if count == 0 { - buffer.WriteString(key + " IN(?)") - newArgs = append(newArgs, value) - } else if count != rv.Len() { - buffer.WriteString(key) - newArgs = append(newArgs, value) - } else { - buffer.WriteString(key) - // 如果键名/属性名称中带有多个?占位符号,那么将参数打散 - newArgs = append(newArgs, gconv.Interfaces(value)...) - } - default: - if value == nil { - buffer.WriteString(key) - } else { - // 支持key带操作符号 - if gstr.Pos(key, "?") == -1 { - if gstr.Pos(key, "<") == -1 && gstr.Pos(key, ">") == -1 && gstr.Pos(key, "=") == -1 { - buffer.WriteString(key + "=?") - } else { - buffer.WriteString(key + "?") - } - } else { - buffer.WriteString(key) - } - newArgs = append(newArgs, value) - } - } - } - - default: - buffer.WriteString(gconv.String(where)) - } - // 没有任何条件查询参数,直接返回 - if buffer.Len() == 0 { - return "", args - } - newArgs = append(newArgs, args...) - newWhere = buffer.String() - // 查询条件参数处理,主要处理slice参数类型 - if len(newArgs) > 0 { - // 支持例如 Where/And/Or("uid", 1) 这种格式 - if gstr.Pos(newWhere, "?") == -1 { - if gstr.Pos(newWhere, "<") == -1 && gstr.Pos(newWhere, ">") == -1 && gstr.Pos(newWhere, "=") == -1 { - newWhere += "=?" - } else { - newWhere += "?" - } - } - } - return -} - // 将预处理参数转换为底层数据库引擎支持的格式。 // 主要是判断参数是否为复杂数据类型,如果是,那么转换为基础类型。 func convertParam(value interface{}) interface{} { diff --git a/g/database/gdb/gdb_model.go b/g/database/gdb/gdb_model.go index 1874cb7b9..1f373b58a 100644 --- a/g/database/gdb/gdb_model.go +++ b/g/database/gdb/gdb_model.go @@ -43,7 +43,7 @@ func (bs *dbBase) Table(tables string) *Model { return &Model{ db: bs.db, tablesInit: tables, - tables: tables, + tables: bs.db.quoteWord(tables), fields: "*", start: -1, offset: -1, @@ -62,7 +62,7 @@ func (tx *TX) Table(tables string) *Model { db: tx.db, tx: tx, tablesInit: tables, - tables: tables, + tables: tx.db.quoteWord(tables), fields: "*", start: -1, offset: -1, @@ -154,7 +154,7 @@ func (md *Model) Where(where interface{}, args ...interface{}) *Model { if model.where != "" { return md.And(where, args...) } - newWhere, newArgs := formatWhere(where, args) + newWhere, newArgs := md.db.formatWhere(where, args) model.where = newWhere model.whereArgs = newArgs return model @@ -163,7 +163,7 @@ func (md *Model) Where(where interface{}, args ...interface{}) *Model { // 链式操作,添加AND条件到Where中 func (md *Model) And(where interface{}, args ...interface{}) *Model { model := md.getModel() - newWhere, newArgs := formatWhere(where, args) + newWhere, newArgs := md.db.formatWhere(where, args) if len(model.where) > 0 && model.where[0] == '(' { model.where = fmt.Sprintf(`%s AND (%s)`, model.where, newWhere) } else { @@ -176,7 +176,7 @@ func (md *Model) And(where interface{}, args ...interface{}) *Model { // 链式操作,添加OR条件到Where中 func (md *Model) Or(where interface{}, args ...interface{}) *Model { model := md.getModel() - newWhere, newArgs := formatWhere(where, args) + newWhere, newArgs := md.db.formatWhere(where, args) if len(model.where) > 0 && model.where[0] == '(' { model.where = fmt.Sprintf(`%s OR (%s)`, model.where, newWhere) } else { @@ -474,7 +474,7 @@ func (md *Model) Select() (Result, error) { // 链式操作,查询所有记录 func (md *Model) All() (Result, error) { - return md.getAll(fmt.Sprintf("SELECT %s FROM %s %s", md.fields, md.tables, md.getConditionSql()), md.whereArgs...) + return md.getAll(fmt.Sprintf("SELECT %s FROM %s%s", md.fields, md.tables, md.getConditionSql()), md.whereArgs...) } // 链式操作,查询单条记录 diff --git a/g/database/gdb/gdb_structure.go b/g/database/gdb/gdb_structure.go index 99adf65a1..c18c2fd51 100644 --- a/g/database/gdb/gdb_structure.go +++ b/g/database/gdb/gdb_structure.go @@ -96,8 +96,7 @@ func (bs *dbBase) getTableFields(table string) (fields map[string]string, err er // 缓存不存在时会查询数据表结构,缓存后不过期,直至程序重启(重新部署) v := bs.cache.GetOrSetFunc("table_fields_"+table, func() interface{} { result := (Result)(nil) - charL, charR := bs.db.getChars() - result, err = bs.GetAll(fmt.Sprintf(`SHOW COLUMNS FROM %s%s%s`, charL, table, charR)) + result, err = bs.GetAll(fmt.Sprintf(`SHOW COLUMNS FROM %s`, bs.db.quoteWord(table))) if err != nil { return nil }