improve model relation feature for package gdb

This commit is contained in:
John Guo
2021-01-25 00:04:01 +08:00
parent 9b02f5220a
commit 9f2e69a9e6
2 changed files with 193 additions and 73 deletions

View File

@ -8,11 +8,9 @@ package gdb
import (
"database/sql"
"fmt"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/util/gconv"
"github.com/gogf/gf/util/gutil"
"reflect"
)
@ -42,19 +40,19 @@ import (
// given <relation> parameter.
//
// See the example or unit testing cases for clear understanding for this function.
func (r Result) ScanList(listPointer interface{}, attributeName string, relation ...string) (err error) {
func (r Result) ScanList(listPointer interface{}, bindToAttrName string, relationKV ...string) (err error) {
// Necessary checks for parameters.
if attributeName == "" {
return gerror.New(`attributeName should not be empty`)
}
if len(relation) > 0 {
if len(relation) < 2 {
return gerror.New(`relation name and key should are both necessary`)
}
if relation[0] == "" || relation[1] == "" {
return gerror.New(`relation name and key should not be empty`)
}
if bindToAttrName == "" {
return gerror.New(`bindToAttrName should not be empty`)
}
//if len(relation) > 0 {
// if len(relation) < 2 {
// return gerror.New(`relation name and key should are both necessary`)
// }
// if relation[0] == "" || relation[1] == "" {
// return gerror.New(`relation name and key should not be empty`)
// }
//}
var (
reflectValue = reflect.ValueOf(listPointer)
@ -65,12 +63,12 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation
reflectKind = reflectValue.Kind()
}
if reflectKind != reflect.Ptr {
return fmt.Errorf("parameter should be type of *[]struct/*[]*struct, but got: %v", reflectKind)
return gerror.Newf("parameter should be type of *[]struct/*[]*struct, but got: %v", reflectKind)
}
reflectValue = reflectValue.Elem()
reflectKind = reflectValue.Kind()
if reflectKind != reflect.Slice && reflectKind != reflect.Array {
return fmt.Errorf("parameter should be type of *[]struct/*[]*struct, but got: %v", reflectKind)
return gerror.Newf("parameter should be type of *[]struct/*[]*struct, but got: %v", reflectKind)
}
length := len(r)
if length == 0 {
@ -101,59 +99,61 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation
// Relation variables.
var (
relationDataMap map[string]Value
relationFieldName string
relationAttrName string
relationKVStr string
relationDataMap map[string]Value
relationFromAttrName string // Eg: relationKV: User, uid:Uid -> User
relationResultFieldName string // Eg: relationKV: uid:Uid -> uid
relationBindToSubAttrName string // Eg: relationKV: uid:Uid -> Uid
)
if len(relation) > 0 {
array := gstr.Split(relation[1], ":")
if len(array) > 1 {
if len(relationKV) > 0 {
if len(relationKV) == 1 {
relationKVStr = relationKV[0]
} else {
relationFromAttrName = relationKV[0]
relationKVStr = relationKV[1]
}
array := gstr.SplitAndTrim(relationKVStr, ":")
if len(array) == 2 {
// Defined table field to relation attribute name.
// Like:
// uid:Uid
// uid:UserId
relationFieldName = array[0]
relationAttrName = array[1]
relationResultFieldName = array[0]
relationBindToSubAttrName = array[1]
} else {
relationAttrName = relation[1]
// Find the possible map key by given only struct attribute name.
// Like:
// Uid
if k, _ := gutil.MapPossibleItemByKey(r[0].Map(), relation[1]); k != "" {
relationFieldName = k
}
return gerror.New(`parameter relationKV should be format of "ResultFieldName:BindToAttrName"`)
}
if relationFieldName != "" {
relationDataMap = r.MapKeyValue(relationFieldName)
if relationResultFieldName != "" {
relationDataMap = r.MapKeyValue(relationResultFieldName)
}
if len(relationDataMap) == 0 {
return fmt.Errorf(`cannot find the relation data map, maybe invalid relation key given: %s`, relation[1])
return gerror.Newf(`cannot find the relation data map, maybe invalid relation given "%v"`, relationKV)
}
}
// Bind to target attribute.
var (
ok bool
attrValue reflect.Value
attrKind reflect.Kind
attrType reflect.Type
attrField reflect.StructField
ok bool
bindToAttrValue reflect.Value
bindToAttrKind reflect.Kind
bindToAttrType reflect.Type
bindToAttrField reflect.StructField
)
if arrayItemType.Kind() == reflect.Ptr {
if attrField, ok = arrayItemType.Elem().FieldByName(attributeName); !ok {
return fmt.Errorf(`invalid field name: %s`, attributeName)
if bindToAttrField, ok = arrayItemType.Elem().FieldByName(bindToAttrName); !ok {
return gerror.Newf(`invalid parameter bindToAttrName: cannot find attribute with name "%s" from slice element`, bindToAttrName)
}
} else {
if attrField, ok = arrayItemType.FieldByName(attributeName); !ok {
return fmt.Errorf(`invalid field name: %s`, attributeName)
if bindToAttrField, ok = arrayItemType.FieldByName(bindToAttrName); !ok {
return gerror.Newf(`invalid parameter bindToAttrName: cannot find attribute with name "%s" from slice element`, bindToAttrName)
}
}
attrType = attrField.Type
attrKind = attrType.Kind()
bindToAttrType = bindToAttrField.Type
bindToAttrKind = bindToAttrType.Kind()
// Bind to relation conditions.
var (
relationValue reflect.Value
relationField reflect.Value
relationFromAttrValue reflect.Value
relationFromAttrField reflect.Value
)
for i := 0; i < arrayValue.Len(); i++ {
arrayElemValue := arrayValue.Index(i)
@ -173,41 +173,45 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation
} else {
// Like: []Entity
}
attrValue = arrayElemValue.FieldByName(attributeName)
if len(relation) > 0 {
relationValue = arrayElemValue.FieldByName(relation[0])
if relationValue.Kind() == reflect.Ptr {
relationValue = relationValue.Elem()
bindToAttrValue = arrayElemValue.FieldByName(bindToAttrName)
if relationFromAttrName != "" {
// Attribute value of current slice element.
relationFromAttrValue = arrayElemValue.FieldByName(relationFromAttrName)
if relationFromAttrValue.Kind() == reflect.Ptr {
relationFromAttrValue = relationFromAttrValue.Elem()
}
} else {
// Current slice element.
relationFromAttrValue = arrayElemValue
}
if len(relationDataMap) > 0 && !relationValue.IsValid() {
return fmt.Errorf(`invalid relation: "%s:%s"`, relation[0], relation[1])
if len(relationDataMap) > 0 && !relationFromAttrValue.IsValid() {
return gerror.Newf(`invalid relation specified: "%v"`, relationKV)
}
switch attrKind {
switch bindToAttrKind {
case reflect.Array, reflect.Slice:
if len(relationDataMap) > 0 {
relationField = relationValue.FieldByName(relationAttrName)
if relationField.IsValid() {
relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToSubAttrName)
if relationFromAttrField.IsValid() {
if err = gconv.Structs(
relationDataMap[gconv.String(relationField.Interface())],
attrValue.Addr(),
relationDataMap[gconv.String(relationFromAttrField.Interface())],
bindToAttrValue.Addr(),
); err != nil {
return err
}
} else {
// May be the attribute does not exist yet.
return fmt.Errorf(`invalid relation: "%s:%s"`, relation[0], relation[1])
return gerror.Newf(`invalid relation specified: "%v"`, relationKV)
}
} else {
return fmt.Errorf(`relationKey should not be empty as field "%s" is slice`, attributeName)
return gerror.Newf(`relationKey should not be empty as field "%s" is slice`, bindToAttrName)
}
case reflect.Ptr:
e := reflect.New(attrType.Elem()).Elem()
e := reflect.New(bindToAttrType.Elem()).Elem()
if len(relationDataMap) > 0 {
relationField = relationValue.FieldByName(relationAttrName)
if relationField.IsValid() {
v := relationDataMap[gconv.String(relationField.Interface())]
relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToSubAttrName)
if relationFromAttrField.IsValid() {
v := relationDataMap[gconv.String(relationFromAttrField.Interface())]
if v == nil {
// There's no relational data.
continue
@ -217,7 +221,7 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation
}
} else {
// May be the attribute does not exist yet.
return fmt.Errorf(`invalid relation: "%s:%s"`, relation[0], relation[1])
return gerror.Newf(`invalid relation specified: "%v"`, relationKV)
}
} else {
v := r[i]
@ -229,14 +233,14 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation
return err
}
}
attrValue.Set(e.Addr())
bindToAttrValue.Set(e.Addr())
case reflect.Struct:
e := reflect.New(attrType).Elem()
e := reflect.New(bindToAttrType).Elem()
if len(relationDataMap) > 0 {
relationField = relationValue.FieldByName(relationAttrName)
if relationField.IsValid() {
relationDataItem := relationDataMap[gconv.String(relationField.Interface())]
relationFromAttrField = relationFromAttrValue.FieldByName(relationBindToSubAttrName)
if relationFromAttrField.IsValid() {
relationDataItem := relationDataMap[gconv.String(relationFromAttrField.Interface())]
if relationDataItem == nil {
// There's no relational data.
continue
@ -246,7 +250,7 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation
}
} else {
// May be the attribute does not exist yet.
return fmt.Errorf(`invalid relation: "%s:%s"`, relation[0], relation[1])
return gerror.Newf(`invalid relation specified: "%v"`, relationKV)
}
} else {
relationDataItem := r[i]
@ -258,10 +262,10 @@ func (r Result) ScanList(listPointer interface{}, attributeName string, relation
return err
}
}
attrValue.Set(e)
bindToAttrValue.Set(e)
default:
return fmt.Errorf(`unsupport attribute type: %s`, attrKind.String())
return gerror.Newf(`unsupported attribute type: %s`, bindToAttrKind.String())
}
}
reflect.ValueOf(listPointer).Elem().Set(arrayValue)

View File

@ -472,3 +472,119 @@ CREATE TABLE %s (
t.Assert(users[1].UserScores[4].Score, 5)
})
}
func Test_Table_Relation_EmbedStruct(t *testing.T) {
var (
tableUser = "user_" + gtime.TimestampMicroStr()
tableUserDetail = "user_detail_" + gtime.TimestampMicroStr()
tableUserScores = "user_scores_" + gtime.TimestampMicroStr()
)
if _, err := db.Exec(fmt.Sprintf(`
CREATE TABLE %s (
uid int(10) unsigned NOT NULL AUTO_INCREMENT,
name varchar(45) NOT NULL,
PRIMARY KEY (uid)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
`, tableUser)); err != nil {
gtest.Error(err)
}
defer dropTable(tableUser)
if _, err := db.Exec(fmt.Sprintf(`
CREATE TABLE %s (
uid int(10) unsigned NOT NULL AUTO_INCREMENT,
address varchar(45) NOT NULL,
PRIMARY KEY (uid)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
`, tableUserDetail)); err != nil {
gtest.Error(err)
}
defer dropTable(tableUserDetail)
if _, err := db.Exec(fmt.Sprintf(`
CREATE TABLE %s (
id int(10) unsigned NOT NULL AUTO_INCREMENT,
uid int(10) unsigned NOT NULL,
score int(10) unsigned NOT NULL,
PRIMARY KEY (id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
`, tableUserScores)); err != nil {
gtest.Error(err)
}
defer dropTable(tableUserScores)
type EntityUser struct {
Uid int `json:"uid"`
Name string `json:"name"`
}
type EntityUserDetail struct {
*EntityUser
Uid int `json:"uid"`
Address string `json:"address"`
}
type EntityUserScores struct {
*EntityUser
*EntityUserDetail
Id int `json:"id"`
Uid int `json:"uid"`
Score int `json:"score"`
}
// Initialize the data.
var err error
for i := 1; i <= 5; i++ {
// User.
_, err = db.Insert(tableUser, g.Map{
"uid": i,
"name": fmt.Sprintf(`name_%d`, i),
})
gtest.Assert(err, nil)
// Detail.
_, err = db.Insert(tableUserDetail, g.Map{
"uid": i,
"address": fmt.Sprintf(`address_%d`, i),
})
gtest.Assert(err, nil)
// Scores.
for j := 1; j <= 5; j++ {
_, err = db.Insert(tableUserScores, g.Map{
"uid": i,
"score": j,
})
gtest.Assert(err, nil)
}
}
gtest.C(t, func(t *gtest.T) {
var (
err error
scores []*EntityUserScores
)
// SELECT * FROM `user_scores`
err = db.Table(tableUserScores).Scan(&scores)
t.Assert(err, nil)
// SELECT * FROM `user_scores` WHERE `uid` IN(1,2,3,4,5)
err = db.Table(tableUser).
Where("uid", gdb.ListItemValuesUnique(&scores, "Uid")).
ScanList(&scores, "EntityUser", "uid:Uid")
t.Assert(err, nil)
// SELECT * FROM `user_detail` WHERE `uid` IN(1,2,3,4,5)
err = db.Table(tableUserDetail).
Where("uid", gdb.ListItemValuesUnique(&scores, "Uid")).
ScanList(&scores, "EntityUserDetail", "uid:Uid")
t.Assert(err, nil)
// Assertions.
t.Assert(len(scores), 25)
t.Assert(scores[0].Id, 1)
t.Assert(scores[0].Uid, 1)
t.Assert(scores[0].Name, "name_1")
t.Assert(scores[0].Address, "address_1")
t.Assert(scores[24].Id, 25)
t.Assert(scores[24].Uid, 5)
t.Assert(scores[24].Name, "name_5")
t.Assert(scores[24].Address, "address_5")
})
}