diff --git a/database/gdb/gdb_model_fields.go b/database/gdb/gdb_model_fields.go index 28c317ba3..05b97429d 100644 --- a/database/gdb/gdb_model_fields.go +++ b/database/gdb/gdb_model_fields.go @@ -10,6 +10,8 @@ import ( "fmt" "github.com/gogf/gf/container/gset" "github.com/gogf/gf/text/gstr" + "github.com/gogf/gf/util/gconv" + "github.com/gogf/gf/util/gutil" ) // Filter marks filtering the fields which does not exist in the fields of the operated table. @@ -24,10 +26,25 @@ func (m *Model) Filter() *Model { } // Fields sets the operation fields of the model, multiple fields joined using char ','. -func (m *Model) Fields(fields ...string) *Model { - if len(fields) > 0 { +// The parameter can be type of string/map/*map/struct/*struct. +func (m *Model) Fields(fieldNamesOrMapStruct ...interface{}) *Model { + length := len(fieldNamesOrMapStruct) + if length == 0 { + return m + } + switch { + case length >= 2: model := m.getModel() - model.fields = gstr.Join(m.mappingToTableFields(fields), ",") + model.fields = gstr.Join(m.mappingToTableFields(gconv.Strings(fieldNamesOrMapStruct)), ",") + return model + case length == 1: + model := m.getModel() + switch r := fieldNamesOrMapStruct[0].(type) { + case string: + model.fields = gstr.Join(m.mappingToTableFields([]string{r}), ",") + default: + model.fields = gstr.Join(m.mappingToTableFields(gutil.Keys(r)), ",") + } return model } return m @@ -35,10 +52,24 @@ func (m *Model) Fields(fields ...string) *Model { // FieldsEx sets the excluded operation fields of the model, multiple fields joined using char ','. // Note that this function supports only single table operations. -func (m *Model) FieldsEx(fields ...string) *Model { - if len(fields) > 0 { - model := m.getModel() - model.fieldsEx = gstr.Join(m.mappingToTableFields(fields), ",") +// The parameter can be type of string/map/*map/struct/*struct. +func (m *Model) FieldsEx(fieldNamesOrMapStruct ...interface{}) *Model { + length := len(fieldNamesOrMapStruct) + if length == 0 { + return m + } + model := m.getModel() + switch { + case length >= 2: + model.fieldsEx = gstr.Join(m.mappingToTableFields(gconv.Strings(fieldNamesOrMapStruct)), ",") + return model + case length == 1: + switch r := fieldNamesOrMapStruct[0].(type) { + case string: + model.fieldsEx = gstr.Join(m.mappingToTableFields([]string{r}), ",") + default: + model.fieldsEx = gstr.Join(m.mappingToTableFields(gutil.Keys(r)), ",") + } return model } return m diff --git a/database/gdb/gdb_z_mysql_model_test.go b/database/gdb/gdb_z_mysql_model_test.go index 874c47f1d..0c06a6535 100644 --- a/database/gdb/gdb_z_mysql_model_test.go +++ b/database/gdb/gdb_z_mysql_model_test.go @@ -2616,6 +2616,84 @@ func Test_Model_Fields_AutoMapping(t *testing.T) { t.Assert(err, nil) t.Assert(value.String(), "name_2") }) + // Map + gtest.C(t, func(t *gtest.T) { + one, err := db.Table(table).Fields(g.Map{ + "ID": 1, + "NICK_NAME": 1, + }).Where("id", 2).One() + t.Assert(err, nil) + t.Assert(len(one), 2) + t.Assert(one["id"], 2) + t.Assert(one["nickname"], "name_2") + }) + // Struct + gtest.C(t, func(t *gtest.T) { + type T struct { + ID int + NICKNAME int + } + one, err := db.Table(table).Fields(&T{ + ID: 0, + NICKNAME: 0, + }).Where("id", 2).One() + t.Assert(err, nil) + t.Assert(len(one), 2) + t.Assert(one["id"], 2) + t.Assert(one["nickname"], "name_2") + }) +} + +func Test_Model_FieldsEx_AutoMapping(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // "id": i, + // "passport": fmt.Sprintf(`user_%d`, i), + // "password": fmt.Sprintf(`pass_%d`, i), + // "nickname": fmt.Sprintf(`name_%d`, i), + // "create_time": gtime.NewFromStr("2018-10-24 10:00:00").String(), + + gtest.C(t, func(t *gtest.T) { + value, err := db.Table(table).FieldsEx("Passport, Password, NickName, CreateTime").Where("id", 2).Value() + t.Assert(err, nil) + t.Assert(value.Int(), 2) + }) + + gtest.C(t, func(t *gtest.T) { + value, err := db.Table(table).FieldsEx("ID, Passport, Password, CreateTime").Where("id", 2).Value() + t.Assert(err, nil) + t.Assert(value.String(), "name_2") + }) + // Map + gtest.C(t, func(t *gtest.T) { + one, err := db.Table(table).FieldsEx(g.Map{ + "Passport": 1, + "Password": 1, + "CreateTime": 1, + }).Where("id", 2).One() + t.Assert(err, nil) + t.Assert(len(one), 2) + t.Assert(one["id"], 2) + t.Assert(one["nickname"], "name_2") + }) + // Struct + gtest.C(t, func(t *gtest.T) { + type T struct { + Passport int + Password int + CreateTime int + } + one, err := db.Table(table).FieldsEx(&T{ + Passport: 0, + Password: 0, + CreateTime: 0, + }).Where("id", 2).One() + t.Assert(err, nil) + t.Assert(len(one), 2) + t.Assert(one["id"], 2) + t.Assert(one["nickname"], "name_2") + }) } func Test_Model_NullField(t *testing.T) {