diff --git a/contrib/drivers/README.MD b/contrib/drivers/README.MD index 37a8a5407..c79d60926 100644 --- a/contrib/drivers/README.MD +++ b/contrib/drivers/README.MD @@ -52,7 +52,7 @@ Note: import _ "github.com/gogf/gf/contrib/drivers/mssql/v2" ``` Note: -- It does not support `Save/Replace` features. +- It does not support `Replace` features. - It does not support `LastInsertId`. - It supports server version >= `SQL Server2005` - It ONLY supports datetime2 and datetimeoffset types for auto handling created_at/updated_at/deleted_at columns, because datetime type does not support microseconds precision when column value is passed as string. @@ -62,7 +62,7 @@ Note: import _ "github.com/gogf/gf/contrib/drivers/oracle/v2" ``` Note: -- It does not support `Save/Replace` features. +- It does not support `Replace` features. - It does not support `LastInsertId`. ## ClickHouse diff --git a/contrib/drivers/mssql/mssql.go b/contrib/drivers/mssql/mssql.go index 0d65d80e4..20b105af6 100644 --- a/contrib/drivers/mssql/mssql.go +++ b/contrib/drivers/mssql/mssql.go @@ -7,7 +7,7 @@ // Package mssql implements gdb.Driver, which supports operations for database MSSql. // // Note: -// 1. It does not support Save/Replace features. +// 1. It does not support Replace features. // 2. It does not support LastInsertId. package mssql diff --git a/contrib/drivers/mssql/mssql_do_insert.go b/contrib/drivers/mssql/mssql_do_insert.go index 02a3facaa..7b71dab42 100644 --- a/contrib/drivers/mssql/mssql_do_insert.go +++ b/contrib/drivers/mssql/mssql_do_insert.go @@ -9,20 +9,21 @@ package mssql import ( "context" "database/sql" + "fmt" + "strings" + "github.com/gogf/gf/v2/container/gset" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/text/gstr" ) // DoInsert inserts or updates data for given table. func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) { switch option.InsertOption { case gdb.InsertOptionSave: - return nil, gerror.NewCode( - gcode.CodeNotSupported, - `Save operation is not supported by mssql driver`, - ) + return d.doSave(ctx, link, table, list, option) case gdb.InsertOptionReplace: return nil, gerror.NewCode( @@ -34,3 +35,116 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list return d.Core.DoInsert(ctx, link, table, list, option) } } + +// doSave support upsert for SQL server +func (d *Driver) doSave(ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { + if len(option.OnConflict) == 0 { + return nil, gerror.NewCode( + gcode.CodeMissingParameter, `Please specify conflict columns`, + ) + } + + if len(list) == 0 { + return nil, gerror.NewCode( + gcode.CodeInvalidRequest, `Save operation list is empty by mssql driver`, + ) + } + + var ( + one = list[0] + oneLen = len(one) + charL, charR = d.GetChars() + + conflictKeys = option.OnConflict + conflictKeySet = gset.New(false) + + // queryHolders: Handle data with Holder that need to be upsert + // queryValues: Handle data that need to be upsert + // insertKeys: Handle valid keys that need to be inserted + // insertValues: Handle values that need to be inserted + // updateValues: Handle values that need to be updated + queryHolders = make([]string, oneLen) + queryValues = make([]interface{}, oneLen) + insertKeys = make([]string, oneLen) + insertValues = make([]string, oneLen) + updateValues []string + ) + + // conflictKeys slice type conv to set type + for _, conflictKey := range conflictKeys { + conflictKeySet.Add(gstr.ToUpper(conflictKey)) + } + + index := 0 + for key, value := range one { + queryHolders[index] = "?" + queryValues[index] = value + insertKeys[index] = charL + key + charR + insertValues[index] = "T2." + charL + key + charR + + // filter conflict keys in updateValues. + // And the key is not a soft created field. + if !(conflictKeySet.Contains(key) || d.Core.IsSoftCreatedFieldName(key)) { + updateValues = append( + updateValues, + fmt.Sprintf(`T1.%s = T2.%s`, charL+key+charR, charL+key+charR), + ) + } + index++ + } + + batchResult := new(gdb.SqlResult) + sqlStr := parseSqlForUpsert(table, queryHolders, insertKeys, insertValues, updateValues, conflictKeys) + r, err := d.DoExec(ctx, link, sqlStr, queryValues...) + if err != nil { + return r, err + } + if n, err := r.RowsAffected(); err != nil { + return r, err + } else { + batchResult.Result = r + batchResult.Affected += n + } + return batchResult, nil +} + +// parseSqlForUpsert +// MERGE INTO {{table}} T1 +// USING ( VALUES( {{queryHolders}}) T2 ({{insertKeyStr}}) +// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...) +// WHEN NOT MATCHED THEN +// INSERT {{insertKeys}} VALUES {{insertValues}} +// WHEN MATCHED THEN +// UPDATE SET {{updateValues}} +func parseSqlForUpsert(table string, + queryHolders, insertKeys, insertValues, updateValues, duplicateKey []string, +) (sqlStr string) { + var ( + queryHolderStr = strings.Join(queryHolders, ",") + insertKeyStr = strings.Join(insertKeys, ",") + insertValueStr = strings.Join(insertValues, ",") + updateValueStr = strings.Join(updateValues, ",") + duplicateKeyStr string + pattern = gstr.Trim(`MERGE INTO %s T1 USING (VALUES(%s)) T2 (%s) ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`) + ) + + for index, keys := range duplicateKey { + if index != 0 { + duplicateKeyStr += " AND " + } + duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys) + duplicateKeyStr += duplicateTmp + } + + return fmt.Sprintf(pattern, + table, + queryHolderStr, + insertKeyStr, + duplicateKeyStr, + insertKeyStr, + insertValueStr, + updateValueStr, + ) +} diff --git a/contrib/drivers/mssql/mssql_z_unit_model_test.go b/contrib/drivers/mssql/mssql_z_unit_model_test.go index fe092043b..5a3a94aa5 100644 --- a/contrib/drivers/mssql/mssql_z_unit_model_test.go +++ b/contrib/drivers/mssql/mssql_z_unit_model_test.go @@ -24,10 +24,10 @@ import ( "github.com/gogf/gf/v2/util/gutil" ) -func TestPage(t *testing.T) { +func Test_Page(t *testing.T) { table := createInitTable() defer dropTable(table) - //db.SetDebug(true) + // db.SetDebug(true) result, err := db.Model(table).Page(1, 2).Order("id").All() gtest.Assert(err, nil) fmt.Println("page:1--------", result) @@ -2588,3 +2588,80 @@ func Test_Model_ScanAndCount(t *testing.T) { t.Assert(total, TableSize) }) } + +func Test_Model_Save(t *testing.T) { + table := createTable("test") + defer dropTable(table) + gtest.C(t, func(t *gtest.T) { + type User struct { + Id int + Passport string + Password string + NickName string + CreatedAt *gtime.Time + UpdatedAt *gtime.Time + } + var ( + user User + count int + result sql.Result + err error + ) + + result, err = db.Model(table).Data(g.Map{ + "id": 1, + "passport": "p1", + "password": "15d55ad283aa400af464c76d713c07ad", + "nickname": "n1", + }).OnConflict("id").Save() + + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.Assert(user.Id, 1) + t.Assert(user.Passport, "p1") + t.Assert(user.Password, "15d55ad283aa400af464c76d713c07ad") + t.Assert(user.NickName, "n1") + + // Sleep 1 second to make sure the updated time is different. + time.Sleep(1 * time.Second) + _, err = db.Model(table).Data(g.Map{ + "id": 1, + "passport": "p1", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "n2", + }).OnConflict("id").Save() + t.AssertNil(err) + + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.Assert(user.Passport, "p1") + t.Assert(user.Password, "25d55ad283aa400af464c76d713c07ad") + t.Assert(user.NickName, "n2") + // check created_at not equal to updated_at + t.AssertNE(user.CreatedAt, user.UpdatedAt) + + count, err = db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 1) + }) +} + +func Test_Model_Replace(t *testing.T) { + table := createTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + _, err := db.Model(table).Data(g.Map{ + "id": 1, + "passport": "t11", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "T11", + "create_time": "2018-10-24 10:00:00", + }).Replace() + t.Assert(err, "Replace operation is not supported by mssql driver") + }) +}