diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 7c7e096da..0a8a28cab 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -489,11 +489,20 @@ func formatSql(sql string, args []interface{}) (newSql string, newArgs []interfa return handleArguments(sql, args) } +type formatWhereInput struct { + Where interface{} + Args []interface{} + OmitNil bool + OmitEmpty bool + Schema string + Table string +} + // formatWhere formats where statement and its arguments for `Where` and `Having` statements. -func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool, schema, table string) (newWhere string, newArgs []interface{}) { +func formatWhere(db DB, in formatWhereInput) (newWhere string, newArgs []interface{}) { var ( buffer = bytes.NewBuffer(nil) - rv = reflect.ValueOf(where) + rv = reflect.ValueOf(in.Where) kind = rv.Kind() ) for kind == reflect.Ptr { @@ -502,12 +511,17 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool, s } switch kind { case reflect.Array, reflect.Slice: - newArgs = formatWhereInterfaces(db, gconv.Interfaces(where), buffer, newArgs) + newArgs = formatWhereInterfaces(db, gconv.Interfaces(in.Where), buffer, newArgs) case reflect.Map: - for key, value := range DataToMapDeep(where) { - if gregex.IsMatchString(regularFieldNameRegPattern, key) && omitEmpty && empty.IsEmpty(value) { - continue + for key, value := range DataToMapDeep(in.Where) { + if gregex.IsMatchString(regularFieldNameRegPattern, key) { + if in.OmitNil && empty.IsNil(value) { + continue + } + if in.OmitEmpty && empty.IsEmpty(value) { + continue + } } newArgs = formatWhereKeyValue(db, buffer, newArgs, key, value) } @@ -517,11 +531,16 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool, s // it then uses its Iterate function to iterates its key-value pairs. // For example, ListMap and TreeMap are ordered map, // which implement apiIterator interface and are index-friendly for where conditions. - if iterator, ok := where.(apiIterator); ok { + if iterator, ok := in.Where.(apiIterator); ok { iterator.Iterator(func(key, value interface{}) bool { ketStr := gconv.String(key) - if gregex.IsMatchString(regularFieldNameRegPattern, ketStr) && omitEmpty && empty.IsEmpty(value) { - return true + if gregex.IsMatchString(regularFieldNameRegPattern, ketStr) { + if in.OmitNil && empty.IsNil(value) { + return true + } + if in.OmitEmpty && empty.IsEmpty(value) { + return true + } } newArgs = formatWhereKeyValue(db, buffer, newArgs, ketStr, value) return true @@ -529,12 +548,15 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool, s break } // Automatically mapping and filtering the struct attribute. - data := DataToMapDeep(where) - if table != "" { - data, _ = db.GetCore().mappingAndFilterData(schema, table, data, true) + data := DataToMapDeep(in.Where) + if in.Table != "" { + data, _ = db.GetCore().mappingAndFilterData(in.Schema, in.Table, data, true) } for key, value := range data { - if omitEmpty && empty.IsEmpty(value) { + if in.OmitNil && empty.IsNil(value) { + continue + } + if in.OmitEmpty && empty.IsEmpty(value) { continue } newArgs = formatWhereKeyValue(db, buffer, newArgs, key, value) @@ -544,14 +566,14 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool, s // Usually a string. var ( i = 0 - whereStr = gconv.String(where) + whereStr = gconv.String(in.Where) ) for { - if i >= len(args) { + if i >= len(in.Args) { break } // Sub query, which is always used along with a string condition. - if model, ok := args[i].(*Model); ok { + if model, ok := in.Args[i].(*Model); ok { var ( index = -1 ) @@ -565,7 +587,7 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool, s } return s }) - args = gutil.SliceDelete(args, i) + in.Args = gutil.SliceDelete(in.Args, i) continue } i++ @@ -574,9 +596,9 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool, s } if buffer.Len() == 0 { - return "", args + return "", in.Args } - newArgs = append(newArgs, args...) + newArgs = append(newArgs, in.Args...) newWhere = buffer.String() if len(newArgs) > 0 { if gstr.Pos(newWhere, "?") == -1 { diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index fe761dca1..e7d4f823f 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -104,9 +104,14 @@ func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model { if len(tableNameQueryOrStruct) > 1 { conditionStr := gconv.String(tableNameQueryOrStruct[0]) if gstr.Contains(conditionStr, "?") { - tableStr, extraArgs = formatWhere( - c.db, conditionStr, tableNameQueryOrStruct[1:], false, "", "", - ) + tableStr, extraArgs = formatWhere(c.db, formatWhereInput{ + Where: conditionStr, + Args: tableNameQueryOrStruct[1:], + OmitNil: false, + OmitEmpty: false, + Schema: "", + Table: "", + }) } } // Normal model creation. diff --git a/database/gdb/gdb_model_condition.go b/database/gdb/gdb_model_condition.go index 5853fe7df..2ff16eb4c 100644 --- a/database/gdb/gdb_model_condition.go +++ b/database/gdb/gdb_model_condition.go @@ -388,9 +388,14 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh switch v.Operator { case whereHolderOperatorWhere: if conditionWhere == "" { - newWhere, newArgs := formatWhere( - m.db, v.Where, v.Args, m.option&optionOmitEmptyWhere > 0, m.schema, m.tables, - ) + newWhere, newArgs := formatWhere(m.db, formatWhereInput{ + Where: v.Where, + Args: v.Args, + OmitNil: m.option&optionOmitNilWhere > 0, + OmitEmpty: m.option&optionOmitEmptyWhere > 0, + Schema: m.schema, + Table: m.tables, + }) if len(newWhere) > 0 { conditionWhere = newWhere conditionArgs = newArgs @@ -400,9 +405,14 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh fallthrough case whereHolderOperatorAnd: - newWhere, newArgs := formatWhere( - m.db, v.Where, v.Args, m.option&optionOmitEmptyWhere > 0, m.schema, m.tables, - ) + newWhere, newArgs := formatWhere(m.db, formatWhereInput{ + Where: v.Where, + Args: v.Args, + OmitNil: m.option&optionOmitNilWhere > 0, + OmitEmpty: m.option&optionOmitEmptyWhere > 0, + Schema: m.schema, + Table: m.tables, + }) if len(newWhere) > 0 { if len(conditionWhere) == 0 { conditionWhere = newWhere @@ -415,9 +425,14 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh } case whereHolderOperatorOr: - newWhere, newArgs := formatWhere( - m.db, v.Where, v.Args, m.option&optionOmitEmptyWhere > 0, m.schema, m.tables, - ) + newWhere, newArgs := formatWhere(m.db, formatWhereInput{ + Where: v.Where, + Args: v.Args, + OmitNil: m.option&optionOmitNilWhere > 0, + OmitEmpty: m.option&optionOmitEmptyWhere > 0, + Schema: m.schema, + Table: m.tables, + }) if len(newWhere) > 0 { if len(conditionWhere) == 0 { conditionWhere = newWhere @@ -457,9 +472,14 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh } // HAVING. if len(m.having) > 0 { - havingStr, havingArgs := formatWhere( - m.db, m.having[0], gconv.Interfaces(m.having[1]), m.option&optionOmitEmptyWhere > 0, m.schema, m.tables, - ) + havingStr, havingArgs := formatWhere(m.db, formatWhereInput{ + Where: m.having[0], + Args: gconv.Interfaces(m.having[1]), + OmitNil: m.option&optionOmitNilWhere > 0, + OmitEmpty: m.option&optionOmitEmptyWhere > 0, + Schema: m.schema, + Table: m.tables, + }) if len(havingStr) > 0 { conditionExtra += " HAVING " + havingStr conditionArgs = append(conditionArgs, havingArgs...) diff --git a/database/gdb/gdb_model_option.go b/database/gdb/gdb_model_option.go index 6b3de8813..68991fce4 100644 --- a/database/gdb/gdb_model_option.go +++ b/database/gdb/gdb_model_option.go @@ -7,9 +7,12 @@ package gdb const ( + optionOmitNil = optionOmitNilWhere | optionOmitNilData optionOmitEmpty = optionOmitEmptyWhere | optionOmitEmptyData optionOmitEmptyWhere = 1 << iota // 8 optionOmitEmptyData // 16 + optionOmitNilWhere // 32 + optionOmitNilData // 64 ) // Option adds extra operation option for the model. @@ -20,7 +23,7 @@ func (m *Model) Option(option int) *Model { return model } -// OmitEmpty sets OmitEmpty option for the model, which automatically filers +// OmitEmpty sets optionOmitEmpty option for the model, which automatically filers // the data and where parameters for `empty` values. func (m *Model) OmitEmpty() *Model { model := m.getModel() @@ -28,7 +31,7 @@ func (m *Model) OmitEmpty() *Model { return model } -// OmitEmptyWhere sets OmitEmptyWhere option for the model, which automatically filers +// OmitEmptyWhere sets optionOmitEmptyWhere option for the model, which automatically filers // the Where/Having parameters for `empty` values. func (m *Model) OmitEmptyWhere() *Model { model := m.getModel() @@ -36,10 +39,34 @@ func (m *Model) OmitEmptyWhere() *Model { return model } -// OmitEmptyData sets OmitEmptyData option for the model, which automatically filers +// OmitEmptyData sets optionOmitEmptyData option for the model, which automatically filers // the Data parameters for `empty` values. func (m *Model) OmitEmptyData() *Model { model := m.getModel() model.option = model.option | optionOmitEmptyData return model } + +// OmitNil sets optionOmitNil option for the model, which automatically filers +// the data and where parameters for `nil` values. +func (m *Model) OmitNil() *Model { + model := m.getModel() + model.option = model.option | optionOmitNil + return model +} + +// OmitNilWhere sets optionOmitNilWhere option for the model, which automatically filers +// the Where/Having parameters for `nil` values. +func (m *Model) OmitNilWhere() *Model { + model := m.getModel() + model.option = model.option | optionOmitNilWhere + return model +} + +// OmitNilData sets optionOmitNilData option for the model, which automatically filers +// the Data parameters for `nil` values. +func (m *Model) OmitNilData() *Model { + model := m.getModel() + model.option = model.option | optionOmitNilData + return model +} diff --git a/database/gdb/gdb_model_utility.go b/database/gdb/gdb_model_utility.go index a1a12b75e..be9acfd41 100644 --- a/database/gdb/gdb_model_utility.go +++ b/database/gdb/gdb_model_utility.go @@ -109,6 +109,18 @@ func (m *Model) doMappingAndFilterForInsertOrUpdateDataMap(data Map, allowOmitEm if err != nil { return nil, err } + // Remove key-value pairs of which the value is nil. + if allowOmitEmpty && m.option&optionOmitNilData > 0 { + tempMap := make(Map, len(data)) + for k, v := range data { + if empty.IsNil(v) { + continue + } + tempMap[k] = v + } + data = tempMap + } + // Remove key-value pairs of which the value is empty. if allowOmitEmpty && m.option&optionOmitEmptyData > 0 { tempMap := make(Map, len(data)) diff --git a/database/gdb/gdb_z_mysql_association_with_test.go b/database/gdb/gdb_z_mysql_association_with_test.go index d02316e41..d1b98727e 100644 --- a/database/gdb/gdb_z_mysql_association_with_test.go +++ b/database/gdb/gdb_z_mysql_association_with_test.go @@ -368,6 +368,7 @@ PRIMARY KEY (id) gtest.Assert(err, nil) } } + gtest.C(t, func(t *gtest.T) { var users []*User err := db.With(User{}). diff --git a/database/gdb/gdb_z_mysql_model_test.go b/database/gdb/gdb_z_mysql_model_test.go index 806cc80bc..ee079006a 100644 --- a/database/gdb/gdb_z_mysql_model_test.go +++ b/database/gdb/gdb_z_mysql_model_test.go @@ -2132,24 +2132,75 @@ func Test_Model_Option_List(t *testing.T) { } func Test_Model_OmitEmpty(t *testing.T) { - gtest.C(t, func(t *gtest.T) { - table := fmt.Sprintf(`table_%s`, gtime.TimestampNanoStr()) - if _, err := db.Exec(fmt.Sprintf(` + table := fmt.Sprintf(`table_%s`, gtime.TimestampNanoStr()) + if _, err := db.Exec(fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id int(10) unsigned NOT NULL AUTO_INCREMENT, name varchar(45) NOT NULL, PRIMARY KEY (id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8; `, table)); err != nil { - gtest.Error(err) - } - defer dropTable(table) + gtest.Error(err) + } + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { _, err := db.Model(table).OmitEmpty().Data(g.Map{ "id": 1, "name": "", }).Save() t.AssertNE(err, nil) }) + gtest.C(t, func(t *gtest.T) { + _, err := db.Model(table).OmitEmptyData().Data(g.Map{ + "id": 1, + "name": "", + }).Save() + t.AssertNE(err, nil) + }) + gtest.C(t, func(t *gtest.T) { + _, err := db.Model(table).OmitEmptyWhere().Data(g.Map{ + "id": 1, + "name": "", + }).Save() + t.Assert(err, nil) + }) +} + +func Test_Model_OmitNil(t *testing.T) { + table := fmt.Sprintf(`table_%s`, gtime.TimestampNanoStr()) + if _, err := db.Exec(fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id int(10) unsigned NOT NULL AUTO_INCREMENT, + name varchar(45) NOT NULL, + PRIMARY KEY (id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, table)); err != nil { + gtest.Error(err) + } + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + _, err := db.Model(table).OmitNil().Data(g.Map{ + "id": 1, + "name": nil, + }).Save() + t.AssertNE(err, nil) + }) + gtest.C(t, func(t *gtest.T) { + _, err := db.Model(table).OmitNil().Data(g.Map{ + "id": 1, + "name": "", + }).Save() + t.Assert(err, nil) + }) + gtest.C(t, func(t *gtest.T) { + _, err := db.Model(table).OmitNilWhere().Data(g.Map{ + "id": 1, + "name": "", + }).Save() + t.Assert(err, nil) + }) } func Test_Model_Option_Where(t *testing.T) {