fix issue in filds filter in join statements for package gdb

This commit is contained in:
jflyfox
2021-01-20 13:09:59 +08:00
parent a304ca8f5b
commit 59285709a6
5 changed files with 102 additions and 6 deletions

View File

@ -271,6 +271,11 @@ var (
// which is a regular field name of table.
regularFieldNameRegPattern = `^[\w\.\-\_]+$`
// regularFieldNameWithoutDotRegPattern is similar to regularFieldNameRegPattern but not allows '.'.
// Note that, although some databases allow char '.' in the field name, but it here does not allow '.'
// in the field name as it conflicts with "db.table.field" pattern in SOME situations.
regularFieldNameWithoutDotRegPattern = `^[\w\-\_]+$`
// internalCache is the memory cache for internal usage.
internalCache = gcache.New()

View File

@ -467,14 +467,20 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (
newWhere = db.QuoteString(newWhere)
if len(newArgs) > 0 {
if utils.IsArray(newArgs[0]) {
// Eg: Where("id", []int{1,2,3})
// Eg:
// Where("id", []int{1,2,3})
// Where("user.id", []int{1,2,3})
newWhere += " IN (?)"
} else if empty.IsNil(newArgs[0]) {
// Eg: Where("id", nil)
// Eg:
// Where("id", nil)
// Where("user.id", nil)
newWhere += " IS NULL"
newArgs = nil
} else {
// Eg: Where/And/Or("uid", 1)
// Eg:
// Where/And/Or("uid", 1)
// Where/And/Or("user.uid", 1)
newWhere += "=?"
}
}

View File

@ -44,10 +44,12 @@ func (m *Model) mappingAndFilterToTableFields(fields []string) []string {
}
for _, field := range inputFieldsArray {
if _, ok := fieldsKeyMap[field]; !ok {
if !gregex.IsMatchString(regularFieldNameRegPattern, field) {
if !gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, field) {
// Eg: user.id, user.name
outputFieldsArray = append(outputFieldsArray, field)
continue
} else {
// Eg: id, name
if foundKey, _ := gutil.MapPossibleItemByKey(fieldsKeyMap, field); foundKey != "" {
outputFieldsArray = append(outputFieldsArray, foundKey)
}

View File

@ -72,7 +72,6 @@ func init() {
gtest.Error(err)
}
db.SetSchema(SCHEMA1)
createTable(TABLE)
// Prefix db.
if r, err := gdb.New("prefix"); err != nil {
@ -87,7 +86,6 @@ func init() {
gtest.Error(err)
}
dbPrefix.SetSchema(SCHEMA1)
createTable(TABLE)
}
func createTable(table ...string) string {

View File

@ -3236,3 +3236,88 @@ func Test_Model_Fields_Map_Struct(t *testing.T) {
t.Assert(a.XXX_TYPE, 0)
})
}
func Test_Model_Fields_AutoFilterInJoinStatement(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
var err error
table1 := "user"
table2 := "score"
table3 := "info"
if _, err := db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id int(11) NOT NULL AUTO_INCREMENT,
name varchar(500) NOT NULL DEFAULT '',
PRIMARY KEY (id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 AUTO_INCREMENT=1;
`, table1,
)); err != nil {
t.Assert(err, nil)
}
defer dropTable(table1)
_, err = db.Table(table1).Insert(g.Map{
"id": 1,
"name": "john",
})
t.Assert(err, nil)
if _, err := db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id int(11) NOT NULL AUTO_INCREMENT,
user_id int(11) NOT NULL DEFAULT 0,
number varchar(500) NOT NULL DEFAULT '',
PRIMARY KEY (id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 AUTO_INCREMENT=1;
`, table2,
)); err != nil {
t.Assert(err, nil)
}
defer dropTable(table2)
_, err = db.Table(table2).Insert(g.Map{
"id": 1,
"user_id": 1,
"number": "n",
})
t.Assert(err, nil)
if _, err := db.Exec(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id int(11) NOT NULL AUTO_INCREMENT,
user_id int(11) NOT NULL DEFAULT 0,
description varchar(500) NOT NULL DEFAULT '',
PRIMARY KEY (id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 AUTO_INCREMENT=1;
`, table3,
)); err != nil {
t.Assert(err, nil)
}
defer dropTable(table3)
_, err = db.Table(table3).Insert(g.Map{
"id": 1,
"user_id": 1,
"description": "brief",
})
t.Assert(err, nil)
one, err := db.Table("user").
Where("user.id", 1).
Fields("score.number,user.name").
LeftJoin("score", "user.id=score.user_id").
LeftJoin("info", "info.id=info.user_id").
Order("user.id asc").
One()
t.Assert(err, nil)
t.Assert(len(one), 2)
t.Assert(one["name"].String(), "john")
t.Assert(one["number"].String(), "n")
one, err = db.Table("user").
LeftJoin("score", "user.id=score.user_id").
LeftJoin("info", "info.id=info.user_id").
Fields("score.number,user.name").
One()
t.Assert(err, nil)
t.Assert(len(one), 2)
t.Assert(one["name"].String(), "john")
t.Assert(one["number"].String(), "n")
})
}