From 59285709a65e8f00d863232ec9e983a7ea2799fd Mon Sep 17 00:00:00 2001 From: jflyfox Date: Wed, 20 Jan 2021 13:09:59 +0800 Subject: [PATCH] fix issue in filds filter in join statements for package gdb --- database/gdb/gdb.go | 5 ++ database/gdb/gdb_func.go | 12 +++- database/gdb/gdb_model_utility.go | 4 +- database/gdb/gdb_z_init_test.go | 2 - database/gdb/gdb_z_mysql_model_test.go | 85 ++++++++++++++++++++++++++ 5 files changed, 102 insertions(+), 6 deletions(-) diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 898624713..0212a0999 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -271,6 +271,11 @@ var ( // which is a regular field name of table. regularFieldNameRegPattern = `^[\w\.\-\_]+$` + // regularFieldNameWithoutDotRegPattern is similar to regularFieldNameRegPattern but not allows '.'. + // Note that, although some databases allow char '.' in the field name, but it here does not allow '.' + // in the field name as it conflicts with "db.table.field" pattern in SOME situations. + regularFieldNameWithoutDotRegPattern = `^[\w\-\_]+$` + // internalCache is the memory cache for internal usage. internalCache = gcache.New() diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 5e7c9b7bf..2a8878a33 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -467,14 +467,20 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) ( newWhere = db.QuoteString(newWhere) if len(newArgs) > 0 { if utils.IsArray(newArgs[0]) { - // Eg: Where("id", []int{1,2,3}) + // Eg: + // Where("id", []int{1,2,3}) + // Where("user.id", []int{1,2,3}) newWhere += " IN (?)" } else if empty.IsNil(newArgs[0]) { - // Eg: Where("id", nil) + // Eg: + // Where("id", nil) + // Where("user.id", nil) newWhere += " IS NULL" newArgs = nil } else { - // Eg: Where/And/Or("uid", 1) + // Eg: + // Where/And/Or("uid", 1) + // Where/And/Or("user.uid", 1) newWhere += "=?" } } diff --git a/database/gdb/gdb_model_utility.go b/database/gdb/gdb_model_utility.go index 9ec5f74eb..17a484025 100644 --- a/database/gdb/gdb_model_utility.go +++ b/database/gdb/gdb_model_utility.go @@ -44,10 +44,12 @@ func (m *Model) mappingAndFilterToTableFields(fields []string) []string { } for _, field := range inputFieldsArray { if _, ok := fieldsKeyMap[field]; !ok { - if !gregex.IsMatchString(regularFieldNameRegPattern, field) { + if !gregex.IsMatchString(regularFieldNameWithoutDotRegPattern, field) { + // Eg: user.id, user.name outputFieldsArray = append(outputFieldsArray, field) continue } else { + // Eg: id, name if foundKey, _ := gutil.MapPossibleItemByKey(fieldsKeyMap, field); foundKey != "" { outputFieldsArray = append(outputFieldsArray, foundKey) } diff --git a/database/gdb/gdb_z_init_test.go b/database/gdb/gdb_z_init_test.go index 44ad2167e..388c410e6 100644 --- a/database/gdb/gdb_z_init_test.go +++ b/database/gdb/gdb_z_init_test.go @@ -72,7 +72,6 @@ func init() { gtest.Error(err) } db.SetSchema(SCHEMA1) - createTable(TABLE) // Prefix db. if r, err := gdb.New("prefix"); err != nil { @@ -87,7 +86,6 @@ func init() { gtest.Error(err) } dbPrefix.SetSchema(SCHEMA1) - createTable(TABLE) } func createTable(table ...string) string { diff --git a/database/gdb/gdb_z_mysql_model_test.go b/database/gdb/gdb_z_mysql_model_test.go index 1528571b3..297eb417c 100644 --- a/database/gdb/gdb_z_mysql_model_test.go +++ b/database/gdb/gdb_z_mysql_model_test.go @@ -3236,3 +3236,88 @@ func Test_Model_Fields_Map_Struct(t *testing.T) { t.Assert(a.XXX_TYPE, 0) }) } + +func Test_Model_Fields_AutoFilterInJoinStatement(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var err error + table1 := "user" + table2 := "score" + table3 := "info" + if _, err := db.Exec(fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id int(11) NOT NULL AUTO_INCREMENT, + name varchar(500) NOT NULL DEFAULT '', + PRIMARY KEY (id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8 AUTO_INCREMENT=1; + `, table1, + )); err != nil { + t.Assert(err, nil) + } + defer dropTable(table1) + _, err = db.Table(table1).Insert(g.Map{ + "id": 1, + "name": "john", + }) + t.Assert(err, nil) + + if _, err := db.Exec(fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id int(11) NOT NULL AUTO_INCREMENT, + user_id int(11) NOT NULL DEFAULT 0, + number varchar(500) NOT NULL DEFAULT '', + PRIMARY KEY (id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8 AUTO_INCREMENT=1; + `, table2, + )); err != nil { + t.Assert(err, nil) + } + defer dropTable(table2) + _, err = db.Table(table2).Insert(g.Map{ + "id": 1, + "user_id": 1, + "number": "n", + }) + t.Assert(err, nil) + + if _, err := db.Exec(fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id int(11) NOT NULL AUTO_INCREMENT, + user_id int(11) NOT NULL DEFAULT 0, + description varchar(500) NOT NULL DEFAULT '', + PRIMARY KEY (id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8 AUTO_INCREMENT=1; + `, table3, + )); err != nil { + t.Assert(err, nil) + } + defer dropTable(table3) + _, err = db.Table(table3).Insert(g.Map{ + "id": 1, + "user_id": 1, + "description": "brief", + }) + t.Assert(err, nil) + + one, err := db.Table("user"). + Where("user.id", 1). + Fields("score.number,user.name"). + LeftJoin("score", "user.id=score.user_id"). + LeftJoin("info", "info.id=info.user_id"). + Order("user.id asc"). + One() + t.Assert(err, nil) + t.Assert(len(one), 2) + t.Assert(one["name"].String(), "john") + t.Assert(one["number"].String(), "n") + + one, err = db.Table("user"). + LeftJoin("score", "user.id=score.user_id"). + LeftJoin("info", "info.id=info.user_id"). + Fields("score.number,user.name"). + One() + t.Assert(err, nil) + t.Assert(len(one), 2) + t.Assert(one["name"].String(), "john") + t.Assert(one["number"].String(), "n") + }) +}