From 2e38416e12a29a527e1d7318bd4b88f64ba6344e Mon Sep 17 00:00:00 2001 From: John Guo Date: Tue, 11 May 2021 20:00:50 +0800 Subject: [PATCH] improve struct embedded association case of with feature for package gdb --- database/gdb/gdb_model_with.go | 10 +- .../gdb/gdb_z_mysql_association_with_test.go | 121 +++++++++++++++++- internal/structs/structs_field.go | 9 +- internal/structs/structs_z_unit_test.go | 4 +- util/gvalid/gvalid_validator_check_struct.go | 2 +- 5 files changed, 134 insertions(+), 12 deletions(-) diff --git a/database/gdb/gdb_model_with.go b/database/gdb/gdb_model_with.go index 2810d9a25..892474228 100644 --- a/database/gdb/gdb_model_with.go +++ b/database/gdb/gdb_model_with.go @@ -55,7 +55,7 @@ func (m *Model) WithAll() *Model { // getWithTagObjectArrayFrom retrieves and returns object array that have "with" tag in the struct. func (m *Model) getWithTagObjectArrayFrom(pointer interface{}) ([]interface{}, error) { - fieldMap, err := structs.FieldMap(pointer, nil) + fieldMap, err := structs.FieldMap(pointer, nil, false) if err != nil { return nil, err } @@ -86,6 +86,7 @@ func (m *Model) doWithScanStruct(pointer interface{}) error { err error withArray = m.withArray ) + // If with all feature is enabled, it then retrieves all the attributes which have with tag defined. if m.withAll { withArray, err = m.getWithTagObjectArrayFrom(pointer) if err != nil { @@ -95,10 +96,11 @@ func (m *Model) doWithScanStruct(pointer interface{}) error { if len(withArray) == 0 { return nil } - fieldMap, err := structs.FieldMap(pointer, nil) + fieldMap, err := structs.FieldMap(pointer, nil, false) if err != nil { return err } + // Check the with array and automatically call the ScanList to complete association querying. for withIndex, withItem := range withArray { withItemReflectValueType, err := structs.StructType(withItem) if err != nil { @@ -110,6 +112,7 @@ func (m *Model) doWithScanStruct(pointer interface{}) error { fieldType = fieldValue.Type() fieldTypeStr = gstr.TrimAll(fieldType.String(), "*[]") ) + // It does select operation if the field type is in the specified with type array. if gstr.Compare(fieldTypeStr, withItemReflectValueTypeStr) == 0 { var ( withTag string @@ -174,6 +177,7 @@ func (m *Model) doWithScanStruct(pointer interface{}) error { } // doWithScanStructs handles model association operations feature for struct slice. +// Also see doWithScanStruct. func (m *Model) doWithScanStructs(pointer interface{}) error { var ( err error @@ -188,7 +192,7 @@ func (m *Model) doWithScanStructs(pointer interface{}) error { if len(withArray) == 0 { return nil } - fieldMap, err := structs.FieldMap(pointer, nil) + fieldMap, err := structs.FieldMap(pointer, nil, false) if err != nil { return err } diff --git a/database/gdb/gdb_z_mysql_association_with_test.go b/database/gdb/gdb_z_mysql_association_with_test.go index d8a34cae9..7cf1bfd92 100644 --- a/database/gdb/gdb_z_mysql_association_with_test.go +++ b/database/gdb/gdb_z_mysql_association_with_test.go @@ -212,7 +212,7 @@ PRIMARY KEY (id) }) } -func Test_Table_Relation_With_ScanList(t *testing.T) { +func Test_Table_Relation_With(t *testing.T) { var ( tableUser = "user" tableUserDetail = "user_detail" @@ -411,7 +411,7 @@ PRIMARY KEY (id) }) } -func Test_Table_Relation_WithAll_Scan(t *testing.T) { +func Test_Table_Relation_WithAll(t *testing.T) { var ( tableUser = "user" tableUserDetail = "user_detail" @@ -526,7 +526,7 @@ PRIMARY KEY (id) }) } -func Test_Table_Relation_WithAll_ScanList(t *testing.T) { +func Test_Table_Relation_WithAll_List(t *testing.T) { var ( tableUser = "user" tableUserDetail = "user_detail" @@ -666,3 +666,118 @@ PRIMARY KEY (id) t.Assert(users[1].UserScores[4].Score, 5) }) } + +func Test_Table_Relation_WithAll_Embedded(t *testing.T) { + var ( + tableUser = "user" + tableUserDetail = "user_detail" + tableUserScores = "user_scores" + ) + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE IF NOT EXISTS %s ( +id int(10) unsigned NOT NULL AUTO_INCREMENT, +name varchar(45) NOT NULL, +PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, tableUser)); err != nil { + gtest.Error(err) + } + defer dropTable(tableUser) + + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE IF NOT EXISTS %s ( +uid int(10) unsigned NOT NULL AUTO_INCREMENT, +address varchar(45) NOT NULL, +PRIMARY KEY (uid) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, tableUserDetail)); err != nil { + gtest.Error(err) + } + defer dropTable(tableUserDetail) + + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE IF NOT EXISTS %s ( +id int(10) unsigned NOT NULL AUTO_INCREMENT, +uid int(10) unsigned NOT NULL, +score int(10) unsigned NOT NULL, +PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, tableUserScores)); err != nil { + gtest.Error(err) + } + defer dropTable(tableUserScores) + + type UserDetail struct { + gmeta.Meta `orm:"table:user_detail"` + Uid int `json:"uid"` + Address string `json:"address"` + } + + type UserScores struct { + gmeta.Meta `orm:"table:user_scores"` + Id int `json:"id"` + Uid int `json:"uid"` + Score int `json:"score"` + } + + type User struct { + gmeta.Meta `orm:"table:user"` + *UserDetail `orm:"with:uid=id"` + Id int `json:"id"` + Name string `json:"name"` + UserScores []*UserScores `orm:"with:uid=id"` + } + + // Initialize the data. + var err error + for i := 1; i <= 5; i++ { + // User. + _, err = db.Insert(tableUser, g.Map{ + "id": i, + "name": fmt.Sprintf(`name_%d`, i), + }) + gtest.Assert(err, nil) + // Detail. + _, err = db.Insert(tableUserDetail, g.Map{ + "uid": i, + "address": fmt.Sprintf(`address_%d`, i), + }) + gtest.Assert(err, nil) + // Scores. + for j := 1; j <= 5; j++ { + _, err = db.Insert(tableUserScores, g.Map{ + "uid": i, + "score": j, + }) + gtest.Assert(err, nil) + } + } + gtest.C(t, func(t *gtest.T) { + var user *User + err := db.Model(tableUser).WithAll().Where("id", 3).Scan(&user) + t.AssertNil(err) + t.Assert(user.Id, 3) + t.AssertNE(user.UserDetail, nil) + t.Assert(user.UserDetail.Uid, 3) + t.Assert(user.UserDetail.Address, `address_3`) + t.Assert(len(user.UserScores), 5) + t.Assert(user.UserScores[0].Uid, 3) + t.Assert(user.UserScores[0].Score, 1) + t.Assert(user.UserScores[4].Uid, 3) + t.Assert(user.UserScores[4].Score, 5) + }) + gtest.C(t, func(t *gtest.T) { + var user User + err := db.Model(tableUser).WithAll().Where("id", 4).Scan(&user) + t.AssertNil(err) + t.Assert(user.Id, 4) + t.AssertNE(user.UserDetail, nil) + t.Assert(user.UserDetail.Uid, 4) + t.Assert(user.UserDetail.Address, `address_4`) + t.Assert(len(user.UserScores), 5) + t.Assert(user.UserScores[0].Uid, 4) + t.Assert(user.UserScores[0].Score, 1) + t.Assert(user.UserScores[4].Uid, 4) + t.Assert(user.UserScores[4].Score, 5) + }) +} diff --git a/internal/structs/structs_field.go b/internal/structs/structs_field.go index a0ad1a1cb..dda5a996e 100644 --- a/internal/structs/structs_field.go +++ b/internal/structs/structs_field.go @@ -61,8 +61,11 @@ func (f *Field) OriginalKind() reflect.Kind { // The parameter `priority` specifies the priority tag array for retrieving from high to low. // If it's given `nil`, it returns map[name]*Field, of which the `name` is attribute name. // +// The parameter `recursive` specifies the whether retrieving the fields recursively if the attribute +// is an embedded struct. +// // Note that it only retrieves the exported attributes with first letter up-case from struct. -func FieldMap(pointer interface{}, priority []string) (map[string]*Field, error) { +func FieldMap(pointer interface{}, priority []string, recursive bool) (map[string]*Field, error) { fields, err := getFieldValues(pointer) if err != nil { return nil, err @@ -88,8 +91,8 @@ func FieldMap(pointer interface{}, priority []string) (map[string]*Field, error) if tagValue != "" { mapField[tagValue] = tempField } else { - if field.IsEmbedded() { - m, err := FieldMap(field.Value, priority) + if recursive && field.IsEmbedded() { + m, err := FieldMap(field.Value, priority, recursive) if err != nil { return nil, err } diff --git a/internal/structs/structs_z_unit_test.go b/internal/structs/structs_z_unit_test.go index c4f31d1a2..4268f473a 100644 --- a/internal/structs/structs_z_unit_test.go +++ b/internal/structs/structs_z_unit_test.go @@ -110,7 +110,7 @@ func Test_FieldMap(t *testing.T) { Pass string `my-tag1:"pass1" my-tag2:"pass2" params:"pass"` } var user *User - m, _ := structs.FieldMap(user, []string{"params"}) + m, _ := structs.FieldMap(user, []string{"params"}, true) t.Assert(len(m), 3) _, ok := m["Id"] t.Assert(ok, true) @@ -130,7 +130,7 @@ func Test_FieldMap(t *testing.T) { Pass string `my-tag1:"pass1" my-tag2:"pass2" params:"pass"` } var user *User - m, _ := structs.FieldMap(user, nil) + m, _ := structs.FieldMap(user, nil, true) t.Assert(len(m), 3) _, ok := m["Id"] t.Assert(ok, true) diff --git a/util/gvalid/gvalid_validator_check_struct.go b/util/gvalid/gvalid_validator_check_struct.go index 3180806f4..a3005b74c 100644 --- a/util/gvalid/gvalid_validator_check_struct.go +++ b/util/gvalid/gvalid_validator_check_struct.go @@ -28,7 +28,7 @@ func (v *Validator) CheckStruct(object interface{}, rules interface{}, messages var ( errorMaps = make(ErrorMap) // Returned error. ) - mapField, err := structs.FieldMap(object, aliasNameTagPriority) + mapField, err := structs.FieldMap(object, aliasNameTagPriority, true) if err != nil { return newErrorStr("invalid_object", err.Error()) }