From 6cddfdb313563f8bda3737387fe146740b2155e4 Mon Sep 17 00:00:00 2001 From: John Guo Date: Tue, 5 Sep 2023 19:29:28 +0800 Subject: [PATCH] improve join feature for package gdb (#2929) --- container/gset/gset_any_set.go | 2 +- contrib/drivers/mysql/mysql_core_test.go | 21 ++---- database/gdb/gdb.go | 21 ++++-- database/gdb/gdb_model_join.go | 81 +++++++++++++----------- 4 files changed, 66 insertions(+), 59 deletions(-) diff --git a/container/gset/gset_any_set.go b/container/gset/gset_any_set.go index 179862ca6..73f412d4b 100644 --- a/container/gset/gset_any_set.go +++ b/container/gset/gset_any_set.go @@ -184,7 +184,7 @@ func (set *Set) Clear() { set.mu.Unlock() } -// Slice returns the an of items of the set as slice. +// Slice returns all items of the set as slice. func (set *Set) Slice() []interface{} { set.mu.RLock() var ( diff --git a/contrib/drivers/mysql/mysql_core_test.go b/contrib/drivers/mysql/mysql_core_test.go index cc07f232f..874056245 100644 --- a/contrib/drivers/mysql/mysql_core_test.go +++ b/contrib/drivers/mysql/mysql_core_test.go @@ -1404,29 +1404,18 @@ func Test_Model_LeftJoin(t *testing.T) { defer dropTable(table2) res, err := db.Model(table2).Where("id > ?", 3).Delete() - if err != nil { - t.Fatal(err) - } + t.AssertNil(err) n, err := res.RowsAffected() - if err != nil { - t.Fatal(err) - } else { - t.Assert(n, 7) - } + t.AssertNil(err) + t.Assert(n, 7) result, err := db.Model(table1+" u1").LeftJoin(table2+" u2", "u1.id = u2.id").All() - if err != nil { - t.Fatal(err) - } - + t.AssertNil(err) t.Assert(len(result), 10) result, err = db.Model(table1+" u1").LeftJoin(table2+" u2", "u1.id = u2.id").Where("u1.id > ? ", 2).All() - if err != nil { - t.Fatal(err) - } - + t.AssertNil(err) t.Assert(len(result), 8) }) } diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 1ea5ffa7d..c2c50144a 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -356,15 +356,10 @@ type CatchSQLManager struct { DoCommit bool } -type queryType int - const ( defaultModelSafe = false defaultCharset = `utf8` defaultProtocol = `tcp` - queryTypeNormal queryType = 0 - queryTypeCount queryType = 1 - queryTypeValue queryType = 2 unionTypeNormal = 0 unionTypeAll = 1 defaultMaxIdleConnCount = 10 // Max idle connection count in pool. @@ -386,6 +381,22 @@ const ( linkPattern = `(\w+):([\w\-]*):(.*?)@(\w+?)\((.+?)\)/{0,1}([^\?]*)\?{0,1}(.*)` ) +type queryType int + +const ( + queryTypeNormal queryType = 0 + queryTypeCount queryType = 1 + queryTypeValue queryType = 2 +) + +type joinOperator string + +const ( + joinOperatorLeft joinOperator = "LEFT" + joinOperatorRight joinOperator = "RIGHT" + joinOperatorInner joinOperator = "INNER" +) + type InsertOption int const ( diff --git a/database/gdb/gdb_model_join.go b/database/gdb/gdb_model_join.go index ce4dbf1ab..eacb0d5a8 100644 --- a/database/gdb/gdb_model_join.go +++ b/database/gdb/gdb_model_join.go @@ -12,17 +12,6 @@ import ( "github.com/gogf/gf/v2/text/gstr" ) -// isSubQuery checks and returns whether given string a sub-query sql string. -func isSubQuery(s string) bool { - s = gstr.TrimLeft(s, "()") - if p := gstr.Pos(s, " "); p != -1 { - if gstr.Equal(s[:p], "select") { - return true - } - } - return false -} - // LeftJoin does "LEFT JOIN ... ON ..." statement on the model. // The parameter `table` can be joined table and its joined condition, // and also with its alias name. @@ -31,8 +20,8 @@ func isSubQuery(s string) bool { // Model("user").LeftJoin("user_detail", "user_detail.uid=user.uid") // Model("user", "u").LeftJoin("user_detail", "ud", "ud.uid=u.uid") // Model("user", "u").LeftJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid"). -func (m *Model) LeftJoin(tableAliasOrSubquery ...string) *Model { - return m.doJoin("LEFT", tableAliasOrSubquery...) +func (m *Model) LeftJoin(tableOrSubQueryAndJoinConditions ...string) *Model { + return m.doJoin(joinOperatorLeft, tableOrSubQueryAndJoinConditions...) } // RightJoin does "RIGHT JOIN ... ON ..." statement on the model. @@ -43,8 +32,8 @@ func (m *Model) LeftJoin(tableAliasOrSubquery ...string) *Model { // Model("user").RightJoin("user_detail", "user_detail.uid=user.uid") // Model("user", "u").RightJoin("user_detail", "ud", "ud.uid=u.uid") // Model("user", "u").RightJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid"). -func (m *Model) RightJoin(tableAliasOrSubquery ...string) *Model { - return m.doJoin("RIGHT", tableAliasOrSubquery...) +func (m *Model) RightJoin(tableOrSubQueryAndJoinConditions ...string) *Model { + return m.doJoin(joinOperatorRight, tableOrSubQueryAndJoinConditions...) } // InnerJoin does "INNER JOIN ... ON ..." statement on the model. @@ -55,17 +44,17 @@ func (m *Model) RightJoin(tableAliasOrSubquery ...string) *Model { // Model("user").InnerJoin("user_detail", "user_detail.uid=user.uid") // Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid=u.uid") // Model("user", "u").InnerJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid"). -func (m *Model) InnerJoin(tableAliasOrSubquery ...string) *Model { - return m.doJoin("INNER", tableAliasOrSubquery...) +func (m *Model) InnerJoin(tableOrSubQueryAndJoinConditions ...string) *Model { + return m.doJoin(joinOperatorInner, tableOrSubQueryAndJoinConditions...) } -// LeftJoinOnField performs as LeftJoin, but it joins both tables with the same field name. +// LeftJoinOnField performs as LeftJoin, but it joins both tables with the `same field name`. // // Eg: // Model("order").LeftJoinOnField("user", "user_id") // Model("order").LeftJoinOnField("product", "product_id"). func (m *Model) LeftJoinOnField(table, field string) *Model { - return m.doJoin("LEFT", table, fmt.Sprintf( + return m.doJoin(joinOperatorLeft, table, fmt.Sprintf( `%s.%s=%s.%s`, m.tablesInit, m.db.GetCore().QuoteWord(field), @@ -74,13 +63,13 @@ func (m *Model) LeftJoinOnField(table, field string) *Model { )) } -// RightJoinOnField performs as RightJoin, but it joins both tables with the same field name. +// RightJoinOnField performs as RightJoin, but it joins both tables with the `same field name`. // // Eg: // Model("order").InnerJoinOnField("user", "user_id") // Model("order").InnerJoinOnField("product", "product_id"). func (m *Model) RightJoinOnField(table, field string) *Model { - return m.doJoin("RIGHT", table, fmt.Sprintf( + return m.doJoin(joinOperatorRight, table, fmt.Sprintf( `%s.%s=%s.%s`, m.tablesInit, m.db.GetCore().QuoteWord(field), @@ -89,13 +78,13 @@ func (m *Model) RightJoinOnField(table, field string) *Model { )) } -// InnerJoinOnField performs as InnerJoin, but it joins both tables with the same field name. +// InnerJoinOnField performs as InnerJoin, but it joins both tables with the `same field name`. // // Eg: // Model("order").InnerJoinOnField("user", "user_id") // Model("order").InnerJoinOnField("product", "product_id"). func (m *Model) InnerJoinOnField(table, field string) *Model { - return m.doJoin("INNER", table, fmt.Sprintf( + return m.doJoin(joinOperatorInner, table, fmt.Sprintf( `%s.%s=%s.%s`, m.tablesInit, m.db.GetCore().QuoteWord(field), @@ -111,7 +100,7 @@ func (m *Model) InnerJoinOnField(table, field string) *Model { // Model("user").LeftJoinOnFields("order", "id", ">", "user_id") // Model("user").LeftJoinOnFields("order", "id", "<", "user_id") func (m *Model) LeftJoinOnFields(table, firstField, operator, secondField string) *Model { - return m.doJoin("LEFT", table, fmt.Sprintf( + return m.doJoin(joinOperatorLeft, table, fmt.Sprintf( `%s.%s %s %s.%s`, m.tablesInit, m.db.GetCore().QuoteWord(firstField), @@ -128,7 +117,7 @@ func (m *Model) LeftJoinOnFields(table, firstField, operator, secondField string // Model("user").RightJoinOnFields("order", "id", ">", "user_id") // Model("user").RightJoinOnFields("order", "id", "<", "user_id") func (m *Model) RightJoinOnFields(table, firstField, operator, secondField string) *Model { - return m.doJoin("RIGHT", table, fmt.Sprintf( + return m.doJoin(joinOperatorRight, table, fmt.Sprintf( `%s.%s %s %s.%s`, m.tablesInit, m.db.GetCore().QuoteWord(firstField), @@ -145,7 +134,7 @@ func (m *Model) RightJoinOnFields(table, firstField, operator, secondField strin // Model("user").InnerJoinOnFields("order", "id", ">", "user_id") // Model("user").InnerJoinOnFields("order", "id", "<", "user_id") func (m *Model) InnerJoinOnFields(table, firstField, operator, secondField string) *Model { - return m.doJoin("INNER", table, fmt.Sprintf( + return m.doJoin(joinOperatorInner, table, fmt.Sprintf( `%s.%s %s %s.%s`, m.tablesInit, m.db.GetCore().QuoteWord(firstField), @@ -156,44 +145,62 @@ func (m *Model) InnerJoinOnFields(table, firstField, operator, secondField strin } // doJoin does "LEFT/RIGHT/INNER JOIN ... ON ..." statement on the model. -// The parameter `table` can be joined table and its joined condition, +// The parameter `tableOrSubQueryAndJoinConditions` can be joined table and its joined condition, // and also with its alias name. // // Eg: // Model("user").InnerJoin("user_detail", "user_detail.uid=user.uid") // Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid=u.uid") +// Model("user", "u").InnerJoin("user_detail", "ud", "ud.uid>u.uid") // Model("user", "u").InnerJoin("SELECT xxx FROM xxx","a", "a.uid=u.uid") // Related issues: // https://github.com/gogf/gf/issues/1024 -func (m *Model) doJoin(operator string, table ...string) *Model { +func (m *Model) doJoin(operator joinOperator, tableOrSubQueryAndJoinConditions ...string) *Model { var ( model = m.getModel() joinStr = "" ) - if len(table) > 0 { - if isSubQuery(table[0]) { - joinStr = gstr.Trim(table[0]) + // Check the first parameter table or sub-query. + if len(tableOrSubQueryAndJoinConditions) > 0 { + if isSubQuery(tableOrSubQueryAndJoinConditions[0]) { + joinStr = gstr.Trim(tableOrSubQueryAndJoinConditions[0]) if joinStr[0] != '(' { joinStr = "(" + joinStr + ")" } } else { - joinStr = m.db.GetCore().QuotePrefixTableName(table[0]) + joinStr = m.db.GetCore().QuotePrefixTableName(tableOrSubQueryAndJoinConditions[0]) } } - if len(table) > 2 { + // Generate join condition statement string. + conditionLength := len(tableOrSubQueryAndJoinConditions) + switch { + case conditionLength > 2: model.tables += fmt.Sprintf( " %s JOIN %s AS %s ON (%s)", - operator, joinStr, m.db.GetCore().QuoteWord(table[1]), table[2], + operator, joinStr, + m.db.GetCore().QuoteWord(tableOrSubQueryAndJoinConditions[1]), + tableOrSubQueryAndJoinConditions[2], ) - } else if len(table) == 2 { + case conditionLength == 2: model.tables += fmt.Sprintf( " %s JOIN %s ON (%s)", - operator, joinStr, table[1], + operator, joinStr, tableOrSubQueryAndJoinConditions[1], ) - } else if len(table) == 1 { + case conditionLength == 1: model.tables += fmt.Sprintf( " %s JOIN %s", operator, joinStr, ) } return model } + +// isSubQuery checks and returns whether given string a sub-query sql string. +func isSubQuery(s string) bool { + s = gstr.TrimLeft(s, "()") + if p := gstr.Pos(s, " "); p != -1 { + if gstr.Equal(s[:p], "select") { + return true + } + } + return false +}