diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 3f3e14e38..c8fe0657d 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -202,8 +202,10 @@ const ( var ( // ErrNoRows is alias of sql.ErrNoRows. ErrNoRows = sql.ErrNoRows + // instances is the management map for instances. instances = gmap.NewStrAnyMap(true) + // driverMap manages all custom registered driver. driverMap = map[string]Driver{ "mysql": &DriverMysql{}, @@ -212,6 +214,14 @@ var ( "oracle": &DriverOracle{}, "sqlite": &DriverSqlite{}, } + + // lastOperatorRegPattern is the regular expression pattern for a string + // which has operator at its tail. + lastOperatorRegPattern = `[<>=]+\s*$` + + // regularFieldNameRegPattern is the regular expression pattern for a string + // which is a regular field name of table. + regularFieldNameRegPattern = `^[\w\.\-]+$` ) // Register registers custom database driver to gdb. diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index ecdf9a3f7..2fe885a8a 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -13,7 +13,6 @@ import ( "fmt" "github.com/gogf/gf/internal/utils" "reflect" - "regexp" "strings" "github.com/gogf/gf/container/gvar" @@ -26,12 +25,6 @@ const ( gPATH_FILTER_KEY = "/database/gdb/gdb" ) -var ( - // lastOperatorReg is the regular expression object for a string - // which has operator at its tail. - lastOperatorReg = regexp.MustCompile(`[<>=]+\s*$`) -) - // Master creates and returns a connection from master node if master-slave configured. // It returns the default connection if master-slave not configured. func (c *Core) Master() (*sql.DB, error) { diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 5b35fd25c..032ab1b5c 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -257,9 +257,11 @@ func formatSql(sql string, args []interface{}) (newQuery string, newArgs []inter // formatWhere formats where statement and its arguments. // TODO []interface{} type support for parameter does not completed yet. func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (newWhere string, newArgs []interface{}) { - buffer := bytes.NewBuffer(nil) - rv := reflect.ValueOf(where) - kind := rv.Kind() + var ( + buffer = bytes.NewBuffer(nil) + rv = reflect.ValueOf(where) + kind = rv.Kind() + ) if kind == reflect.Ptr { rv = rv.Elem() kind = rv.Kind() @@ -270,7 +272,7 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) ( case reflect.Map: for key, value := range DataToMapDeep(where) { - if omitEmpty && empty.IsEmpty(value) { + if gregex.IsMatchString(regularFieldNameRegPattern, key) && omitEmpty && empty.IsEmpty(value) { continue } newArgs = formatWhereKeyValue(db, buffer, newArgs, key, value) @@ -283,10 +285,11 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) ( // which implement apiIterator interface and are index-friendly for where conditions. if iterator, ok := where.(apiIterator); ok { iterator.Iterator(func(key, value interface{}) bool { - if omitEmpty && empty.IsEmpty(value) { + ketStr := gconv.String(key) + if gregex.IsMatchString(regularFieldNameRegPattern, ketStr) && omitEmpty && empty.IsEmpty(value) { return true } - newArgs = formatWhereKeyValue(db, buffer, newArgs, gconv.String(key), value) + newArgs = formatWhereKeyValue(db, buffer, newArgs, ketStr, value) return true }) break @@ -309,10 +312,10 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) ( newWhere = buffer.String() if len(newArgs) > 0 { if gstr.Pos(newWhere, "?") == -1 { - if lastOperatorReg.MatchString(newWhere) { + if gregex.IsMatchString(lastOperatorRegPattern, newWhere) { // Eg: Where/And/Or("uid>=", 1) newWhere += "?" - } else if gregex.IsMatchString(`^[\w\.\-]+$`, newWhere) { + } else if gregex.IsMatchString(regularFieldNameRegPattern, newWhere) { newWhere = db.QuoteString(newWhere) if len(newArgs) > 0 { if utils.IsArray(newArgs[0]) { @@ -362,8 +365,10 @@ func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key // If the value is type of slice, and there's only one '?' holder in // the key string, it automatically adds '?' holder chars according to its arguments count // and converts it to "IN" statement. - rv := reflect.ValueOf(value) - kind := rv.Kind() + var ( + rv = reflect.ValueOf(value) + kind = rv.Kind() + ) switch kind { case reflect.Slice, reflect.Array: count := gstr.Count(quotedKey, "?") @@ -379,7 +384,7 @@ func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key } default: if value == nil || empty.IsNil(rv) { - if gregex.IsMatchString(`^[\w\.\-]+$`, key) { + if gregex.IsMatchString(regularFieldNameRegPattern, key) { // The key is a single field name. buffer.WriteString(quotedKey + " IS NULL") } else { @@ -392,11 +397,24 @@ func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key if gstr.Pos(quotedKey, "?") == -1 { like := " like" if len(quotedKey) > len(like) && gstr.Equal(quotedKey[len(quotedKey)-len(like):], like) { + // Eg: Where(g.Map{"name like": "john%"}) buffer.WriteString(quotedKey + " ?") - } else if lastOperatorReg.MatchString(quotedKey) { + } else if gregex.IsMatchString(lastOperatorRegPattern, quotedKey) { + // Eg: Where(g.Map{"age > ": 16}) buffer.WriteString(quotedKey + " ?") - } else { + } else if gregex.IsMatchString(regularFieldNameRegPattern, key) { + // The key is a regular field name. buffer.WriteString(quotedKey + "=?") + } else { + // The key is not a regular field name. + // Eg: Where(g.Map{"age > 16": nil}) + // Issue: https://github.com/gogf/gf/issues/765 + if empty.IsEmpty(value) { + buffer.WriteString(quotedKey) + break + } else { + buffer.WriteString(quotedKey + "=?") + } } } else { buffer.WriteString(quotedKey) diff --git a/database/gdb/gdb_z_mysql_model_test.go b/database/gdb/gdb_z_mysql_model_test.go index dcaea9cb9..f95975d9e 100644 --- a/database/gdb/gdb_z_mysql_model_test.go +++ b/database/gdb/gdb_z_mysql_model_test.go @@ -1286,6 +1286,29 @@ func Test_Model_Where_ISNULL_2(t *testing.T) { }) } +func Test_Model_Where_OmitEmpty(t *testing.T) { + table := createInitTable() + defer dropTable(table) + gtest.C(t, func(t *gtest.T) { + conditions := g.Map{ + "id < 4": "", + } + result, err := db.Table(table).WherePri(conditions).Order("id asc").All() + t.Assert(err, nil) + t.Assert(len(result), 3) + t.Assert(result[0]["id"].Int(), 1) + }) + gtest.C(t, func(t *gtest.T) { + conditions := g.Map{ + "id < 4": "", + } + result, err := db.Table(table).WherePri(conditions).OmitEmpty().Order("id asc").All() + t.Assert(err, nil) + t.Assert(len(result), 3) + t.Assert(result[0]["id"].Int(), 1) + }) +} + func Test_Model_Where_GTime(t *testing.T) { table := createInitTable() defer dropTable(table) diff --git a/database/gdb/gdb_z_mysql_soft_time_test.go b/database/gdb/gdb_z_mysql_soft_time_test.go index 4daef93a1..70f064b5f 100644 --- a/database/gdb/gdb_z_mysql_soft_time_test.go +++ b/database/gdb/gdb_z_mysql_soft_time_test.go @@ -163,7 +163,7 @@ CREATE TABLE %s ( gtest.Error(err) } defer dropTable(table) - db.SetDebug(true) + // db.SetDebug(true) gtest.C(t, func(t *gtest.T) { for i := 1; i <= 10; i++ { data := g.Map{ diff --git a/go.mod b/go.mod index bc4d3995c..1afa28a7d 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,9 @@ require ( github.com/gqcn/structs v1.1.1 github.com/grokify/html-strip-tags-go v0.0.0-20190921062105-daaa06bf1aaf github.com/json-iterator/go v1.1.10 + github.com/mattn/go-runewidth v0.0.9 // indirect github.com/olekukonko/tablewriter v0.0.1 + golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae // indirect golang.org/x/text v0.3.2 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c )