From 9f2e69a9e6b8eaa95576f7d29f9c3bd954b7ccea Mon Sep 17 00:00:00 2001 From: John Guo Date: Mon, 25 Jan 2021 00:04:01 +0800 Subject: [PATCH] improve model relation feature for package gdb --- database/gdb/gdb_type_result_scanlist.go | 150 ++++++++++--------- database/gdb/gdb_z_mysql_association_test.go | 116 ++++++++++++++ 2 files changed, 193 insertions(+), 73 deletions(-) diff --git a/database/gdb/gdb_type_result_scanlist.go b/database/gdb/gdb_type_result_scanlist.go index 5da062227..9494ef514 100644 --- a/database/gdb/gdb_type_result_scanlist.go +++ b/database/gdb/gdb_type_result_scanlist.go @@ -8,11 +8,9 @@ package gdb import ( "database/sql" - "fmt" "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" - "github.com/gogf/gf/util/gutil" "reflect" ) @@ -42,19 +40,19 @@ import ( // given parameter. // // See the example or unit testing cases for clear understanding for this function. -func (r Result) ScanList(listPointer interface{}, attributeName string, relation ...string) (err error) { +func (r Result) ScanList(listPointer interface{}, bindToAttrName string, relationKV ...string) (err error) { // Necessary checks for parameters. - if attributeName == "" { - return gerror.New(`attributeName should not be empty`) - } - if len(relation) > 0 { - if len(relation) < 2 { - return gerror.New(`relation name and key should are both necessary`) - } - if relation[0] == "" || relation[1] == "" { - return gerror.New(`relation name and key should not be empty`) - } + if bindToAttrName == "" { + return gerror.New(`bindToAttrName should not be empty`) } + //if len(relation) > 0 { + // if len(relation) < 2 { + // return gerror.New(`relation name and key should are both necessary`) + // } + // if relation[0] == "" || relation[1] == "" { + // return gerror.New(`relation name and key should not be empty`) + // } + //} var ( reflectValue = reflect.ValueOf(listPointer) @@ -65,12 +63,12 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation reflectKind = reflectValue.Kind() } if reflectKind != reflect.Ptr { - return fmt.Errorf("parameter should be type of *[]struct/*[]*struct, but got: %v", reflectKind) + return gerror.Newf("parameter should be type of *[]struct/*[]*struct, but got: %v", reflectKind) } reflectValue = reflectValue.Elem() reflectKind = reflectValue.Kind() if reflectKind != reflect.Slice && reflectKind != reflect.Array { - return fmt.Errorf("parameter should be type of *[]struct/*[]*struct, but got: %v", reflectKind) + return gerror.Newf("parameter should be type of *[]struct/*[]*struct, but got: %v", reflectKind) } length := len(r) if length == 0 { @@ -101,59 +99,61 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation // Relation variables. var ( - relationDataMap map[string]Value - relationFieldName string - relationAttrName string + relationKVStr string + relationDataMap map[string]Value + relationFromAttrName string // Eg: relationKV: User, uid:Uid -> User + relationResultFieldName string // Eg: relationKV: uid:Uid -> uid + relationBindToSubAttrName string // Eg: relationKV: uid:Uid -> Uid ) - if len(relation) > 0 { - array := gstr.Split(relation[1], ":") - if len(array) > 1 { + if len(relationKV) > 0 { + if len(relationKV) == 1 { + relationKVStr = relationKV[0] + } else { + relationFromAttrName = relationKV[0] + relationKVStr = relationKV[1] + } + array := gstr.SplitAndTrim(relationKVStr, ":") + if len(array) == 2 { // Defined table field to relation attribute name. // Like: // uid:Uid // uid:UserId - relationFieldName = array[0] - relationAttrName = array[1] + relationResultFieldName = array[0] + relationBindToSubAttrName = array[1] } else { - relationAttrName = relation[1] - // Find the possible map key by given only struct attribute name. - // Like: - // Uid - if k, _ := gutil.MapPossibleItemByKey(r[0].Map(), relation[1]); k != "" { - relationFieldName = k - } + return gerror.New(`parameter relationKV should be format of "ResultFieldName:BindToAttrName"`) } - if relationFieldName != "" { - relationDataMap = r.MapKeyValue(relationFieldName) + if relationResultFieldName != "" { + relationDataMap = r.MapKeyValue(relationResultFieldName) } if len(relationDataMap) == 0 { - return fmt.Errorf(`cannot find the relation data map, maybe invalid relation key given: %s`, relation[1]) + return gerror.Newf(`cannot find the relation data map, maybe invalid relation given "%v"`, relationKV) } } // Bind to target attribute. var ( - ok bool - attrValue reflect.Value - attrKind reflect.Kind - attrType reflect.Type - attrField reflect.StructField + ok bool + bindToAttrValue reflect.Value + bindToAttrKind reflect.Kind + bindToAttrType reflect.Type + bindToAttrField reflect.StructField ) if arrayItemType.Kind() == reflect.Ptr { - if attrField, ok = arrayItemType.Elem().FieldByName(attributeName); !ok { - return fmt.Errorf(`invalid field name: %s`, attributeName) + if bindToAttrField, ok = arrayItemType.Elem().FieldByName(bindToAttrName); !ok { + return gerror.Newf(`invalid parameter bindToAttrName: cannot find attribute with name "%s" from slice element`, bindToAttrName) } } else { - if attrField, ok = arrayItemType.FieldByName(attributeName); !ok { - return fmt.Errorf(`invalid field name: %s`, attributeName) + if bindToAttrField, ok = arrayItemType.FieldByName(bindToAttrName); !ok { + return gerror.Newf(`invalid parameter bindToAttrName: cannot find attribute with name "%s" from slice element`, bindToAttrName) } } - attrType = attrField.Type - attrKind = attrType.Kind() + bindToAttrType = bindToAttrField.Type + bindToAttrKind = bindToAttrType.Kind() // Bind to relation conditions. var ( - relationValue reflect.Value - relationField reflect.Value + relationFromAttrValue reflect.Value + relationFromAttrField reflect.Value ) for i := 0; i < arrayValue.Len(); i++ { arrayElemValue := arrayValue.Index(i) @@ -173,41 +173,45 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation } else { // Like: []Entity } - attrValue = arrayElemValue.FieldByName(attributeName) - if len(relation) > 0 { - relationValue = arrayElemValue.FieldByName(relation[0]) - if relationValue.Kind() == reflect.Ptr { - relationValue = relationValue.Elem() + bindToAttrValue = arrayElemValue.FieldByName(bindToAttrName) + if relationFromAttrName != "" { + // Attribute value of current slice element. + relationFromAttrValue = arrayElemValue.FieldByName(relationFromAttrName) + if relationFromAttrValue.Kind() == reflect.Ptr { + relationFromAttrValue = relationFromAttrValue.Elem() } + } else { + // Current slice element. + relationFromAttrValue = arrayElemValue } - if len(relationDataMap) > 0 && !relationValue.IsValid() { - return fmt.Errorf(`invalid relation: "%s:%s"`, relation[0], relation[1]) + if len(relationDataMap) > 0 && !relationFromAttrValue.IsValid() { + return gerror.Newf(`invalid relation specified: "%v"`, relationKV) } - switch attrKind { + switch bindToAttrKind { case reflect.Array, reflect.Slice: if len(relationDataMap) > 0 { - relationField = relationValue.FieldByName(relationAttrName) - if relationField.IsValid() { + relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToSubAttrName) + if relationFromAttrField.IsValid() { if err = gconv.Structs( - relationDataMap[gconv.String(relationField.Interface())], - attrValue.Addr(), + relationDataMap[gconv.String(relationFromAttrField.Interface())], + bindToAttrValue.Addr(), ); err != nil { return err } } else { // May be the attribute does not exist yet. - return fmt.Errorf(`invalid relation: "%s:%s"`, relation[0], relation[1]) + return gerror.Newf(`invalid relation specified: "%v"`, relationKV) } } else { - return fmt.Errorf(`relationKey should not be empty as field "%s" is slice`, attributeName) + return gerror.Newf(`relationKey should not be empty as field "%s" is slice`, bindToAttrName) } case reflect.Ptr: - e := reflect.New(attrType.Elem()).Elem() + e := reflect.New(bindToAttrType.Elem()).Elem() if len(relationDataMap) > 0 { - relationField = relationValue.FieldByName(relationAttrName) - if relationField.IsValid() { - v := relationDataMap[gconv.String(relationField.Interface())] + relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToSubAttrName) + if relationFromAttrField.IsValid() { + v := relationDataMap[gconv.String(relationFromAttrField.Interface())] if v == nil { // There's no relational data. continue @@ -217,7 +221,7 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation } } else { // May be the attribute does not exist yet. - return fmt.Errorf(`invalid relation: "%s:%s"`, relation[0], relation[1]) + return gerror.Newf(`invalid relation specified: "%v"`, relationKV) } } else { v := r[i] @@ -229,14 +233,14 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation return err } } - attrValue.Set(e.Addr()) + bindToAttrValue.Set(e.Addr()) case reflect.Struct: - e := reflect.New(attrType).Elem() + e := reflect.New(bindToAttrType).Elem() if len(relationDataMap) > 0 { - relationField = relationValue.FieldByName(relationAttrName) - if relationField.IsValid() { - relationDataItem := relationDataMap[gconv.String(relationField.Interface())] + relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToSubAttrName) + if relationFromAttrField.IsValid() { + relationDataItem := relationDataMap[gconv.String(relationFromAttrField.Interface())] if relationDataItem == nil { // There's no relational data. continue @@ -246,7 +250,7 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation } } else { // May be the attribute does not exist yet. - return fmt.Errorf(`invalid relation: "%s:%s"`, relation[0], relation[1]) + return gerror.Newf(`invalid relation specified: "%v"`, relationKV) } } else { relationDataItem := r[i] @@ -258,10 +262,10 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation return err } } - attrValue.Set(e) + bindToAttrValue.Set(e) default: - return fmt.Errorf(`unsupport attribute type: %s`, attrKind.String()) + return gerror.Newf(`unsupported attribute type: %s`, bindToAttrKind.String()) } } reflect.ValueOf(listPointer).Elem().Set(arrayValue) diff --git a/database/gdb/gdb_z_mysql_association_test.go b/database/gdb/gdb_z_mysql_association_test.go index 856377fea..06b462017 100644 --- a/database/gdb/gdb_z_mysql_association_test.go +++ b/database/gdb/gdb_z_mysql_association_test.go @@ -472,3 +472,119 @@ CREATE TABLE %s ( t.Assert(users[1].UserScores[4].Score, 5) }) } + +func Test_Table_Relation_EmbedStruct(t *testing.T) { + var ( + tableUser = "user_" + gtime.TimestampMicroStr() + tableUserDetail = "user_detail_" + gtime.TimestampMicroStr() + tableUserScores = "user_scores_" + gtime.TimestampMicroStr() + ) + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE %s ( + uid int(10) unsigned NOT NULL AUTO_INCREMENT, + name varchar(45) NOT NULL, + PRIMARY KEY (uid) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, tableUser)); err != nil { + gtest.Error(err) + } + defer dropTable(tableUser) + + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE %s ( + uid int(10) unsigned NOT NULL AUTO_INCREMENT, + address varchar(45) NOT NULL, + PRIMARY KEY (uid) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, tableUserDetail)); err != nil { + gtest.Error(err) + } + defer dropTable(tableUserDetail) + + if _, err := db.Exec(fmt.Sprintf(` +CREATE TABLE %s ( + id int(10) unsigned NOT NULL AUTO_INCREMENT, + uid int(10) unsigned NOT NULL, + score int(10) unsigned NOT NULL, + PRIMARY KEY (id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `, tableUserScores)); err != nil { + gtest.Error(err) + } + defer dropTable(tableUserScores) + + type EntityUser struct { + Uid int `json:"uid"` + Name string `json:"name"` + } + type EntityUserDetail struct { + *EntityUser + Uid int `json:"uid"` + Address string `json:"address"` + } + type EntityUserScores struct { + *EntityUser + *EntityUserDetail + Id int `json:"id"` + Uid int `json:"uid"` + Score int `json:"score"` + } + + // Initialize the data. + var err error + for i := 1; i <= 5; i++ { + // User. + _, err = db.Insert(tableUser, g.Map{ + "uid": i, + "name": fmt.Sprintf(`name_%d`, i), + }) + gtest.Assert(err, nil) + // Detail. + _, err = db.Insert(tableUserDetail, g.Map{ + "uid": i, + "address": fmt.Sprintf(`address_%d`, i), + }) + gtest.Assert(err, nil) + // Scores. + for j := 1; j <= 5; j++ { + _, err = db.Insert(tableUserScores, g.Map{ + "uid": i, + "score": j, + }) + gtest.Assert(err, nil) + } + } + + gtest.C(t, func(t *gtest.T) { + var ( + err error + scores []*EntityUserScores + ) + // SELECT * FROM `user_scores` + err = db.Table(tableUserScores).Scan(&scores) + t.Assert(err, nil) + + // SELECT * FROM `user_scores` WHERE `uid` IN(1,2,3,4,5) + err = db.Table(tableUser). + Where("uid", gdb.ListItemValuesUnique(&scores, "Uid")). + ScanList(&scores, "EntityUser", "uid:Uid") + t.Assert(err, nil) + + // SELECT * FROM `user_detail` WHERE `uid` IN(1,2,3,4,5) + err = db.Table(tableUserDetail). + Where("uid", gdb.ListItemValuesUnique(&scores, "Uid")). + ScanList(&scores, "EntityUserDetail", "uid:Uid") + t.Assert(err, nil) + + // Assertions. + t.Assert(len(scores), 25) + t.Assert(scores[0].Id, 1) + t.Assert(scores[0].Uid, 1) + t.Assert(scores[0].Name, "name_1") + t.Assert(scores[0].Address, "address_1") + t.Assert(scores[24].Id, 25) + t.Assert(scores[24].Uid, 5) + t.Assert(scores[24].Name, "name_5") + t.Assert(scores[24].Address, "address_5") + }) +}