automatically add column prefix for where conditions

This commit is contained in:
John Guo
2021-10-29 16:57:56 +08:00
parent 09b8df1818
commit 1188793f8f
7 changed files with 84 additions and 55 deletions

View File

@ -151,7 +151,7 @@ func (c *Core) convertFieldValueToLocalValue(fieldValue interface{}, fieldType s
// mappingAndFilterData automatically mappings the map key to table field and removes
// all key-value pairs that are not the field of given table.
func (c *Core) mappingAndFilterData(schema, table string, data map[string]interface{}, filter bool) (map[string]interface{}, error) {
if fieldsMap, err := c.db.TableFields(c.GetCtx(), table, schema); err == nil {
if fieldsMap, err := c.db.TableFields(c.GetCtx(), c.guessPrimaryTableName(table), schema); err == nil {
fieldsKeyMap := make(map[string]interface{}, len(fieldsMap))
for k, _ := range fieldsMap {
fieldsKeyMap[k] = nil

View File

@ -7,6 +7,13 @@
package gdb
import (
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
)
// MasterLink acts like function Master but with additional `schema` parameter specifying
// the schema for the connection. It is defined for internal usage.
// Also see Master.
@ -29,8 +36,10 @@ func (c *Core) SlaveLink(schema ...string) (Link, error) {
return &dbLink{db}, nil
}
// QuoteWord checks given string `s` a word, if true quotes it with security chars of the database
// and returns the quoted string; or else return `s` without any change.
// QuoteWord checks given string `s` a word,
// if true it quotes `s` with security chars of the database
// and returns the quoted string; or else it returns `s` without any change.
//
// The meaning of a `word` can be considered as a column name.
func (c *Core) QuoteWord(s string) string {
charLeft, charRight := c.db.GetChars()
@ -39,6 +48,7 @@ func (c *Core) QuoteWord(s string) string {
// QuoteString quotes string with quote chars. Strings like:
// "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc".
//
// The meaning of a `string` can be considered as part of a statement string including columns.
func (c *Core) QuoteString(s string) string {
charLeft, charRight := c.db.GetChars()
@ -84,3 +94,56 @@ func (c *Core) Tables(schema ...string) (tables []string, err error) {
func (c *Core) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) {
return
}
// HasField determine whether the field exists in the table.
func (c *Core) HasField(table, field string, schema ...string) (bool, error) {
table = c.guessPrimaryTableName(table)
tableFields, err := c.db.TableFields(c.GetCtx(), table, schema...)
if err != nil {
return false, err
}
if len(tableFields) == 0 {
return false, gerror.NewCodef(
gcode.CodeNotFound,
`empty table fields for table "%s"`, table,
)
}
fieldsArray := make([]string, len(tableFields))
for k, v := range tableFields {
fieldsArray[v.Index] = k
}
charLeft, charRight := c.db.GetChars()
field = gstr.Trim(field, charLeft+charRight)
for _, f := range fieldsArray {
if f == field {
return true, nil
}
}
return false, nil
}
// guessPrimaryTableName parses and returns the primary table name.
func (c *Core) guessPrimaryTableName(tableStr string) string {
if tableStr == "" {
return ""
}
var (
guessedTableName = ""
array1 = gstr.SplitAndTrim(tableStr, ",")
array2 = gstr.SplitAndTrim(array1[0], " ")
array3 = gstr.SplitAndTrim(array2[0], ".")
)
if len(array3) >= 2 {
guessedTableName = array3[1]
} else {
guessedTableName = array3[0]
}
charL, charR := c.db.GetChars()
if charL != "" || charR != "" {
guessedTableName = gstr.Trim(guessedTableName, charL+charR)
}
if !gregex.IsMatchString(regularFieldNameRegPattern, guessedTableName) {
return ""
}
return guessedTableName
}

View File

@ -58,8 +58,6 @@ type iTableName interface {
const (
OrmTagForStruct = "orm"
OrmTagForUnique = "unique"
OrmTagForPrimary = "primary"
OrmTagForTable = "table"
OrmTagForWith = "with"
OrmTagForWithWhere = "where"
@ -74,32 +72,6 @@ var (
structTagPriority = append([]string{OrmTagForStruct}, gconv.StructTagPriority...)
)
// guessPrimaryTableName parses and returns the primary table name.
func (m *Model) guessPrimaryTableName(tableStr string) string {
if tableStr == "" {
return ""
}
var (
guessedTableName = ""
array1 = gstr.SplitAndTrim(tableStr, ",")
array2 = gstr.SplitAndTrim(array1[0], " ")
array3 = gstr.SplitAndTrim(array2[0], ".")
)
if len(array3) >= 2 {
guessedTableName = array3[1]
} else {
guessedTableName = array3[0]
}
charL, charR := m.db.GetChars()
if charL != "" || charR != "" {
guessedTableName = gstr.Trim(guessedTableName, charL+charR)
}
if !gregex.IsMatchString(regularFieldNameRegPattern, guessedTableName) {
return ""
}
return guessedTableName
}
// getTableNameFromOrmTag retrieves and returns the table name from struct object.
func getTableNameFromOrmTag(object interface{}) string {
var tableName string
@ -524,6 +496,13 @@ func formatWhere(db DB, in formatWhereInput) (newWhere string, newArgs []interfa
in.Args = in.Args[:0]
break
}
// If the first part is column name, it automatically adds prefix to the column.
if in.Prefix != "" {
array := gstr.Split(whereStr, " ")
if ok, _ := db.GetCore().HasField(in.Table, array[0]); ok {
whereStr = in.Prefix + "." + whereStr
}
}
// Regular string and parameter place holder handling.
// Eg:
// Where("id in(?) and name=?", g.Slice{1,2,3}, "john")

View File

@ -9,8 +9,6 @@ package gdb
import (
"fmt"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
)
@ -244,21 +242,5 @@ func (m *Model) GetFieldsExStr(fields string, prefix ...string) string {
// HasField determine whether the field exists in the table.
func (m *Model) HasField(field string) (bool, error) {
tableFields, err := m.TableFields(m.tablesInit)
if err != nil {
return false, err
}
if len(tableFields) == 0 {
return false, gerror.NewCodef(gcode.CodeNotFound, `empty table fields for table "%s"`, m.tables)
}
fieldsArray := make([]string, len(tableFields))
for k, v := range tableFields {
fieldsArray[v.Index] = k
}
for _, f := range fieldsArray {
if f == field {
return true, nil
}
}
return false, nil
return m.db.GetCore().HasField(m.tablesInit, field)
}

View File

@ -26,7 +26,11 @@ func (m *Model) TableFields(tableStr string, schema ...string) (fields map[strin
if len(schema) > 0 && schema[0] != "" {
useSchema = schema[0]
}
return m.db.TableFields(m.GetCtx(), m.guessPrimaryTableName(tableStr), useSchema)
return m.db.TableFields(
m.GetCtx(),
m.db.GetCore().guessPrimaryTableName(tableStr),
useSchema,
)
}
// getModel creates and returns a cloned model of current model if `safe` is true, or else it returns
@ -104,7 +108,7 @@ func (m *Model) filterDataForInsertOrUpdate(data interface{}) (interface{}, erro
func (m *Model) doMappingAndFilterForInsertOrUpdateDataMap(data Map, allowOmitEmpty bool) (Map, error) {
var err error
data, err = m.db.GetCore().mappingAndFilterData(
m.schema, m.guessPrimaryTableName(m.tablesInit), data, m.filter,
m.schema, m.tablesInit, data, m.filter,
)
if err != nil {
return nil, err

View File

@ -251,6 +251,7 @@ CREATE TABLE %s (
model := db.Model(fmt.Sprintf(`%s as t`, table1))
t.Assert(model.getConditionForSoftDeleting(), "`delete_at` IS NULL")
})
gtest.C(t, func(t *gtest.T) {
model := db.Model(fmt.Sprintf(`%s, %s`, table1, table2))
t.Assert(model.getConditionForSoftDeleting(), fmt.Sprintf(

View File

@ -27,7 +27,7 @@ func Test_Model_LeftJoinOnField(t *testing.T) {
r, err := db.Model(table1).
FieldsPrefix(table1, "*").
LeftJoinOnField(table2, "id").
Where("id", g.Slice{1, 2}).
WhereIn("id", g.Slice{1, 2}).
Order("id asc").All()
t.AssertNil(err)
t.Assert(len(r), 2)
@ -50,7 +50,7 @@ func Test_Model_RightJoinOnField(t *testing.T) {
r, err := db.Model(table1).
FieldsPrefix(table1, "*").
RightJoinOnField(table2, "id").
Where("id", g.Slice{1, 2}).
WhereIn("id", g.Slice{1, 2}).
Order("id asc").All()
t.AssertNil(err)
t.Assert(len(r), 2)
@ -73,7 +73,7 @@ func Test_Model_InnerJoinOnField(t *testing.T) {
r, err := db.Model(table1).
FieldsPrefix(table1, "*").
InnerJoinOnField(table2, "id").
Where("id", g.Slice{1, 2}).
WhereIn("id", g.Slice{1, 2}).
Order("id asc").All()
t.AssertNil(err)
t.Assert(len(r), 2)