From 2b083709b5a85b3ae9bec33480adefe1bf6ecd57 Mon Sep 17 00:00:00 2001 From: John Date: Mon, 14 Oct 2019 23:27:48 +0800 Subject: [PATCH] add gmap.ListMap/TreeMap support for gdb.Where --- database/gdb/gdb_base.go | 123 +++++++++++--------- database/gdb/gdb_unit_z_mysql_model_test.go | 42 +++++++ frame/g/g.go | 23 ++-- 3 files changed, 125 insertions(+), 63 deletions(-) diff --git a/database/gdb/gdb_base.go b/database/gdb/gdb_base.go index 18de865e6..389f4b2c5 100644 --- a/database/gdb/gdb_base.go +++ b/database/gdb/gdb_base.go @@ -12,6 +12,7 @@ import ( "database/sql" "errors" "fmt" + "github.com/gogf/gf/container/gmap" "reflect" "regexp" "strings" @@ -31,8 +32,8 @@ const ( ) var ( - wordReg = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`) - lastOperatorReg = regexp.MustCompile(`[<>=]+\s*$`) + wordReg = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`) // Regular expression object for a word. + lastOperatorReg = regexp.MustCompile(`[<>=]+\s*$`) // Regular expression object for a string which has operator at its tail. ) // 获取最近一条执行的sql @@ -645,9 +646,7 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) { // 格式化Where查询条件。 func (bs *dbBase) formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) { - // 条件字符串处理 buffer := bytes.NewBuffer(nil) - // 使用反射进行类型判断 rv := reflect.ValueOf(where) kind := rv.Kind() if kind == reflect.Ptr { @@ -655,68 +654,42 @@ func (bs *dbBase) formatWhere(where interface{}, args []interface{}) (newWhere s kind = rv.Kind() } switch kind { - // map/struct类型 case reflect.Map: - fallthrough - case reflect.Struct: for key, value := range structToMap(where) { - // 字段安全符号判断 - key = bs.db.quoteWord(key) - if buffer.Len() > 0 { - buffer.WriteString(" AND ") - } - // 支持slice键值/属性,如果只有一个?占位符号,那么作为IN查询,否则打散作为多个查询参数 - rv := reflect.ValueOf(value) - switch rv.Kind() { - case reflect.Slice: - fallthrough - case reflect.Array: - count := gstr.Count(key, "?") - if count == 0 { - buffer.WriteString(key + " IN(?)") - newArgs = append(newArgs, value) - } else if count != rv.Len() { - buffer.WriteString(key) - newArgs = append(newArgs, value) - } else { - buffer.WriteString(key) - // 如果键名/属性名称中带有多个?占位符号,那么将参数打散 - newArgs = append(newArgs, gconv.Interfaces(value)...) - } - default: - if value == nil { - buffer.WriteString(key) - } else { - // 支持key带操作符号,注意like也算是操作符号 - key = gstr.Trim(key) - if gstr.Pos(key, "?") == -1 { - like := " like" - if len(key) > len(like) && gstr.Equal(key[len(key)-len(like):], like) { - buffer.WriteString(key + " ?") - } else if key[len(key)-1] != '<' && key[len(key)-1] != '>' && key[len(key)-1] != '=' { - buffer.WriteString(key + "=?") - } else { - buffer.WriteString(key + " ?") - } - } else { - buffer.WriteString(key) - } - newArgs = append(newArgs, value) - } + newArgs = bs.formatWhereKeyValue(buffer, newArgs, key, value) + } + + case reflect.Struct: + // ListMap and TreeMap are ordered map, + // which are index-friendly for where conditions. + switch m := where.(type) { + case *gmap.ListMap: + m.Iterator(func(key, value interface{}) bool { + newArgs = bs.formatWhereKeyValue(buffer, newArgs, gconv.String(key), value) + return true + }) + case *gmap.TreeMap: + m.Iterator(func(key, value interface{}) bool { + newArgs = bs.formatWhereKeyValue(buffer, newArgs, gconv.String(key), value) + return true + }) + default: + for key, value := range structToMap(where) { + newArgs = bs.formatWhereKeyValue(buffer, newArgs, key, value) } } default: buffer.WriteString(gconv.String(where)) } - // 没有任何条件查询参数,直接返回 + if buffer.Len() == 0 { return "", args } newArgs = append(newArgs, args...) newWhere = buffer.String() if len(newArgs) > 0 { - // 支持例如 Where/And/Or("uid", 1) , Where/And/Or("uid>=", 1) 这种格式 + // It supports formats like: Where/And/Or("uid", 1) , Where/And/Or("uid>=", 1) if gstr.Pos(newWhere, "?") == -1 { if lastOperatorReg.MatchString(newWhere) { newWhere += "?" @@ -728,6 +701,52 @@ func (bs *dbBase) formatWhere(where interface{}, args []interface{}) (newWhere s return handlerSliceArguments(newWhere, newArgs) } +// formatWhereKeyValue handles each key-value pair of the param map. +func (bs *dbBase) formatWhereKeyValue(buffer *bytes.Buffer, newArgs []interface{}, key string, value interface{}) []interface{} { + key = bs.db.quoteWord(key) + if buffer.Len() > 0 { + buffer.WriteString(" AND ") + } + // 支持slice键值/属性,如果只有一个?占位符号,那么作为IN查询,否则打散作为多个查询参数 + rv := reflect.ValueOf(value) + switch rv.Kind() { + case reflect.Slice, reflect.Array: + count := gstr.Count(key, "?") + if count == 0 { + buffer.WriteString(key + " IN(?)") + newArgs = append(newArgs, value) + } else if count != rv.Len() { + buffer.WriteString(key) + newArgs = append(newArgs, value) + } else { + buffer.WriteString(key) + // 如果键名/属性名称中带有多个?占位符号,那么将参数打散 + newArgs = append(newArgs, gconv.Interfaces(value)...) + } + default: + if value == nil { + buffer.WriteString(key) + } else { + // 支持key带操作符号,注意like也算是操作符号 + key = gstr.Trim(key) + if gstr.Pos(key, "?") == -1 { + like := " like" + if len(key) > len(like) && gstr.Equal(key[len(key)-len(like):], like) { + buffer.WriteString(key + " ?") + } else if lastOperatorReg.MatchString(key) { + buffer.WriteString(key + " ?") + } else { + buffer.WriteString(key + "=?") + } + } else { + buffer.WriteString(key) + } + newArgs = append(newArgs, value) + } + } + return newArgs +} + // 使用关键字操作符转义给定字符串。 // 如果给定的字符串不为单词,那么不转义,直接返回该字符串。 func (bs *dbBase) quoteWord(s string) string { diff --git a/database/gdb/gdb_unit_z_mysql_model_test.go b/database/gdb/gdb_unit_z_mysql_model_test.go index 979baf127..e1200e16f 100644 --- a/database/gdb/gdb_unit_z_mysql_model_test.go +++ b/database/gdb/gdb_unit_z_mysql_model_test.go @@ -9,6 +9,8 @@ package gdb_test import ( "database/sql" "fmt" + "github.com/gogf/gf/container/gmap" + "github.com/gogf/gf/util/gutil" "testing" "github.com/gogf/gf/database/gdb" @@ -756,6 +758,46 @@ func Test_Model_Where(t *testing.T) { gtest.Assert(err, nil) gtest.Assert(result["id"].Int(), 2) }) + + // gmap.Map + gtest.Case(t, func() { + result, err := db.Table(table).Where(gmap.NewFrom(g.MapAnyAny{"id": 3, "nickname": "name_3"})).One() + gtest.Assert(err, nil) + gtest.Assert(result["id"].Int(), 3) + }) + // gmap.Map key operator + gtest.Case(t, func() { + result, err := db.Table(table).Where(gmap.NewFrom(g.MapAnyAny{"id>": 1, "id<": 3})).One() + gtest.Assert(err, nil) + gtest.Assert(result["id"].Int(), 2) + }) + + // list map + gtest.Case(t, func() { + result, err := db.Table(table).Where(gmap.NewListMapFrom(g.MapAnyAny{"id": 3, "nickname": "name_3"})).One() + gtest.Assert(err, nil) + gtest.Assert(result["id"].Int(), 3) + }) + // list map key operator + gtest.Case(t, func() { + result, err := db.Table(table).Where(gmap.NewListMapFrom(g.MapAnyAny{"id>": 1, "id<": 3})).One() + gtest.Assert(err, nil) + gtest.Assert(result["id"].Int(), 2) + }) + + // tree map + gtest.Case(t, func() { + result, err := db.Table(table).Where(gmap.NewTreeMapFrom(gutil.ComparatorString, g.MapAnyAny{"id": 3, "nickname": "name_3"})).One() + gtest.Assert(err, nil) + gtest.Assert(result["id"].Int(), 3) + }) + // tree map key operator + gtest.Case(t, func() { + result, err := db.Table(table).Where(gmap.NewTreeMapFrom(gutil.ComparatorString, g.MapAnyAny{"id>": 1, "id<": 3})).One() + gtest.Assert(err, nil) + gtest.Assert(result["id"].Int(), 2) + }) + // complicated where 1 gtest.Case(t, func() { //db.SetDebug(true) diff --git a/frame/g/g.go b/frame/g/g.go index 5ed4481d4..06a774c7b 100644 --- a/frame/g/g.go +++ b/frame/g/g.go @@ -30,17 +30,18 @@ type MapIntBool = map[int]bool // Frequently-used slice type alias. type List = []Map -type ListAnyStr = []map[interface{}]string -type ListAnyInt = []map[interface{}]int -type ListStrAny = []map[string]interface{} -type ListStrStr = []map[string]string -type ListStrInt = []map[string]int -type ListIntAny = []map[int]interface{} -type ListIntStr = []map[int]string -type ListIntInt = []map[int]int -type ListAnyBool = []map[interface{}]bool -type ListStrBool = []map[string]bool -type ListIntBool = []map[int]bool +type ListAnyAny = []Map +type ListAnyStr = []MapAnyStr +type ListAnyInt = []MapAnyInt +type ListStrAny = []MapStrAny +type ListStrStr = []MapStrStr +type ListStrInt = []MapStrInt +type ListIntAny = []MapIntAny +type ListIntStr = []MapIntStr +type ListIntInt = []MapIntInt +type ListAnyBool = []MapAnyBool +type ListStrBool = []MapStrBool +type ListIntBool = []MapIntBool // Frequently-used slice type alias. type Slice = []interface{}