This commit is contained in:
John
2020-06-28 23:03:41 +08:00
parent 2f44721086
commit 0e884c78f5
6 changed files with 67 additions and 21 deletions

View File

@ -202,8 +202,10 @@ const (
var (
// ErrNoRows is alias of sql.ErrNoRows.
ErrNoRows = sql.ErrNoRows
// instances is the management map for instances.
instances = gmap.NewStrAnyMap(true)
// driverMap manages all custom registered driver.
driverMap = map[string]Driver{
"mysql": &DriverMysql{},
@ -212,6 +214,14 @@ var (
"oracle": &DriverOracle{},
"sqlite": &DriverSqlite{},
}
// lastOperatorRegPattern is the regular expression pattern for a string
// which has operator at its tail.
lastOperatorRegPattern = `[<>=]+\s*$`
// regularFieldNameRegPattern is the regular expression pattern for a string
// which is a regular field name of table.
regularFieldNameRegPattern = `^[\w\.\-]+$`
)
// Register registers custom database driver to gdb.

View File

@ -13,7 +13,6 @@ import (
"fmt"
"github.com/gogf/gf/internal/utils"
"reflect"
"regexp"
"strings"
"github.com/gogf/gf/container/gvar"
@ -26,12 +25,6 @@ const (
gPATH_FILTER_KEY = "/database/gdb/gdb"
)
var (
// lastOperatorReg is the regular expression object for a string
// which has operator at its tail.
lastOperatorReg = regexp.MustCompile(`[<>=]+\s*$`)
)
// Master creates and returns a connection from master node if master-slave configured.
// It returns the default connection if master-slave not configured.
func (c *Core) Master() (*sql.DB, error) {

View File

@ -257,9 +257,11 @@ func formatSql(sql string, args []interface{}) (newQuery string, newArgs []inter
// formatWhere formats where statement and its arguments.
// TODO []interface{} type support for parameter <where> does not completed yet.
func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (newWhere string, newArgs []interface{}) {
buffer := bytes.NewBuffer(nil)
rv := reflect.ValueOf(where)
kind := rv.Kind()
var (
buffer = bytes.NewBuffer(nil)
rv = reflect.ValueOf(where)
kind = rv.Kind()
)
if kind == reflect.Ptr {
rv = rv.Elem()
kind = rv.Kind()
@ -270,7 +272,7 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (
case reflect.Map:
for key, value := range DataToMapDeep(where) {
if omitEmpty && empty.IsEmpty(value) {
if gregex.IsMatchString(regularFieldNameRegPattern, key) && omitEmpty && empty.IsEmpty(value) {
continue
}
newArgs = formatWhereKeyValue(db, buffer, newArgs, key, value)
@ -283,10 +285,11 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (
// which implement apiIterator interface and are index-friendly for where conditions.
if iterator, ok := where.(apiIterator); ok {
iterator.Iterator(func(key, value interface{}) bool {
if omitEmpty && empty.IsEmpty(value) {
ketStr := gconv.String(key)
if gregex.IsMatchString(regularFieldNameRegPattern, ketStr) && omitEmpty && empty.IsEmpty(value) {
return true
}
newArgs = formatWhereKeyValue(db, buffer, newArgs, gconv.String(key), value)
newArgs = formatWhereKeyValue(db, buffer, newArgs, ketStr, value)
return true
})
break
@ -309,10 +312,10 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (
newWhere = buffer.String()
if len(newArgs) > 0 {
if gstr.Pos(newWhere, "?") == -1 {
if lastOperatorReg.MatchString(newWhere) {
if gregex.IsMatchString(lastOperatorRegPattern, newWhere) {
// Eg: Where/And/Or("uid>=", 1)
newWhere += "?"
} else if gregex.IsMatchString(`^[\w\.\-]+$`, newWhere) {
} else if gregex.IsMatchString(regularFieldNameRegPattern, newWhere) {
newWhere = db.QuoteString(newWhere)
if len(newArgs) > 0 {
if utils.IsArray(newArgs[0]) {
@ -362,8 +365,10 @@ func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key
// If the value is type of slice, and there's only one '?' holder in
// the key string, it automatically adds '?' holder chars according to its arguments count
// and converts it to "IN" statement.
rv := reflect.ValueOf(value)
kind := rv.Kind()
var (
rv = reflect.ValueOf(value)
kind = rv.Kind()
)
switch kind {
case reflect.Slice, reflect.Array:
count := gstr.Count(quotedKey, "?")
@ -379,7 +384,7 @@ func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key
}
default:
if value == nil || empty.IsNil(rv) {
if gregex.IsMatchString(`^[\w\.\-]+$`, key) {
if gregex.IsMatchString(regularFieldNameRegPattern, key) {
// The key is a single field name.
buffer.WriteString(quotedKey + " IS NULL")
} else {
@ -392,11 +397,24 @@ func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key
if gstr.Pos(quotedKey, "?") == -1 {
like := " like"
if len(quotedKey) > len(like) && gstr.Equal(quotedKey[len(quotedKey)-len(like):], like) {
// Eg: Where(g.Map{"name like": "john%"})
buffer.WriteString(quotedKey + " ?")
} else if lastOperatorReg.MatchString(quotedKey) {
} else if gregex.IsMatchString(lastOperatorRegPattern, quotedKey) {
// Eg: Where(g.Map{"age > ": 16})
buffer.WriteString(quotedKey + " ?")
} else {
} else if gregex.IsMatchString(regularFieldNameRegPattern, key) {
// The key is a regular field name.
buffer.WriteString(quotedKey + "=?")
} else {
// The key is not a regular field name.
// Eg: Where(g.Map{"age > 16": nil})
// Issue: https://github.com/gogf/gf/issues/765
if empty.IsEmpty(value) {
buffer.WriteString(quotedKey)
break
} else {
buffer.WriteString(quotedKey + "=?")
}
}
} else {
buffer.WriteString(quotedKey)

View File

@ -1286,6 +1286,29 @@ func Test_Model_Where_ISNULL_2(t *testing.T) {
})
}
func Test_Model_Where_OmitEmpty(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
conditions := g.Map{
"id < 4": "",
}
result, err := db.Table(table).WherePri(conditions).Order("id asc").All()
t.Assert(err, nil)
t.Assert(len(result), 3)
t.Assert(result[0]["id"].Int(), 1)
})
gtest.C(t, func(t *gtest.T) {
conditions := g.Map{
"id < 4": "",
}
result, err := db.Table(table).WherePri(conditions).OmitEmpty().Order("id asc").All()
t.Assert(err, nil)
t.Assert(len(result), 3)
t.Assert(result[0]["id"].Int(), 1)
})
}
func Test_Model_Where_GTime(t *testing.T) {
table := createInitTable()
defer dropTable(table)

View File

@ -163,7 +163,7 @@ CREATE TABLE %s (
gtest.Error(err)
}
defer dropTable(table)
db.SetDebug(true)
// db.SetDebug(true)
gtest.C(t, func(t *gtest.T) {
for i := 1; i <= 10; i++ {
data := g.Map{