add slice parameter support for method operations for gdb

This commit is contained in:
John
2019-07-09 12:04:24 +08:00
parent 0bba2092af
commit 835a971f89
4 changed files with 135 additions and 66 deletions

View File

@ -79,6 +79,7 @@ func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err
// 数据库sql查询操作主要执行查询
func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) {
query, args = formatQuery(query, args)
query = bs.db.handleSqlBeforeExec(query)
if bs.db.getDebug() {
mTime1 := gtime.Millisecond()
@ -115,6 +116,7 @@ func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, er
// 执行一条sql并返回执行情况主要用于非查询操作
func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) {
query, args = formatQuery(query, args)
query = bs.db.handleSqlBeforeExec(query)
if bs.db.getDebug() {
mTime1 := gtime.Millisecond()
@ -179,28 +181,28 @@ func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
}
// 数据库查询查询单条记录自动映射数据到给定的struct对象中
func (bs *dbBase) GetStruct(objPointer interface{}, query string, args ...interface{}) error {
func (bs *dbBase) GetStruct(pointer interface{}, query string, args ...interface{}) error {
one, err := bs.GetOne(query, args...)
if err != nil {
return err
}
return one.ToStruct(objPointer)
return one.ToStruct(pointer)
}
// 数据库查询查询多条记录并自动转换为指定的slice对象, 如: []struct/[]*struct。
func (bs *dbBase) GetStructs(objPointerSlice interface{}, query string, args ...interface{}) error {
func (bs *dbBase) GetStructs(pointer interface{}, query string, args ...interface{}) error {
all, err := bs.GetAll(query, args...)
if err != nil {
return err
}
return all.ToStructs(objPointerSlice)
return all.ToStructs(pointer)
}
// 将结果转换为指定的struct/*struct/[]struct/[]*struct,
// 参数应该为指针类型,否则返回失败。
// 该方法自动识别参数类型调用Struct/Structs方法。
func (bs *dbBase) GetScan(objPointer interface{}, query string, args ...interface{}) error {
t := reflect.TypeOf(objPointer)
func (bs *dbBase) GetScan(pointer interface{}, query string, args ...interface{}) error {
t := reflect.TypeOf(pointer)
k := t.Kind()
if k != reflect.Ptr {
return fmt.Errorf("params should be type of pointer, but got: %v", k)
@ -208,9 +210,9 @@ func (bs *dbBase) GetScan(objPointer interface{}, query string, args ...interfac
k = t.Elem().Kind()
switch k {
case reflect.Array, reflect.Slice:
return bs.db.GetStructs(objPointer, query, args...)
return bs.db.GetStructs(pointer, query, args...)
case reflect.Struct:
return bs.db.GetStruct(objPointer, query, args...)
return bs.db.GetStruct(pointer, query, args...)
}
return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k)
}

View File

@ -27,7 +27,49 @@ type apiString interface {
String() string
}
// 格式化Where查询条件
// 格式化SQL语句。
// 1. 支持参数只传一个slice
// 2. 支持占位符号数量自动扩展;
func formatQuery(query string, args []interface{}) (newQuery string, newArgs []interface{}) {
newQuery = query
// 查询条件参数处理主要处理slice参数类型
if len(args) > 0 {
for index, arg := range args {
rv := reflect.ValueOf(arg)
kind := rv.Kind()
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
// '?'占位符支持slice类型, 这里会将slice参数拆散并更新原有占位符'?'为多个'?',使用','符号连接。
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
newArgs = append(newArgs, rv.Index(i).Interface())
}
// 如果参数直接传递slice并且占位符数量与slice长度相等
// 那么不用替换扩展占位符数量直接使用该slice作为查询参数
if len(args) == 1 && gstr.Count(newQuery, "?") == rv.Len() {
break
}
// counter用于匹配该参数的位置(与index对应)
counter := 0
newQuery, _ = gregex.ReplaceStringFunc(`\?`, newQuery, func(s string) string {
counter++
if counter == index+1 {
return "?" + strings.Repeat(",?", rv.Len()-1)
}
return s
})
default:
newArgs = append(newArgs, arg)
}
}
}
return
}
// 格式化Where查询条件。
func formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) {
// 条件字符串处理
buffer := bytes.NewBuffer(nil)
@ -38,7 +80,6 @@ func formatWhere(where interface{}, args []interface{}) (newWhere string, newArg
rv = rv.Elem()
kind = rv.Kind()
}
tmpArgs := []interface{}(nil)
switch kind {
// map/struct类型
case reflect.Map:
@ -57,19 +98,20 @@ func formatWhere(where interface{}, args []interface{}) (newWhere string, newArg
count := gstr.Count(key, "?")
if count == 0 {
buffer.WriteString(key + " IN(?)")
tmpArgs = append(tmpArgs, value)
newArgs = append(newArgs, value)
} else if count != rv.Len() {
buffer.WriteString(key)
tmpArgs = append(tmpArgs, value)
newArgs = append(newArgs, value)
} else {
buffer.WriteString(key)
// 如果键名/属性名称中带有多个?占位符号,那么将参数打散
tmpArgs = append(tmpArgs, gconv.Interfaces(value)...)
newArgs = append(newArgs, gconv.Interfaces(value)...)
}
default:
if value == nil {
buffer.WriteString(key)
} else {
// 支持key带操作符号
if gstr.Pos(key, "?") == -1 {
if gstr.Pos(key, "<") == -1 && gstr.Pos(key, ">") == -1 && gstr.Pos(key, "=") == -1 {
buffer.WriteString(key + "=?")
@ -79,7 +121,7 @@ func formatWhere(where interface{}, args []interface{}) (newWhere string, newArg
} else {
buffer.WriteString(key)
}
tmpArgs = append(tmpArgs, value)
newArgs = append(newArgs, value)
}
}
}
@ -91,47 +133,16 @@ func formatWhere(where interface{}, args []interface{}) (newWhere string, newArg
if buffer.Len() == 0 {
return "", args
}
tmpArgs = append(tmpArgs, args...)
newArgs = append(newArgs, args...)
newWhere = buffer.String()
// 查询条件参数处理主要处理slice参数类型
if len(tmpArgs) > 0 {
for index, arg := range tmpArgs {
rv := reflect.ValueOf(arg)
kind := rv.Kind()
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
}
switch kind {
// '?'占位符支持slice类型, 这里会将slice参数拆散并更新原有占位符'?'为多个'?',使用','符号连接。
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
newArgs = append(newArgs, rv.Index(i).Interface())
}
// 如果参数直接传递slice并且占位符数量与slice长度相等
// 那么不用替换扩展占位符数量直接使用该slice作为查询参数
if len(args) == 1 && gstr.Count(newWhere, "?") == rv.Len() {
break
}
// counter用于匹配该参数的位置(与index对应)
counter := 0
newWhere, _ = gregex.ReplaceStringFunc(`\?`, newWhere, func(s string) string {
counter++
if counter == index+1 {
return "?" + strings.Repeat(",?", rv.Len()-1)
}
return s
})
default:
// 支持例如 Where/And/Or("uid", 1) 这种格式
if gstr.Pos(newWhere, "?") == -1 {
if gstr.Pos(newWhere, "<") == -1 && gstr.Pos(newWhere, ">") == -1 && gstr.Pos(newWhere, "=") == -1 {
newWhere += "=?"
} else {
newWhere += "?"
}
}
newArgs = append(newArgs, arg)
if len(newArgs) > 0 {
// 支持例如 Where/And/Or("uid", 1) 这种格式
if gstr.Pos(newWhere, "?") == -1 {
if gstr.Pos(newWhere, "<") == -1 && gstr.Pos(newWhere, ">") == -1 && gstr.Pos(newWhere, "=") == -1 {
newWhere += "=?"
} else {
newWhere += "?"
}
}
}

View File

@ -7,11 +7,12 @@
package gdb_test
import (
"testing"
"time"
"github.com/gogf/gf/g"
"github.com/gogf/gf/g/os/gtime"
"github.com/gogf/gf/g/test/gtest"
"testing"
"time"
)
func TestDbBase_Ping(t *testing.T) {
@ -24,12 +25,17 @@ func TestDbBase_Ping(t *testing.T) {
}
func TestDbBase_Query(t *testing.T) {
if _, err := db.Query("SELECT ?", 1); err != nil {
gtest.Fatal(err)
}
if _, err := db.Query("ERROR"); err == nil {
gtest.Fatal("FAIL")
}
gtest.Case(t, func() {
_, err := db.Query("SELECT ?", 1)
gtest.Assert(err, nil)
_, err = db.Query("SELECT ?+?", 1, 2)
gtest.Assert(err, nil)
_, err = db.Query("SELECT ?+?", g.Slice{1, 2})
gtest.Assert(err, nil)
_, err = db.Query("ERROR")
gtest.AssertNE(err, nil)
})
}
func TestDbBase_Exec(t *testing.T) {
@ -290,11 +296,44 @@ func TestDbBase_Update(t *testing.T) {
}
func TestDbBase_GetAll(t *testing.T) {
if result, err := db.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil {
gtest.Fatal(err)
} else {
gtest.Case(t, func() {
result, err := db.GetAll("SELECT * FROM user WHERE id=?", 1)
gtest.Assert(err, nil)
gtest.Assert(len(result), 1)
}
gtest.Assert(result[0]["id"].Int(), 1)
})
gtest.Case(t, func() {
result, err := db.GetAll("SELECT * FROM user WHERE id in(?)", g.Slice{1, 2, 3})
gtest.Assert(err, nil)
gtest.Assert(len(result), 3)
gtest.Assert(result[0]["id"].Int(), 1)
gtest.Assert(result[1]["id"].Int(), 2)
gtest.Assert(result[2]["id"].Int(), 3)
})
gtest.Case(t, func() {
result, err := db.GetAll("SELECT * FROM user WHERE id in(?,?,?)", g.Slice{1, 2, 3})
gtest.Assert(err, nil)
gtest.Assert(len(result), 3)
gtest.Assert(result[0]["id"].Int(), 1)
gtest.Assert(result[1]["id"].Int(), 2)
gtest.Assert(result[2]["id"].Int(), 3)
})
gtest.Case(t, func() {
result, err := db.GetAll("SELECT * FROM user WHERE id in(?,?,?)", g.Slice{1, 2, 3}...)
gtest.Assert(err, nil)
gtest.Assert(len(result), 3)
gtest.Assert(result[0]["id"].Int(), 1)
gtest.Assert(result[1]["id"].Int(), 2)
gtest.Assert(result[2]["id"].Int(), 3)
})
gtest.Case(t, func() {
result, err := db.GetAll("SELECT * FROM user WHERE id>=? AND id <=?", g.Slice{1, 3})
gtest.Assert(err, nil)
gtest.Assert(len(result), 3)
gtest.Assert(result[0]["id"].Int(), 1)
gtest.Assert(result[1]["id"].Int(), 2)
gtest.Assert(result[2]["id"].Int(), 3)
})
}
func TestDbBase_GetOne(t *testing.T) {

View File

@ -7,10 +7,11 @@
package gdb_test
import (
"testing"
"github.com/gogf/gf/g"
"github.com/gogf/gf/g/os/gtime"
"github.com/gogf/gf/g/test/gtest"
"testing"
)
func TestTX_Query(t *testing.T) {
@ -23,6 +24,16 @@ func TestTX_Query(t *testing.T) {
} else {
rows.Close()
}
if rows, err := tx.Query("SELECT ?+?", 1, 2); err != nil {
gtest.Fatal(err)
} else {
rows.Close()
}
if rows, err := tx.Query("SELECT ?+?", g.Slice{1, 2}); err != nil {
gtest.Fatal(err)
} else {
rows.Close()
}
if _, err := tx.Query("ERROR"); err == nil {
gtest.Fatal("FAIL")
}
@ -39,6 +50,12 @@ func TestTX_Exec(t *testing.T) {
if _, err := tx.Exec("SELECT ?", 1); err != nil {
gtest.Fatal(err)
}
if _, err := tx.Exec("SELECT ?+?", 1, 2); err != nil {
gtest.Fatal(err)
}
if _, err := tx.Exec("SELECT ?+?", g.Slice{1, 2}); err != nil {
gtest.Fatal(err)
}
if _, err := tx.Exec("ERROR"); err == nil {
gtest.Fatal("FAIL")
}