diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 394b49eb9..9c5a487e4 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -196,20 +196,26 @@ func (c *Core) GetStructs(ctx context.Context, pointer interface{}, sql string, // the conversion. If parameter `pointer` is type of slice, it calls GetStructs internally // for conversion. func (c *Core) GetScan(ctx context.Context, pointer interface{}, sql string, args ...interface{}) error { - t := reflect.TypeOf(pointer) - k := t.Kind() - if k != reflect.Ptr { - return gerror.NewCodef(gcode.CodeInvalidParameter, "params should be type of pointer, but got: %v", k) + reflectInfo := utils.OriginTypeAndKind(pointer) + if reflectInfo.InputKind != reflect.Ptr { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + "params should be type of pointer, but got: %v", + reflectInfo.InputKind, + ) } - k = t.Elem().Kind() - switch k { + switch reflectInfo.OriginKind { case reflect.Array, reflect.Slice: return c.db.GetCore().GetStructs(ctx, pointer, sql, args...) case reflect.Struct: return c.db.GetCore().GetStruct(ctx, pointer, sql, args...) } - return gerror.NewCodef(gcode.CodeInvalidParameter, "element type should be type of struct/slice, unsupported: %v", k) + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `in valid parameter type "%v", of which element type should be type of struct/slice`, + reflectInfo.InputType, + ) } // GetValue queries and returns the field value from database. diff --git a/database/gdb/gdb_core_transaction.go b/database/gdb/gdb_core_transaction.go index fb574e580..d0c5cf613 100644 --- a/database/gdb/gdb_core_transaction.go +++ b/database/gdb/gdb_core_transaction.go @@ -11,6 +11,7 @@ import ( "database/sql" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/internal/utils" "reflect" "github.com/gogf/gf/v2/container/gtype" @@ -399,20 +400,26 @@ func (tx *TX) GetStructs(objPointerSlice interface{}, sql string, args ...interf // the conversion. If parameter `pointer` is type of slice, it calls GetStructs internally // for conversion. func (tx *TX) GetScan(pointer interface{}, sql string, args ...interface{}) error { - t := reflect.TypeOf(pointer) - k := t.Kind() - if k != reflect.Ptr { - return gerror.NewCodef(gcode.CodeInvalidParameter, "params should be type of pointer, but got: %v", k) + reflectInfo := utils.OriginTypeAndKind(pointer) + if reflectInfo.InputKind != reflect.Ptr { + return gerror.NewCodef( + gcode.CodeInvalidParameter, + "params should be type of pointer, but got: %v", + reflectInfo.InputKind, + ) } - k = t.Elem().Kind() - switch k { + switch reflectInfo.OriginKind { case reflect.Array, reflect.Slice: return tx.GetStructs(pointer, sql, args...) case reflect.Struct: return tx.GetStruct(pointer, sql, args...) } - return gerror.NewCodef(gcode.CodeInvalidParameter, "element type should be type of struct/slice, unsupported: %v", k) + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `in valid parameter type "%v", of which element type should be type of struct/slice`, + reflectInfo.InputType, + ) } // GetValue queries and returns the field value from database. diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 46bba0259..d82583538 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -58,6 +58,8 @@ type iTableName interface { const ( OrmTagForStruct = "orm" + OrmTagForUnique = "unique" + OrmTagForPrimary = "primary" OrmTagForTable = "table" OrmTagForWith = "with" OrmTagForWithWhere = "where" @@ -400,26 +402,22 @@ func formatSql(sql string, args []interface{}) (newSql string, newArgs []interfa } type formatWhereInput struct { - Where interface{} - Args []interface{} - OmitNil bool - OmitEmpty bool - Schema string - Table string + Where interface{} + Args []interface{} + OmitNil bool + OmitEmpty bool + IgnoreEmptySliceWhere bool + Schema string + Table string } // formatWhere formats where statement and its arguments for `Where` and `Having` statements. func formatWhere(db DB, in formatWhereInput) (newWhere string, newArgs []interface{}) { var ( - buffer = bytes.NewBuffer(nil) - reflectValue = reflect.ValueOf(in.Where) - reflectKind = reflectValue.Kind() + buffer = bytes.NewBuffer(nil) + reflectInfo = utils.OriginValueAndKind(in.Where) ) - for reflectKind == reflect.Ptr { - reflectValue = reflectValue.Elem() - reflectKind = reflectValue.Kind() - } - switch reflectKind { + switch reflectInfo.OriginKind { case reflect.Array, reflect.Slice: newArgs = formatWhereInterfaces(db, gconv.Interfaces(in.Where), buffer, newArgs) @@ -433,7 +431,14 @@ func formatWhere(db DB, in formatWhereInput) (newWhere string, newArgs []interfa continue } } - newArgs = formatWhereKeyValue(db, buffer, newArgs, key, value) + newArgs = formatWhereKeyValue(formatWhereKeyValueInput{ + Db: db, + Buffer: buffer, + Args: newArgs, + Key: key, + Value: value, + IgnoreEmptySliceWhere: in.IgnoreEmptySliceWhere, + }) } case reflect.Struct: @@ -452,14 +457,21 @@ func formatWhere(db DB, in formatWhereInput) (newWhere string, newArgs []interfa return true } } - newArgs = formatWhereKeyValue(db, buffer, newArgs, ketStr, value) + newArgs = formatWhereKeyValue(formatWhereKeyValueInput{ + Db: db, + Buffer: buffer, + Args: newArgs, + Key: ketStr, + Value: value, + IgnoreEmptySliceWhere: in.IgnoreEmptySliceWhere, + }) return true }) break } // Automatically mapping and filtering the struct attribute. var ( - reflectType = reflectValue.Type() + reflectType = reflectInfo.OriginValue.Type() structField reflect.StructField ) data := DataToMapDeep(in.Where) @@ -477,7 +489,14 @@ func formatWhere(db DB, in formatWhereInput) (newWhere string, newArgs []interfa if in.OmitEmpty && empty.IsEmpty(foundValue) { continue } - newArgs = formatWhereKeyValue(db, buffer, newArgs, foundKey, foundValue) + newArgs = formatWhereKeyValue(formatWhereKeyValueInput{ + Db: db, + Buffer: buffer, + Args: newArgs, + Key: foundKey, + Value: foundValue, + IgnoreEmptySliceWhere: in.IgnoreEmptySliceWhere, + }) } } @@ -487,6 +506,15 @@ func formatWhere(db DB, in formatWhereInput) (newWhere string, newArgs []interfa i = 0 whereStr = gconv.String(in.Where) ) + // Eg: + // Where("id", []int{}).All() -> SELECT xxx FROM xxx WHERE 0=1 + // IgnoreEmptySliceWhere().Where("id", []int{}).One() -> SELECT xxx FROM xxx + if in.IgnoreEmptySliceWhere && len(in.Args) == 1 && utils.IsArray(in.Args[0]) { + if gstr.Count(whereStr, "?") == 0 && utils.IsEmpty(in.Args[0]) { + in.Args = in.Args[:0] + break + } + } for { if i >= len(in.Args) { break @@ -517,7 +545,9 @@ func formatWhere(db DB, in formatWhereInput) (newWhere string, newArgs []interfa if buffer.Len() == 0 { return "", in.Args } - newArgs = append(newArgs, in.Args...) + if len(in.Args) > 0 { + newArgs = append(newArgs, in.Args...) + } newWhere = buffer.String() if len(newArgs) > 0 { if gstr.Pos(newWhere, "?") == -1 { @@ -577,77 +607,100 @@ func formatWhereInterfaces(db DB, where []interface{}, buffer *bytes.Buffer, new return newArgs } +type formatWhereKeyValueInput struct { + Db DB + Buffer *bytes.Buffer + Args []interface{} + Key string + Value interface{} + IgnoreEmptySliceWhere bool +} + // formatWhereKeyValue handles each key-value pair of the parameter map. -func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key string, value interface{}) []interface{} { - quotedKey := db.GetCore().QuoteWord(key) - if buffer.Len() > 0 { - buffer.WriteString(" AND ") - } +func formatWhereKeyValue(in formatWhereKeyValueInput) (newArgs []interface{}) { + quotedKey := in.Db.GetCore().QuoteWord(in.Key) // If the value is type of slice, and there's only one '?' holder in // the key string, it automatically adds '?' holder chars according to its arguments count // and converts it to "IN" statement. var ( - rv = reflect.ValueOf(value) - kind = rv.Kind() + reflectValue = reflect.ValueOf(in.Value) + reflectKind = reflectValue.Kind() ) - switch kind { + switch reflectKind { + // Slice argument. case reflect.Slice, reflect.Array: count := gstr.Count(quotedKey, "?") - if count == 0 { - buffer.WriteString(quotedKey + " IN(?)") - newArgs = append(newArgs, value) - } else if count != rv.Len() { - buffer.WriteString(quotedKey) - newArgs = append(newArgs, value) - } else { - buffer.WriteString(quotedKey) - newArgs = append(newArgs, gconv.Interfaces(value)...) + // Eg: + // Where("id", []int{}).All() -> SELECT xxx FROM xxx WHERE 0=1 + // IgnoreEmptySliceWhere().Where("id", []int{}).One() -> SELECT xxx FROM xxx + if count == 0 && reflectValue.Len() == 0 && in.IgnoreEmptySliceWhere { + return in.Args } + + if in.Buffer.Len() > 0 { + in.Buffer.WriteString(" AND ") + } + if count == 0 { + in.Buffer.WriteString(quotedKey + " IN(?)") + in.Args = append(in.Args, in.Value) + } else { + if count != reflectValue.Len() { + in.Buffer.WriteString(quotedKey) + in.Args = append(in.Args, in.Value) + } else { + in.Buffer.WriteString(quotedKey) + in.Args = append(in.Args, gconv.Interfaces(in.Value)...) + } + } + default: - if value == nil || empty.IsNil(rv) { - if gregex.IsMatchString(regularFieldNameRegPattern, key) { + if in.Buffer.Len() > 0 { + in.Buffer.WriteString(" AND ") + } + if in.Value == nil || empty.IsNil(reflectValue) { + if gregex.IsMatchString(regularFieldNameRegPattern, in.Key) { // The key is a single field name. - buffer.WriteString(quotedKey + " IS NULL") + in.Buffer.WriteString(quotedKey + " IS NULL") } else { // The key may have operation chars. - buffer.WriteString(quotedKey) + in.Buffer.WriteString(quotedKey) } } else { // It also supports "LIKE" statement, which we considers it an operator. quotedKey = gstr.Trim(quotedKey) if gstr.Pos(quotedKey, "?") == -1 { - like := " like" + like := " LIKE" if len(quotedKey) > len(like) && gstr.Equal(quotedKey[len(quotedKey)-len(like):], like) { // Eg: Where(g.Map{"name like": "john%"}) - buffer.WriteString(quotedKey + " ?") + in.Buffer.WriteString(quotedKey + " ?") } else if gregex.IsMatchString(lastOperatorRegPattern, quotedKey) { // Eg: Where(g.Map{"age > ": 16}) - buffer.WriteString(quotedKey + " ?") - } else if gregex.IsMatchString(regularFieldNameRegPattern, key) { + in.Buffer.WriteString(quotedKey + " ?") + } else if gregex.IsMatchString(regularFieldNameRegPattern, in.Key) { // The key is a regular field name. - buffer.WriteString(quotedKey + "=?") + in.Buffer.WriteString(quotedKey + "=?") } else { // The key is not a regular field name. // Eg: Where(g.Map{"age > 16": nil}) // Issue: https://github.com/gogf/gf/issues/765 - if empty.IsEmpty(value) { - buffer.WriteString(quotedKey) + if empty.IsEmpty(in.Value) { + in.Buffer.WriteString(quotedKey) break } else { - buffer.WriteString(quotedKey + "=?") + in.Buffer.WriteString(quotedKey + "=?") } } } else { - buffer.WriteString(quotedKey) + in.Buffer.WriteString(quotedKey) } - if s, ok := value.(Raw); ok { - buffer.WriteString(gconv.String(s)) + if s, ok := in.Value.(Raw); ok { + in.Buffer.WriteString(gconv.String(s)) } else { - newArgs = append(newArgs, value) + in.Args = append(in.Args, in.Value) } } } - return newArgs + return in.Args } // handleArguments is an important function, which handles the sql and all its arguments @@ -659,15 +712,8 @@ func handleArguments(sql string, args []interface{}) (newSql string, newArgs []i // Handles the slice arguments. if len(args) > 0 { for index, arg := range args { - var ( - reflectValue = reflect.ValueOf(arg) - reflectKind = reflectValue.Kind() - ) - for reflectKind == reflect.Ptr { - reflectValue = reflectValue.Elem() - reflectKind = reflectValue.Kind() - } - switch reflectKind { + reflectInfo := utils.OriginValueAndKind(arg) + switch reflectInfo.OriginKind { case reflect.Slice, reflect.Array: // It does not split the type of []byte. // Eg: table.Where("name = ?", []byte("john")) @@ -676,7 +722,7 @@ func handleArguments(sql string, args []interface{}) (newSql string, newArgs []i continue } - if reflectValue.Len() == 0 { + if reflectInfo.OriginValue.Len() == 0 { // Empty slice argument, it converts the sql to a false sql. // Eg: // Query("select * from xxx where id in(?)", g.Slice{}) -> select * from xxx where 0=1 @@ -690,15 +736,15 @@ func handleArguments(sql string, args []interface{}) (newSql string, newArgs []i } } } else { - for i := 0; i < reflectValue.Len(); i++ { - newArgs = append(newArgs, reflectValue.Index(i).Interface()) + for i := 0; i < reflectInfo.OriginValue.Len(); i++ { + newArgs = append(newArgs, reflectInfo.OriginValue.Index(i).Interface()) } } // If the '?' holder count equals the length of the slice, // it does not implement the arguments splitting logic. // Eg: db.Query("SELECT ?+?", g.Slice{1, 2}) - if len(args) == 1 && gstr.Count(newSql, "?") == reflectValue.Len() { + if len(args) == 1 && gstr.Count(newSql, "?") == reflectInfo.OriginValue.Len() { break } // counter is used to finding the inserting position for the '?' holder. @@ -713,8 +759,8 @@ func handleArguments(sql string, args []interface{}) (newSql string, newArgs []i counter++ if counter == index+insertHolderCount+1 { replaced = true - insertHolderCount += reflectValue.Len() - 1 - return "?" + strings.Repeat(",?", reflectValue.Len()-1) + insertHolderCount += reflectInfo.OriginValue.Len() - 1 + return "?" + strings.Repeat(",?", reflectInfo.OriginValue.Len()-1) } return s }) @@ -781,17 +827,13 @@ func FormatSqlWithArgs(sql string, args []interface{}) string { return "null" } var ( - rv = reflect.ValueOf(args[index]) - kind = rv.Kind() + reflectInfo = utils.OriginValueAndKind(args[index]) ) - if kind == reflect.Ptr { - if rv.IsNil() || !rv.IsValid() { - return "null" - } - rv = rv.Elem() - kind = rv.Kind() + if reflectInfo.OriginKind == reflect.Ptr && + (reflectInfo.OriginValue.IsNil() || !reflectInfo.OriginValue.IsValid()) { + return "null" } - switch kind { + switch reflectInfo.OriginKind { case reflect.String, reflect.Map, reflect.Slice, reflect.Array: return `'` + gstr.QuoteMeta(gconv.String(args[index]), `'`) + `'` diff --git a/database/gdb/gdb_model_condition.go b/database/gdb/gdb_model_condition.go index 43b898cb8..5c9e7f36f 100644 --- a/database/gdb/gdb_model_condition.go +++ b/database/gdb/gdb_model_condition.go @@ -340,12 +340,13 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh case whereHolderOperatorWhere: if conditionWhere == "" { newWhere, newArgs := formatWhere(m.db, formatWhereInput{ - Where: v.Where, - Args: v.Args, - OmitNil: m.option&optionOmitNilWhere > 0, - OmitEmpty: m.option&optionOmitEmptyWhere > 0, - Schema: m.schema, - Table: m.tables, + Where: v.Where, + Args: v.Args, + OmitNil: m.option&optionOmitNilWhere > 0, + OmitEmpty: m.option&optionOmitEmptyWhere > 0, + IgnoreEmptySliceWhere: m.option&optionIgnoreEmptySliceWhere > 0, + Schema: m.schema, + Table: m.tables, }) if len(newWhere) > 0 { conditionWhere = newWhere @@ -357,12 +358,13 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh case whereHolderOperatorAnd: newWhere, newArgs := formatWhere(m.db, formatWhereInput{ - Where: v.Where, - Args: v.Args, - OmitNil: m.option&optionOmitNilWhere > 0, - OmitEmpty: m.option&optionOmitEmptyWhere > 0, - Schema: m.schema, - Table: m.tables, + Where: v.Where, + Args: v.Args, + OmitNil: m.option&optionOmitNilWhere > 0, + OmitEmpty: m.option&optionOmitEmptyWhere > 0, + IgnoreEmptySliceWhere: m.option&optionIgnoreEmptySliceWhere > 0, + Schema: m.schema, + Table: m.tables, }) if len(newWhere) > 0 { if len(conditionWhere) == 0 { @@ -377,12 +379,13 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh case whereHolderOperatorOr: newWhere, newArgs := formatWhere(m.db, formatWhereInput{ - Where: v.Where, - Args: v.Args, - OmitNil: m.option&optionOmitNilWhere > 0, - OmitEmpty: m.option&optionOmitEmptyWhere > 0, - Schema: m.schema, - Table: m.tables, + Where: v.Where, + Args: v.Args, + OmitNil: m.option&optionOmitNilWhere > 0, + OmitEmpty: m.option&optionOmitEmptyWhere > 0, + IgnoreEmptySliceWhere: m.option&optionIgnoreEmptySliceWhere > 0, + Schema: m.schema, + Table: m.tables, }) if len(newWhere) > 0 { if len(conditionWhere) == 0 { @@ -424,12 +427,13 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh // HAVING. if len(m.having) > 0 { havingStr, havingArgs := formatWhere(m.db, formatWhereInput{ - Where: m.having[0], - Args: gconv.Interfaces(m.having[1]), - OmitNil: m.option&optionOmitNilWhere > 0, - OmitEmpty: m.option&optionOmitEmptyWhere > 0, - Schema: m.schema, - Table: m.tables, + Where: m.having[0], + Args: gconv.Interfaces(m.having[1]), + OmitNil: m.option&optionOmitNilWhere > 0, + OmitEmpty: m.option&optionOmitEmptyWhere > 0, + IgnoreEmptySliceWhere: m.option&optionIgnoreEmptySliceWhere > 0, + Schema: m.schema, + Table: m.tables, }) if len(havingStr) > 0 { conditionExtra += " HAVING " + havingStr diff --git a/database/gdb/gdb_model_insert.go b/database/gdb/gdb_model_insert.go index b1b042738..3bff29e18 100644 --- a/database/gdb/gdb_model_insert.go +++ b/database/gdb/gdb_model_insert.go @@ -10,6 +10,7 @@ import ( "database/sql" "github.com/gogf/gf/v2/container/gset" "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/internal/utils" "reflect" "github.com/gogf/gf/v2/errors/gerror" @@ -69,18 +70,13 @@ func (m *Model) Data(data ...interface{}) *Model { default: var ( - rv = reflect.ValueOf(params) - kind = rv.Kind() + reflectInfo = utils.OriginValueAndKind(params) ) - if kind == reflect.Ptr { - rv = rv.Elem() - kind = rv.Kind() - } - switch kind { + switch reflectInfo.OriginKind { case reflect.Slice, reflect.Array: - list := make(List, rv.Len()) - for i := 0; i < rv.Len(); i++ { - list[i] = ConvertDataForTableRecord(rv.Index(i).Interface()) + list := make(List, reflectInfo.OriginValue.Len()) + for i := 0; i < reflectInfo.OriginValue.Len(); i++ { + list[i] = ConvertDataForTableRecord(reflectInfo.OriginValue.Index(i).Interface()) } model.data = list @@ -246,19 +242,14 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err default: var ( - rv = reflect.ValueOf(newData) - kind = rv.Kind() + reflectInfo = utils.OriginValueAndKind(newData) ) - if kind == reflect.Ptr { - rv = rv.Elem() - kind = rv.Kind() - } - switch kind { + switch reflectInfo.OriginKind { // If it's slice type, it then converts it to List type. case reflect.Slice, reflect.Array: - list = make(List, rv.Len()) - for i := 0; i < rv.Len(); i++ { - list[i] = ConvertDataForTableRecord(rv.Index(i).Interface()) + list = make(List, reflectInfo.OriginValue.Len()) + for i := 0; i < reflectInfo.OriginValue.Len(); i++ { + list[i] = ConvertDataForTableRecord(reflectInfo.OriginValue.Index(i).Interface()) } case reflect.Map: @@ -278,7 +269,11 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err } default: - return result, gerror.NewCodef(gcode.CodeInvalidParameter, "unsupported list type:%v", kind) + return result, gerror.NewCodef( + gcode.CodeInvalidParameter, + "unsupported data list type: %v", + reflectInfo.InputValue.Type(), + ) } } @@ -331,17 +326,12 @@ func (m *Model) formatDoInsertOption(insertOption int, columnNames []string) (op default: var ( - reflectValue = reflect.ValueOf(m.onDuplicate) - reflectKind = reflectValue.Kind() + reflectInfo = utils.OriginValueAndKind(m.onDuplicate) ) - for reflectKind == reflect.Ptr { - reflectValue = reflectValue.Elem() - reflectKind = reflectValue.Kind() - } - switch reflectKind { + switch reflectInfo.OriginKind { case reflect.String: option.OnDuplicateMap = make(map[string]interface{}) - for _, v := range gstr.SplitAndTrim(reflectValue.String(), ",") { + for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") { if onDuplicateExKeySet.Contains(v) { continue } @@ -393,16 +383,11 @@ func (m *Model) formatOnDuplicateExKeys(onDuplicateEx interface{}) ([]string, er } var ( - reflectValue = reflect.ValueOf(onDuplicateEx) - reflectKind = reflectValue.Kind() + reflectInfo = utils.OriginValueAndKind(onDuplicateEx) ) - for reflectKind == reflect.Ptr { - reflectValue = reflectValue.Elem() - reflectKind = reflectValue.Kind() - } - switch reflectKind { + switch reflectInfo.OriginKind { case reflect.String: - return gstr.SplitAndTrim(reflectValue.String(), ","), nil + return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil case reflect.Map: return gutil.Keys(onDuplicateEx), nil diff --git a/database/gdb/gdb_model_option.go b/database/gdb/gdb_model_option.go index 61238c3f5..3f155038e 100644 --- a/database/gdb/gdb_model_option.go +++ b/database/gdb/gdb_model_option.go @@ -7,14 +7,27 @@ package gdb const ( - optionOmitNil = optionOmitNilWhere | optionOmitNilData - optionOmitEmpty = optionOmitEmptyWhere | optionOmitEmptyData - optionOmitEmptyWhere = 1 << iota // 8 - optionOmitEmptyData // 16 - optionOmitNilWhere // 32 - optionOmitNilData // 64 + optionOmitNil = optionOmitNilWhere | optionOmitNilData + optionOmitEmpty = optionOmitEmptyWhere | optionOmitEmptyData + optionOmitEmptyWhere = 1 << iota // 8 + optionOmitEmptyData // 16 + optionOmitNilWhere // 32 + optionOmitNilData // 64 + optionIgnoreEmptySliceWhere // 128 ) +// IgnoreEmptySliceWhere sets optionIgnoreEmptySliceWhere option for the model, which automatically filers +// the where parameters for `empty` slice values. +// +// Eg: +// Where("id", []int{}).All() -> SELECT xxx FROM xxx WHERE 0=1 +// OmitEmptyWhereSlice().Where("id", []int{}).One() -> SELECT xxx FROM xxx +func (m *Model) IgnoreEmptySliceWhere() *Model { + model := m.getModel() + model.option = model.option | optionIgnoreEmptySliceWhere + return model +} + // OmitEmpty sets optionOmitEmpty option for the model, which automatically filers // the data and where parameters for `empty` values. func (m *Model) OmitEmpty() *Model { diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index e1a4a18d8..6c1445794 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -10,6 +10,7 @@ import ( "fmt" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/internal/utils" "reflect" "github.com/gogf/gf/v2/container/gset" @@ -293,24 +294,15 @@ func (m *Model) doStructs(pointer interface{}, where ...interface{}) error { // err := db.Model("user").Scan(&users) func (m *Model) Scan(pointer interface{}, where ...interface{}) error { var ( - reflectValue reflect.Value - reflectKind reflect.Kind + reflectInfo = utils.OriginTypeAndKind(pointer) ) - if v, ok := pointer.(reflect.Value); ok { - reflectValue = v - } else { - reflectValue = reflect.ValueOf(pointer) + if reflectInfo.InputKind != reflect.Ptr { + return gerror.NewCode( + gcode.CodeInvalidParameter, + `the parameter "pointer" for function Scan should type of pointer`, + ) } - - reflectKind = reflectValue.Kind() - if reflectKind != reflect.Ptr { - return gerror.NewCode(gcode.CodeInvalidParameter, `the parameter "pointer" for function Scan should type of pointer`) - } - for reflectKind == reflect.Ptr { - reflectValue = reflectValue.Elem() - reflectKind = reflectValue.Kind() - } - switch reflectKind { + switch reflectInfo.OriginKind { case reflect.Slice, reflect.Array: return m.doStructs(pointer, where...) diff --git a/database/gdb/gdb_model_update.go b/database/gdb/gdb_model_update.go index 7df735488..976912e2d 100644 --- a/database/gdb/gdb_model_update.go +++ b/database/gdb/gdb_model_update.go @@ -10,6 +10,7 @@ import ( "database/sql" "fmt" "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/internal/utils" "reflect" "github.com/gogf/gf/v2/errors/gerror" @@ -49,14 +50,9 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro // Automatically update the record updating time. if !m.unscoped && fieldNameUpdate != "" { var ( - refValue = reflect.ValueOf(m.data) - refKind = refValue.Kind() + reflectInfo = utils.OriginTypeAndKind(m.data) ) - if refKind == reflect.Ptr { - refValue = refValue.Elem() - refKind = refValue.Kind() - } - switch refKind { + switch reflectInfo.OriginKind { case reflect.Map, reflect.Struct: dataMap := ConvertDataForTableRecord(m.data) if fieldNameUpdate != "" { diff --git a/database/gdb/gdb_type_result_scanlist.go b/database/gdb/gdb_type_result_scanlist.go index e481a4c28..6fb16baae 100644 --- a/database/gdb/gdb_type_result_scanlist.go +++ b/database/gdb/gdb_type_result_scanlist.go @@ -68,12 +68,20 @@ func doScanList(model *Model, result Result, listPointer interface{}, bindToAttr reflectKind = reflectValue.Kind() } if reflectKind != reflect.Ptr { - return gerror.NewCodef(gcode.CodeInvalidParameter, "listPointer should be type of *[]struct/*[]*struct, but got: %v", reflectKind) + return gerror.NewCodef( + gcode.CodeInvalidParameter, + "listPointer should be type of *[]struct/*[]*struct, but got: %v", + reflectKind, + ) } reflectValue = reflectValue.Elem() reflectKind = reflectValue.Kind() if reflectKind != reflect.Slice && reflectKind != reflect.Array { - return gerror.NewCodef(gcode.CodeInvalidParameter, "listPointer should be type of *[]struct/*[]*struct, but got: %v", reflectKind) + return gerror.NewCodef( + gcode.CodeInvalidParameter, + "listPointer should be type of *[]struct/*[]*struct, but got: %v", + reflectKind, + ) } length := len(result) if length == 0 { @@ -146,14 +154,21 @@ func doScanList(model *Model, result Result, listPointer interface{}, bindToAttr relationResultFieldName = key } } else { - return gerror.NewCode(gcode.CodeInvalidParameter, `parameter relationKV should be format of "ResultFieldName:BindToAttrName"`) + return gerror.NewCode( + gcode.CodeInvalidParameter, + `parameter relationKV should be format of "ResultFieldName:BindToAttrName"`, + ) } if relationResultFieldName != "" { // Note that the value might be type of slice. relationDataMap = result.MapKeyValue(relationResultFieldName) } if len(relationDataMap) == 0 { - return gerror.NewCodef(gcode.CodeInvalidParameter, `cannot find the relation data map, maybe invalid relation given "%v"`, relationKV) + return gerror.NewCodef( + gcode.CodeInvalidParameter, + `cannot find the relation data map, maybe invalid relation given "%v"`, + relationKV, + ) } } // Bind to target attribute. diff --git a/database/gdb/gdb_z_mysql_method_test.go b/database/gdb/gdb_z_mysql_method_test.go index c89cc3e05..ae8f7a383 100644 --- a/database/gdb/gdb_z_mysql_method_test.go +++ b/database/gdb/gdb_z_mysql_method_test.go @@ -662,6 +662,19 @@ func Test_DB_GetScan(t *testing.T) { t.AssertNil(err) t.Assert(user.NickName, "name_3") }) + gtest.C(t, func(t *gtest.T) { + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime gtime.Time + } + var user *User + err := db.GetScan(ctx, &user, fmt.Sprintf("SELECT * FROM %s WHERE id=?", table), 3) + t.AssertNil(err) + t.Assert(user.NickName, "name_3") + }) gtest.C(t, func(t *gtest.T) { type User struct { Id int diff --git a/database/gdb/gdb_z_mysql_model_test.go b/database/gdb/gdb_z_mysql_model_test.go index 9b4b744b8..42d308aaf 100644 --- a/database/gdb/gdb_z_mysql_model_test.go +++ b/database/gdb/gdb_z_mysql_model_test.go @@ -3676,6 +3676,68 @@ func Test_Model_FieldAvg(t *testing.T) { }) } +func Test_Model_IgnoreEmptySliceWhere(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // Key-Value where. + gtest.C(t, func(t *gtest.T) { + count, err := db.Model(table).Where("id", g.Slice{1, 2, 3}).Count() + t.AssertNil(err) + t.Assert(count, 3) + }) + gtest.C(t, func(t *gtest.T) { + count, err := db.Model(table).Where("id", g.Slice{}).Count() + t.AssertNil(err) + t.Assert(count, 0) + }) + gtest.C(t, func(t *gtest.T) { + count, err := db.Model(table).IgnoreEmptySliceWhere().Where("id", g.Slice{}).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) + gtest.C(t, func(t *gtest.T) { + count, err := db.Model(table).Where("id", g.Slice{}).IgnoreEmptySliceWhere().Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) + // Struct Where. + gtest.C(t, func(t *gtest.T) { + type Input struct { + Id []int + } + count, err := db.Model(table).Where(Input{}).Count() + t.AssertNil(err) + t.Assert(count, 0) + }) + gtest.C(t, func(t *gtest.T) { + type Input struct { + Id []int + } + count, err := db.Model(table).Where(Input{}).IgnoreEmptySliceWhere().Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) + // Map Where. + gtest.C(t, func(t *gtest.T) { + count, err := db.Model(table).Where(g.Map{ + "id": []int{}, + }).Count() + t.AssertNil(err) + t.Assert(count, 0) + }) + gtest.C(t, func(t *gtest.T) { + type Input struct { + Id []int + } + count, err := db.Model(table).Where(g.Map{ + "id": []int{}, + }).IgnoreEmptySliceWhere().Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) +} + // https://github.com/gogf/gf/issues/1387 func Test_Model_GTime_DefaultValue(t *testing.T) { table := createTable() diff --git a/database/gredis/gredis_redis_conn.go b/database/gredis/gredis_redis_conn.go index 7649c7d8a..9a330fba9 100644 --- a/database/gredis/gredis_redis_conn.go +++ b/database/gredis/gredis_redis_conn.go @@ -10,6 +10,7 @@ import ( "context" "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/internal/json" + "github.com/gogf/gf/v2/internal/utils" "github.com/gogf/gf/v2/os/gtime" "reflect" ) @@ -23,18 +24,11 @@ type RedisConn struct { // Do sends a command to the server and returns the received reply. // It uses json.Marshal for struct/slice/map type values before committing them to redis. func (c *RedisConn) Do(ctx context.Context, command string, args ...interface{}) (reply *gvar.Var, err error) { - var ( - reflectValue reflect.Value - reflectKind reflect.Kind - ) for k, v := range args { - reflectValue = reflect.ValueOf(v) - reflectKind = reflectValue.Kind() - if reflectKind == reflect.Ptr { - reflectValue = reflectValue.Elem() - reflectKind = reflectValue.Kind() - } - switch reflectKind { + var ( + reflectInfo = utils.OriginTypeAndKind(v) + ) + switch reflectInfo.OriginKind { case reflect.Struct, reflect.Map, diff --git a/encoding/gjson/gjson.go b/encoding/gjson/gjson.go index 2461ca5ff..85cf14b56 100644 --- a/encoding/gjson/gjson.go +++ b/encoding/gjson/gjson.go @@ -259,13 +259,10 @@ func (j *Json) convertValue(value interface{}) interface{} { case []interface{}: return value default: - rv := reflect.ValueOf(value) - kind := rv.Kind() - if kind == reflect.Ptr { - rv = rv.Elem() - kind = rv.Kind() - } - switch kind { + var ( + reflectInfo = utils.OriginTypeAndKind(value) + ) + switch reflectInfo.OriginKind { case reflect.Array: return gconv.Interfaces(value) case reflect.Slice: diff --git a/encoding/gjson/gjson_api_new_load.go b/encoding/gjson/gjson_api_new_load.go index ddaac9097..a09e17ba5 100644 --- a/encoding/gjson/gjson_api_new_load.go +++ b/encoding/gjson/gjson_api_new_load.go @@ -10,6 +10,7 @@ import ( "bytes" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/internal/utils" "reflect" "github.com/gogf/gf/v2/internal/json" @@ -68,14 +69,9 @@ func NewWithOptions(data interface{}, options Options) *Json { } default: var ( - rv = reflect.ValueOf(data) - kind = rv.Kind() + reflectInfo = utils.OriginTypeAndKind(data) ) - if kind == reflect.Ptr { - rv = rv.Elem() - kind = rv.Kind() - } - switch kind { + switch reflectInfo.OriginKind { case reflect.Slice, reflect.Array: i := interface{}(nil) i = gconv.Interfaces(data) diff --git a/internal/utils/utils_reflect.go b/internal/utils/utils_reflect.go new file mode 100644 index 000000000..f8c4df94e --- /dev/null +++ b/internal/utils/utils_reflect.go @@ -0,0 +1,61 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package utils + +import "reflect" + +type OriginValueAndKindOutput struct { + InputValue reflect.Value + InputKind reflect.Kind + OriginValue reflect.Value + OriginKind reflect.Kind +} + +// OriginValueAndKind retrieves and returns the original reflect value and kind. +func OriginValueAndKind(value interface{}) (out OriginValueAndKindOutput) { + if v, ok := value.(reflect.Value); ok { + out.InputValue = v + } else { + out.InputValue = reflect.ValueOf(value) + } + out.InputKind = out.InputValue.Kind() + out.OriginValue = out.InputValue + out.OriginKind = out.InputKind + for out.OriginKind == reflect.Ptr { + out.OriginValue = out.OriginValue.Elem() + out.OriginKind = out.OriginValue.Kind() + } + return +} + +type OriginTypeAndKindOutput struct { + InputType reflect.Type + InputKind reflect.Kind + OriginType reflect.Type + OriginKind reflect.Kind +} + +// OriginTypeAndKind retrieves and returns the original reflect type and kind. +func OriginTypeAndKind(value interface{}) (out OriginTypeAndKindOutput) { + if reflectType, ok := value.(reflect.Type); ok { + out.InputType = reflectType + } else { + if reflectValue, ok := value.(reflect.Value); ok { + out.InputType = reflectValue.Type() + } else { + out.InputType = reflect.TypeOf(value) + } + } + out.InputKind = out.InputType.Kind() + out.OriginType = out.InputType + out.OriginKind = out.InputKind + for out.OriginKind == reflect.Ptr { + out.OriginType = out.OriginType.Elem() + out.OriginKind = out.OriginType.Kind() + } + return +} diff --git a/internal/utils/utils_z_test.go b/internal/utils/utils_z_test.go index 2b7242e77..27e1892a8 100644 --- a/internal/utils/utils_z_test.go +++ b/internal/utils/utils_z_test.go @@ -10,6 +10,7 @@ import ( "github.com/gogf/gf/v2/internal/utils" "github.com/gogf/gf/v2/test/gtest" "io/ioutil" + "reflect" "testing" ) @@ -69,3 +70,57 @@ func Test_RemoveSymbols(t *testing.T) { t.Assert(utils.RemoveSymbols(`-a-b._a c1!@#$%^&*()_+:";'.,'01`), `abac101`) }) } + +func Test_OriginValueAndKind(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var s = "s" + out := utils.OriginValueAndKind(s) + t.Assert(out.InputKind, reflect.String) + t.Assert(out.OriginKind, reflect.String) + }) + gtest.C(t, func(t *gtest.T) { + var s = "s" + out := utils.OriginValueAndKind(&s) + t.Assert(out.InputKind, reflect.Ptr) + t.Assert(out.OriginKind, reflect.String) + }) + gtest.C(t, func(t *gtest.T) { + var s []int + out := utils.OriginValueAndKind(s) + t.Assert(out.InputKind, reflect.Slice) + t.Assert(out.OriginKind, reflect.Slice) + }) + gtest.C(t, func(t *gtest.T) { + var s []int + out := utils.OriginValueAndKind(&s) + t.Assert(out.InputKind, reflect.Ptr) + t.Assert(out.OriginKind, reflect.Slice) + }) +} + +func Test_OriginTypeAndKind(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var s = "s" + out := utils.OriginTypeAndKind(s) + t.Assert(out.InputKind, reflect.String) + t.Assert(out.OriginKind, reflect.String) + }) + gtest.C(t, func(t *gtest.T) { + var s = "s" + out := utils.OriginTypeAndKind(&s) + t.Assert(out.InputKind, reflect.Ptr) + t.Assert(out.OriginKind, reflect.String) + }) + gtest.C(t, func(t *gtest.T) { + var s []int + out := utils.OriginTypeAndKind(s) + t.Assert(out.InputKind, reflect.Slice) + t.Assert(out.OriginKind, reflect.Slice) + }) + gtest.C(t, func(t *gtest.T) { + var s []int + out := utils.OriginTypeAndKind(&s) + t.Assert(out.InputKind, reflect.Ptr) + t.Assert(out.OriginKind, reflect.Slice) + }) +} diff --git a/net/ghttp/ghttp_request_param.go b/net/ghttp/ghttp_request_param.go index f985c207e..075faa908 100644 --- a/net/ghttp/ghttp_request_param.go +++ b/net/ghttp/ghttp_request_param.go @@ -72,7 +72,8 @@ func (r *Request) doParse(pointer interface{}, requestType int) error { if reflectKind1 != reflect.Ptr { return gerror.NewCodef( gcode.CodeInvalidParameter, - "parameter should be type of *struct/**struct/*[]struct/*[]*struct, but got: %v", + `invalid parameter type "%v", of which kind should be of *struct/**struct/*[]struct/*[]*struct, but got: "%v"`, + reflectVal1.Type(), reflectKind1, ) }