From 1188793f8f4eb3b2f07aefbd49eef2dbec6c3b70 Mon Sep 17 00:00:00 2001 From: John Guo Date: Fri, 29 Oct 2021 16:57:56 +0800 Subject: [PATCH] automatically add column prefix for where conditions --- database/gdb/gdb_core_structure.go | 2 +- database/gdb/gdb_core_utility.go | 67 ++++++++++++++++++++- database/gdb/gdb_func.go | 35 +++-------- database/gdb/gdb_model_fields.go | 20 +----- database/gdb/gdb_model_utility.go | 8 ++- database/gdb/gdb_z_mysql_internal_test.go | 1 + database/gdb/gdb_z_mysql_model_join_test.go | 6 +- 7 files changed, 84 insertions(+), 55 deletions(-) diff --git a/database/gdb/gdb_core_structure.go b/database/gdb/gdb_core_structure.go index 2d990e3af..6df0cd5a6 100644 --- a/database/gdb/gdb_core_structure.go +++ b/database/gdb/gdb_core_structure.go @@ -151,7 +151,7 @@ func (c *Core) convertFieldValueToLocalValue(fieldValue interface{}, fieldType s // mappingAndFilterData automatically mappings the map key to table field and removes // all key-value pairs that are not the field of given table. func (c *Core) mappingAndFilterData(schema, table string, data map[string]interface{}, filter bool) (map[string]interface{}, error) { - if fieldsMap, err := c.db.TableFields(c.GetCtx(), table, schema); err == nil { + if fieldsMap, err := c.db.TableFields(c.GetCtx(), c.guessPrimaryTableName(table), schema); err == nil { fieldsKeyMap := make(map[string]interface{}, len(fieldsMap)) for k, _ := range fieldsMap { fieldsKeyMap[k] = nil diff --git a/database/gdb/gdb_core_utility.go b/database/gdb/gdb_core_utility.go index 5b0f1973a..9ddfd1489 100644 --- a/database/gdb/gdb_core_utility.go +++ b/database/gdb/gdb_core_utility.go @@ -7,6 +7,13 @@ package gdb +import ( + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gregex" + "github.com/gogf/gf/v2/text/gstr" +) + // MasterLink acts like function Master but with additional `schema` parameter specifying // the schema for the connection. It is defined for internal usage. // Also see Master. @@ -29,8 +36,10 @@ func (c *Core) SlaveLink(schema ...string) (Link, error) { return &dbLink{db}, nil } -// QuoteWord checks given string `s` a word, if true quotes it with security chars of the database -// and returns the quoted string; or else return `s` without any change. +// QuoteWord checks given string `s` a word, +// if true it quotes `s` with security chars of the database +// and returns the quoted string; or else it returns `s` without any change. +// // The meaning of a `word` can be considered as a column name. func (c *Core) QuoteWord(s string) string { charLeft, charRight := c.db.GetChars() @@ -39,6 +48,7 @@ func (c *Core) QuoteWord(s string) string { // QuoteString quotes string with quote chars. Strings like: // "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc". +// // The meaning of a `string` can be considered as part of a statement string including columns. func (c *Core) QuoteString(s string) string { charLeft, charRight := c.db.GetChars() @@ -84,3 +94,56 @@ func (c *Core) Tables(schema ...string) (tables []string, err error) { func (c *Core) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { return } + +// HasField determine whether the field exists in the table. +func (c *Core) HasField(table, field string, schema ...string) (bool, error) { + table = c.guessPrimaryTableName(table) + tableFields, err := c.db.TableFields(c.GetCtx(), table, schema...) + if err != nil { + return false, err + } + if len(tableFields) == 0 { + return false, gerror.NewCodef( + gcode.CodeNotFound, + `empty table fields for table "%s"`, table, + ) + } + fieldsArray := make([]string, len(tableFields)) + for k, v := range tableFields { + fieldsArray[v.Index] = k + } + charLeft, charRight := c.db.GetChars() + field = gstr.Trim(field, charLeft+charRight) + for _, f := range fieldsArray { + if f == field { + return true, nil + } + } + return false, nil +} + +// guessPrimaryTableName parses and returns the primary table name. +func (c *Core) guessPrimaryTableName(tableStr string) string { + if tableStr == "" { + return "" + } + var ( + guessedTableName = "" + array1 = gstr.SplitAndTrim(tableStr, ",") + array2 = gstr.SplitAndTrim(array1[0], " ") + array3 = gstr.SplitAndTrim(array2[0], ".") + ) + if len(array3) >= 2 { + guessedTableName = array3[1] + } else { + guessedTableName = array3[0] + } + charL, charR := c.db.GetChars() + if charL != "" || charR != "" { + guessedTableName = gstr.Trim(guessedTableName, charL+charR) + } + if !gregex.IsMatchString(regularFieldNameRegPattern, guessedTableName) { + return "" + } + return guessedTableName +} diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 056fb3409..aae5b68d5 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -58,8 +58,6 @@ type iTableName interface { const ( OrmTagForStruct = "orm" - OrmTagForUnique = "unique" - OrmTagForPrimary = "primary" OrmTagForTable = "table" OrmTagForWith = "with" OrmTagForWithWhere = "where" @@ -74,32 +72,6 @@ var ( structTagPriority = append([]string{OrmTagForStruct}, gconv.StructTagPriority...) ) -// guessPrimaryTableName parses and returns the primary table name. -func (m *Model) guessPrimaryTableName(tableStr string) string { - if tableStr == "" { - return "" - } - var ( - guessedTableName = "" - array1 = gstr.SplitAndTrim(tableStr, ",") - array2 = gstr.SplitAndTrim(array1[0], " ") - array3 = gstr.SplitAndTrim(array2[0], ".") - ) - if len(array3) >= 2 { - guessedTableName = array3[1] - } else { - guessedTableName = array3[0] - } - charL, charR := m.db.GetChars() - if charL != "" || charR != "" { - guessedTableName = gstr.Trim(guessedTableName, charL+charR) - } - if !gregex.IsMatchString(regularFieldNameRegPattern, guessedTableName) { - return "" - } - return guessedTableName -} - // getTableNameFromOrmTag retrieves and returns the table name from struct object. func getTableNameFromOrmTag(object interface{}) string { var tableName string @@ -524,6 +496,13 @@ func formatWhere(db DB, in formatWhereInput) (newWhere string, newArgs []interfa in.Args = in.Args[:0] break } + // If the first part is column name, it automatically adds prefix to the column. + if in.Prefix != "" { + array := gstr.Split(whereStr, " ") + if ok, _ := db.GetCore().HasField(in.Table, array[0]); ok { + whereStr = in.Prefix + "." + whereStr + } + } // Regular string and parameter place holder handling. // Eg: // Where("id in(?) and name=?", g.Slice{1,2,3}, "john") diff --git a/database/gdb/gdb_model_fields.go b/database/gdb/gdb_model_fields.go index e2791a394..c650d3503 100644 --- a/database/gdb/gdb_model_fields.go +++ b/database/gdb/gdb_model_fields.go @@ -9,8 +9,6 @@ package gdb import ( "fmt" "github.com/gogf/gf/v2/container/gset" - "github.com/gogf/gf/v2/errors/gcode" - "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" ) @@ -244,21 +242,5 @@ func (m *Model) GetFieldsExStr(fields string, prefix ...string) string { // HasField determine whether the field exists in the table. func (m *Model) HasField(field string) (bool, error) { - tableFields, err := m.TableFields(m.tablesInit) - if err != nil { - return false, err - } - if len(tableFields) == 0 { - return false, gerror.NewCodef(gcode.CodeNotFound, `empty table fields for table "%s"`, m.tables) - } - fieldsArray := make([]string, len(tableFields)) - for k, v := range tableFields { - fieldsArray[v.Index] = k - } - for _, f := range fieldsArray { - if f == field { - return true, nil - } - } - return false, nil + return m.db.GetCore().HasField(m.tablesInit, field) } diff --git a/database/gdb/gdb_model_utility.go b/database/gdb/gdb_model_utility.go index b01438a62..1d0914e0a 100644 --- a/database/gdb/gdb_model_utility.go +++ b/database/gdb/gdb_model_utility.go @@ -26,7 +26,11 @@ func (m *Model) TableFields(tableStr string, schema ...string) (fields map[strin if len(schema) > 0 && schema[0] != "" { useSchema = schema[0] } - return m.db.TableFields(m.GetCtx(), m.guessPrimaryTableName(tableStr), useSchema) + return m.db.TableFields( + m.GetCtx(), + m.db.GetCore().guessPrimaryTableName(tableStr), + useSchema, + ) } // getModel creates and returns a cloned model of current model if `safe` is true, or else it returns @@ -104,7 +108,7 @@ func (m *Model) filterDataForInsertOrUpdate(data interface{}) (interface{}, erro func (m *Model) doMappingAndFilterForInsertOrUpdateDataMap(data Map, allowOmitEmpty bool) (Map, error) { var err error data, err = m.db.GetCore().mappingAndFilterData( - m.schema, m.guessPrimaryTableName(m.tablesInit), data, m.filter, + m.schema, m.tablesInit, data, m.filter, ) if err != nil { return nil, err diff --git a/database/gdb/gdb_z_mysql_internal_test.go b/database/gdb/gdb_z_mysql_internal_test.go index e8ec33c0e..ad90d3b54 100644 --- a/database/gdb/gdb_z_mysql_internal_test.go +++ b/database/gdb/gdb_z_mysql_internal_test.go @@ -251,6 +251,7 @@ CREATE TABLE %s ( model := db.Model(fmt.Sprintf(`%s as t`, table1)) t.Assert(model.getConditionForSoftDeleting(), "`delete_at` IS NULL") }) + gtest.C(t, func(t *gtest.T) { model := db.Model(fmt.Sprintf(`%s, %s`, table1, table2)) t.Assert(model.getConditionForSoftDeleting(), fmt.Sprintf( diff --git a/database/gdb/gdb_z_mysql_model_join_test.go b/database/gdb/gdb_z_mysql_model_join_test.go index 4230042a8..5ffb07ae8 100644 --- a/database/gdb/gdb_z_mysql_model_join_test.go +++ b/database/gdb/gdb_z_mysql_model_join_test.go @@ -27,7 +27,7 @@ func Test_Model_LeftJoinOnField(t *testing.T) { r, err := db.Model(table1). FieldsPrefix(table1, "*"). LeftJoinOnField(table2, "id"). - Where("id", g.Slice{1, 2}). + WhereIn("id", g.Slice{1, 2}). Order("id asc").All() t.AssertNil(err) t.Assert(len(r), 2) @@ -50,7 +50,7 @@ func Test_Model_RightJoinOnField(t *testing.T) { r, err := db.Model(table1). FieldsPrefix(table1, "*"). RightJoinOnField(table2, "id"). - Where("id", g.Slice{1, 2}). + WhereIn("id", g.Slice{1, 2}). Order("id asc").All() t.AssertNil(err) t.Assert(len(r), 2) @@ -73,7 +73,7 @@ func Test_Model_InnerJoinOnField(t *testing.T) { r, err := db.Model(table1). FieldsPrefix(table1, "*"). InnerJoinOnField(table2, "id"). - Where("id", g.Slice{1, 2}). + WhereIn("id", g.Slice{1, 2}). Order("id asc").All() t.AssertNil(err) t.Assert(len(r), 2)