diff --git a/g/database/gdb/gdb_base.go b/g/database/gdb/gdb_base.go index d4a80ed45..17509944e 100644 --- a/g/database/gdb/gdb_base.go +++ b/g/database/gdb/gdb_base.go @@ -410,7 +410,8 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio return nil, err } } - return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, formatCondition(condition)), params...) + newWhere, newArgs := formatCondition(condition, params) + return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, newWhere), newArgs...) } // CURD操作:删除数据 @@ -424,7 +425,8 @@ func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{ // CURD操作:删除数据 func (bs *dbBase) doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) { - return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, formatCondition(condition)), args...) + newWhere, newArgs := formatCondition(condition, args) + return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, newWhere), newArgs...) } // 获得缓存对象 diff --git a/g/database/gdb/gdb_func.go b/g/database/gdb/gdb_func.go index 42210e115..d16402808 100644 --- a/g/database/gdb/gdb_func.go +++ b/g/database/gdb/gdb_func.go @@ -7,6 +7,7 @@ package gdb import ( + "bytes" "database/sql" "errors" "fmt" @@ -14,9 +15,11 @@ import ( "gitee.com/johng/gf/g/os/glog" "gitee.com/johng/gf/g/os/gtime" "gitee.com/johng/gf/g/util/gconv" + "gitee.com/johng/gf/g/util/gregex" "gitee.com/johng/gf/g/util/gstr" _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" "reflect" + "strings" ) // 将数据查询的列表数据*sql.Rows转换为Result类型 @@ -55,30 +58,61 @@ func rowsToResult(rows *sql.Rows) (Result, error) { } // 格式化SQL查询条件 -func formatCondition(condition interface{}) (where string) { - if reflect.ValueOf(condition).Kind() == reflect.Map { - ks := reflect.ValueOf(condition).MapKeys() - vs := reflect.ValueOf(condition) +func formatCondition(where interface{}, args []interface{}) (string, []interface{}) { + // 条件字符串处理 + buffer := bytes.NewBuffer(nil) + if reflect.ValueOf(where).Kind() == reflect.Map { + ks := reflect.ValueOf(where).MapKeys() + vs := reflect.ValueOf(where) for _, k := range ks { key := gconv.String(k.Interface()) value := gconv.String(vs.MapIndex(k).Interface()) - isNum := gstr.IsNumeric(value) - if len(where) > 0 { - where += " AND " + if buffer.Len() > 0 { + buffer.WriteString(" AND ") } - if isNum || value == "?" { - where += key + "=" + value + if gstr.IsNumeric(value) || value == "?" { + buffer.WriteString(key + "=" + value) } else { - where += key + "='" + value + "'" + buffer.WriteString(key + "='" + value + "'") } } } else { - where += gconv.String(condition) + buffer.Write(gconv.Bytes(where)) } - if len(where) == 0 { - where = "1" + if buffer.Len() == 0 { + buffer.WriteString("1") } - return + // 查询条件处理 + newWhere := buffer.String() + newArgs := make([]interface{}, 0) + if len(args) > 0 { + for index, arg := range args { + rv := reflect.ValueOf(arg) + kind := rv.Kind() + if kind == reflect.Ptr { + rv = rv.Elem() + kind = rv.Kind() + } + switch kind { + case reflect.Slice: fallthrough + case reflect.Array: + for i := 0; i < rv.Len(); i++ { + newArgs = append(newArgs, rv.Index(i).Interface()) + } + counter := 0 + newWhere, _ = gregex.ReplaceStringFunc(`\?`, newWhere, func(s string) string { + counter++ + if counter == index + 1 { + return "?" + strings.Repeat(",?", rv.Len() - 1) + } + return s + }) + default: + newArgs = append(newArgs, arg) + } + } + } + return newWhere, newArgs } // 打印SQL对象(仅在debug=true时有效) diff --git a/g/database/gdb/gdb_model.go b/g/database/gdb/gdb_model.go index 95780f6d1..64a6c8ab9 100644 --- a/g/database/gdb/gdb_model.go +++ b/g/database/gdb/gdb_model.go @@ -108,8 +108,9 @@ func (md *Model) Filter() (*Model) { // 链式操作,condition,支持string & gdb.Map func (md *Model) Where(where interface{}, args ...interface{}) (*Model) { - md.where = formatCondition(where) - md.whereArgs = append(md.whereArgs, args...) + newWhere, newArgs := formatCondition(where, args) + md.where = newWhere + md.whereArgs = append(md.whereArgs, newArgs...) // 支持 Where("uid", 1)这种格式 if len(args) == 1 && strings.Index(md.where , "?") < 0 { md.where += "=?" @@ -119,15 +120,17 @@ func (md *Model) Where(where interface{}, args ...interface{}) (*Model) { // 链式操作,添加AND条件到Where中 func (md *Model) And(where interface{}, args ...interface{}) (*Model) { - md.where += " AND " + formatCondition(where) - md.whereArgs = append(md.whereArgs, args...) + newWhere, newArgs := formatCondition(where, args) + md.where += " AND " + newWhere + md.whereArgs = append(md.whereArgs, newArgs...) return md } // 链式操作,添加OR条件到Where中 func (md *Model) Or(where interface{}, args ...interface{}) (*Model) { - md.where += " OR " + formatCondition(where) - md.whereArgs = append(md.whereArgs, args...) + newWhere, newArgs := formatCondition(where, args) + md.where += " OR " + newWhere + md.whereArgs = append(md.whereArgs, newArgs...) return md } diff --git a/g/database/gdb/gdb_unit_2_test.go b/g/database/gdb/gdb_unit_2_test.go index 469db382f..f6b2b40ee 100644 --- a/g/database/gdb/gdb_unit_2_test.go +++ b/g/database/gdb/gdb_unit_2_test.go @@ -168,6 +168,25 @@ func TestModel_GroupBy(t *testing.T) { gtest.Assert(result[0]["nickname"].String(), "T111") } +func TestModel_Where1(t *testing.T) { + result, err := db.Table("user").Where("id IN(?)", g.Slice{1,3}).OrderBy("id ASC").All() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(len(result), 2) + gtest.Assert(result[0]["id"].Int(), 1) + gtest.Assert(result[1]["id"].Int(), 3) +} + +func TestModel_Where2(t *testing.T) { + result, err := db.Table("user").Where("nickname=? AND id IN(?)", "T3", g.Slice{1,3}).OrderBy("id ASC").All() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(len(result), 1) + gtest.Assert(result[0]["id"].Int(), 3) +} + func TestModel_Delete(t *testing.T) { result, err := db.Table("user").Delete() if err != nil { diff --git a/geg/database/orm/mysql/gdb_value.go b/geg/database/orm/mysql/gdb_value.go index 31747b2ec..dddaed10e 100644 --- a/geg/database/orm/mysql/gdb_value.go +++ b/geg/database/orm/mysql/gdb_value.go @@ -10,7 +10,7 @@ func main() { // 开启调试模式,以便于记录所有执行的SQL db.SetDebug(true) - r, _ := db.Table("test").Where("id IN (?,?)", 1,2).All() + r, _ := db.Table("test").Where("id IN (?)", []interface{}{1, 2}).All() if r != nil { fmt.Println(r.ToList()) } diff --git a/geg/other/test2.go b/geg/other/test2.go index f6a556927..3e0c1a0ec 100644 --- a/geg/other/test2.go +++ b/geg/other/test2.go @@ -3,13 +3,20 @@ package main import ( "fmt" "gitee.com/johng/gf/g/util/gregex" + "strings" ) func main() { - query := "select * from user" - q, err := gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query) - fmt.Println(err) - fmt.Println(q) + newWhere := "?????" + counter := 0 + newWhere, _ = gregex.ReplaceStringFunc(`\?`, newWhere, func(s string) string { + counter++ + if counter == 4 { + return "?" + strings.Repeat(",!", 5 - 1) + } + return s + }) + fmt.Println(newWhere) } \ No newline at end of file