diff --git a/g/database/gdb/gdb_base.go b/g/database/gdb/gdb_base.go index dfc17d577..8d1e82db8 100644 --- a/g/database/gdb/gdb_base.go +++ b/g/database/gdb/gdb_base.go @@ -79,6 +79,7 @@ func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err // 数据库sql查询操作,主要执行查询 func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) { + query, args = formatQuery(query, args) query = bs.db.handleSqlBeforeExec(query) if bs.db.getDebug() { mTime1 := gtime.Millisecond() @@ -115,6 +116,7 @@ func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, er // 执行一条sql,并返回执行情况,主要用于非查询操作 func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) { + query, args = formatQuery(query, args) query = bs.db.handleSqlBeforeExec(query) if bs.db.getDebug() { mTime1 := gtime.Millisecond() @@ -179,28 +181,28 @@ func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) { } // 数据库查询,查询单条记录,自动映射数据到给定的struct对象中 -func (bs *dbBase) GetStruct(objPointer interface{}, query string, args ...interface{}) error { +func (bs *dbBase) GetStruct(pointer interface{}, query string, args ...interface{}) error { one, err := bs.GetOne(query, args...) if err != nil { return err } - return one.ToStruct(objPointer) + return one.ToStruct(pointer) } // 数据库查询,查询多条记录,并自动转换为指定的slice对象, 如: []struct/[]*struct。 -func (bs *dbBase) GetStructs(objPointerSlice interface{}, query string, args ...interface{}) error { +func (bs *dbBase) GetStructs(pointer interface{}, query string, args ...interface{}) error { all, err := bs.GetAll(query, args...) if err != nil { return err } - return all.ToStructs(objPointerSlice) + return all.ToStructs(pointer) } // 将结果转换为指定的struct/*struct/[]struct/[]*struct, // 参数应该为指针类型,否则返回失败。 // 该方法自动识别参数类型,调用Struct/Structs方法。 -func (bs *dbBase) GetScan(objPointer interface{}, query string, args ...interface{}) error { - t := reflect.TypeOf(objPointer) +func (bs *dbBase) GetScan(pointer interface{}, query string, args ...interface{}) error { + t := reflect.TypeOf(pointer) k := t.Kind() if k != reflect.Ptr { return fmt.Errorf("params should be type of pointer, but got: %v", k) @@ -208,9 +210,9 @@ func (bs *dbBase) GetScan(objPointer interface{}, query string, args ...interfac k = t.Elem().Kind() switch k { case reflect.Array, reflect.Slice: - return bs.db.GetStructs(objPointer, query, args...) + return bs.db.GetStructs(pointer, query, args...) case reflect.Struct: - return bs.db.GetStruct(objPointer, query, args...) + return bs.db.GetStruct(pointer, query, args...) } return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k) } diff --git a/g/database/gdb/gdb_func.go b/g/database/gdb/gdb_func.go index 9fb964292..1dd68c58a 100644 --- a/g/database/gdb/gdb_func.go +++ b/g/database/gdb/gdb_func.go @@ -27,7 +27,49 @@ type apiString interface { String() string } -// 格式化Where查询条件 +// 格式化SQL语句。 +// 1. 支持参数只传一个slice; +// 2. 支持占位符号数量自动扩展; +func formatQuery(query string, args []interface{}) (newQuery string, newArgs []interface{}) { + newQuery = query + // 查询条件参数处理,主要处理slice参数类型 + if len(args) > 0 { + for index, arg := range args { + rv := reflect.ValueOf(arg) + kind := rv.Kind() + if kind == reflect.Ptr { + rv = rv.Elem() + kind = rv.Kind() + } + switch kind { + // '?'占位符支持slice类型, 这里会将slice参数拆散,并更新原有占位符'?'为多个'?',使用','符号连接。 + case reflect.Slice, reflect.Array: + for i := 0; i < rv.Len(); i++ { + newArgs = append(newArgs, rv.Index(i).Interface()) + } + // 如果参数直接传递slice,并且占位符数量与slice长度相等, + // 那么不用替换扩展占位符数量,直接使用该slice作为查询参数 + if len(args) == 1 && gstr.Count(newQuery, "?") == rv.Len() { + break + } + // counter用于匹配该参数的位置(与index对应) + counter := 0 + newQuery, _ = gregex.ReplaceStringFunc(`\?`, newQuery, func(s string) string { + counter++ + if counter == index+1 { + return "?" + strings.Repeat(",?", rv.Len()-1) + } + return s + }) + default: + newArgs = append(newArgs, arg) + } + } + } + return +} + +// 格式化Where查询条件。 func formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) { // 条件字符串处理 buffer := bytes.NewBuffer(nil) @@ -38,7 +80,6 @@ func formatWhere(where interface{}, args []interface{}) (newWhere string, newArg rv = rv.Elem() kind = rv.Kind() } - tmpArgs := []interface{}(nil) switch kind { // map/struct类型 case reflect.Map: @@ -57,19 +98,20 @@ func formatWhere(where interface{}, args []interface{}) (newWhere string, newArg count := gstr.Count(key, "?") if count == 0 { buffer.WriteString(key + " IN(?)") - tmpArgs = append(tmpArgs, value) + newArgs = append(newArgs, value) } else if count != rv.Len() { buffer.WriteString(key) - tmpArgs = append(tmpArgs, value) + newArgs = append(newArgs, value) } else { buffer.WriteString(key) // 如果键名/属性名称中带有多个?占位符号,那么将参数打散 - tmpArgs = append(tmpArgs, gconv.Interfaces(value)...) + newArgs = append(newArgs, gconv.Interfaces(value)...) } default: if value == nil { buffer.WriteString(key) } else { + // 支持key带操作符号 if gstr.Pos(key, "?") == -1 { if gstr.Pos(key, "<") == -1 && gstr.Pos(key, ">") == -1 && gstr.Pos(key, "=") == -1 { buffer.WriteString(key + "=?") @@ -79,7 +121,7 @@ func formatWhere(where interface{}, args []interface{}) (newWhere string, newArg } else { buffer.WriteString(key) } - tmpArgs = append(tmpArgs, value) + newArgs = append(newArgs, value) } } } @@ -91,47 +133,16 @@ func formatWhere(where interface{}, args []interface{}) (newWhere string, newArg if buffer.Len() == 0 { return "", args } - tmpArgs = append(tmpArgs, args...) + newArgs = append(newArgs, args...) newWhere = buffer.String() // 查询条件参数处理,主要处理slice参数类型 - if len(tmpArgs) > 0 { - for index, arg := range tmpArgs { - rv := reflect.ValueOf(arg) - kind := rv.Kind() - if kind == reflect.Ptr { - rv = rv.Elem() - kind = rv.Kind() - } - switch kind { - // '?'占位符支持slice类型, 这里会将slice参数拆散,并更新原有占位符'?'为多个'?',使用','符号连接。 - case reflect.Slice, reflect.Array: - for i := 0; i < rv.Len(); i++ { - newArgs = append(newArgs, rv.Index(i).Interface()) - } - // 如果参数直接传递slice,并且占位符数量与slice长度相等, - // 那么不用替换扩展占位符数量,直接使用该slice作为查询参数 - if len(args) == 1 && gstr.Count(newWhere, "?") == rv.Len() { - break - } - // counter用于匹配该参数的位置(与index对应) - counter := 0 - newWhere, _ = gregex.ReplaceStringFunc(`\?`, newWhere, func(s string) string { - counter++ - if counter == index+1 { - return "?" + strings.Repeat(",?", rv.Len()-1) - } - return s - }) - default: - // 支持例如 Where/And/Or("uid", 1) 这种格式 - if gstr.Pos(newWhere, "?") == -1 { - if gstr.Pos(newWhere, "<") == -1 && gstr.Pos(newWhere, ">") == -1 && gstr.Pos(newWhere, "=") == -1 { - newWhere += "=?" - } else { - newWhere += "?" - } - } - newArgs = append(newArgs, arg) + if len(newArgs) > 0 { + // 支持例如 Where/And/Or("uid", 1) 这种格式 + if gstr.Pos(newWhere, "?") == -1 { + if gstr.Pos(newWhere, "<") == -1 && gstr.Pos(newWhere, ">") == -1 && gstr.Pos(newWhere, "=") == -1 { + newWhere += "=?" + } else { + newWhere += "?" } } } diff --git a/g/database/gdb/gdb_unit_method_test.go b/g/database/gdb/gdb_unit_method_test.go index 886277a24..21f72c1b7 100644 --- a/g/database/gdb/gdb_unit_method_test.go +++ b/g/database/gdb/gdb_unit_method_test.go @@ -7,11 +7,12 @@ package gdb_test import ( + "testing" + "time" + "github.com/gogf/gf/g" "github.com/gogf/gf/g/os/gtime" "github.com/gogf/gf/g/test/gtest" - "testing" - "time" ) func TestDbBase_Ping(t *testing.T) { @@ -24,12 +25,17 @@ func TestDbBase_Ping(t *testing.T) { } func TestDbBase_Query(t *testing.T) { - if _, err := db.Query("SELECT ?", 1); err != nil { - gtest.Fatal(err) - } - if _, err := db.Query("ERROR"); err == nil { - gtest.Fatal("FAIL") - } + gtest.Case(t, func() { + _, err := db.Query("SELECT ?", 1) + gtest.Assert(err, nil) + _, err = db.Query("SELECT ?+?", 1, 2) + gtest.Assert(err, nil) + _, err = db.Query("SELECT ?+?", g.Slice{1, 2}) + gtest.Assert(err, nil) + _, err = db.Query("ERROR") + gtest.AssertNE(err, nil) + }) + } func TestDbBase_Exec(t *testing.T) { @@ -290,11 +296,44 @@ func TestDbBase_Update(t *testing.T) { } func TestDbBase_GetAll(t *testing.T) { - if result, err := db.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil { - gtest.Fatal(err) - } else { + gtest.Case(t, func() { + result, err := db.GetAll("SELECT * FROM user WHERE id=?", 1) + gtest.Assert(err, nil) gtest.Assert(len(result), 1) - } + gtest.Assert(result[0]["id"].Int(), 1) + }) + gtest.Case(t, func() { + result, err := db.GetAll("SELECT * FROM user WHERE id in(?)", g.Slice{1, 2, 3}) + gtest.Assert(err, nil) + gtest.Assert(len(result), 3) + gtest.Assert(result[0]["id"].Int(), 1) + gtest.Assert(result[1]["id"].Int(), 2) + gtest.Assert(result[2]["id"].Int(), 3) + }) + gtest.Case(t, func() { + result, err := db.GetAll("SELECT * FROM user WHERE id in(?,?,?)", g.Slice{1, 2, 3}) + gtest.Assert(err, nil) + gtest.Assert(len(result), 3) + gtest.Assert(result[0]["id"].Int(), 1) + gtest.Assert(result[1]["id"].Int(), 2) + gtest.Assert(result[2]["id"].Int(), 3) + }) + gtest.Case(t, func() { + result, err := db.GetAll("SELECT * FROM user WHERE id in(?,?,?)", g.Slice{1, 2, 3}...) + gtest.Assert(err, nil) + gtest.Assert(len(result), 3) + gtest.Assert(result[0]["id"].Int(), 1) + gtest.Assert(result[1]["id"].Int(), 2) + gtest.Assert(result[2]["id"].Int(), 3) + }) + gtest.Case(t, func() { + result, err := db.GetAll("SELECT * FROM user WHERE id>=? AND id <=?", g.Slice{1, 3}) + gtest.Assert(err, nil) + gtest.Assert(len(result), 3) + gtest.Assert(result[0]["id"].Int(), 1) + gtest.Assert(result[1]["id"].Int(), 2) + gtest.Assert(result[2]["id"].Int(), 3) + }) } func TestDbBase_GetOne(t *testing.T) { diff --git a/g/database/gdb/gdb_unit_transaction_test.go b/g/database/gdb/gdb_unit_transaction_test.go index 835004241..5ddedabd6 100644 --- a/g/database/gdb/gdb_unit_transaction_test.go +++ b/g/database/gdb/gdb_unit_transaction_test.go @@ -7,10 +7,11 @@ package gdb_test import ( + "testing" + "github.com/gogf/gf/g" "github.com/gogf/gf/g/os/gtime" "github.com/gogf/gf/g/test/gtest" - "testing" ) func TestTX_Query(t *testing.T) { @@ -23,6 +24,16 @@ func TestTX_Query(t *testing.T) { } else { rows.Close() } + if rows, err := tx.Query("SELECT ?+?", 1, 2); err != nil { + gtest.Fatal(err) + } else { + rows.Close() + } + if rows, err := tx.Query("SELECT ?+?", g.Slice{1, 2}); err != nil { + gtest.Fatal(err) + } else { + rows.Close() + } if _, err := tx.Query("ERROR"); err == nil { gtest.Fatal("FAIL") } @@ -39,6 +50,12 @@ func TestTX_Exec(t *testing.T) { if _, err := tx.Exec("SELECT ?", 1); err != nil { gtest.Fatal(err) } + if _, err := tx.Exec("SELECT ?+?", 1, 2); err != nil { + gtest.Fatal(err) + } + if _, err := tx.Exec("SELECT ?+?", g.Slice{1, 2}); err != nil { + gtest.Fatal(err) + } if _, err := tx.Exec("ERROR"); err == nil { gtest.Fatal("FAIL") }