From 5f664f331aadc1e698e54572d872260cfe392e4d Mon Sep 17 00:00:00 2001 From: John Guo Date: Mon, 8 Dec 2025 16:28:34 +0800 Subject: [PATCH] up --- contrib/drivers/dm/dm_do_insert.go | 26 ++------ contrib/drivers/mssql/mssql_do_insert.go | 62 +++++++++++++------ .../drivers/mssql/mssql_z_unit_basic_test.go | 4 +- .../drivers/mssql/mssql_z_unit_model_test.go | 43 ++++++++++++- contrib/drivers/pgsql/pgsql_do_insert.go | 30 +++------ database/gdb/gdb_core_utility.go | 20 ++++++ 6 files changed, 119 insertions(+), 66 deletions(-) diff --git a/contrib/drivers/dm/dm_do_insert.go b/contrib/drivers/dm/dm_do_insert.go index b5ccb336c..01a00354f 100644 --- a/contrib/drivers/dm/dm_do_insert.go +++ b/contrib/drivers/dm/dm_do_insert.go @@ -66,7 +66,7 @@ func (d *Driver) doMergeInsert( // If OnConflict is not specified, automatically get the primary key of the table conflictKeys := option.OnConflict if len(conflictKeys) == 0 { - primaryKeys, err := d.getPrimaryKeys(ctx, table) + primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table) if err != nil { return nil, gerror.WrapCode( gcode.CodeInternalError, @@ -82,9 +82,11 @@ func (d *Driver) doMergeInsert( } } if !foundPrimaryKey { - return nil, gerror.NewCode( + return nil, gerror.NewCodef( gcode.CodeMissingParameter, - `Please specify conflict columns or ensure the record has a primary key for Save/Replace/InsertIgnore operation`, + `Replace/Save/InsertIgnore operation requires conflict detection: `+ + `either specify OnConflict() columns or ensure table '%s' has a primary key in the data`, + table, ) } conflictKeys = primaryKeys @@ -149,24 +151,6 @@ func (d *Driver) doMergeInsert( return batchResult, nil } -// getPrimaryKeys retrieves the primary key field names of the table as a slice of strings. -// This method extracts primary key information from TableFields. -func (d *Driver) getPrimaryKeys(ctx context.Context, table string) ([]string, error) { - tableFields, err := d.TableFields(ctx, table) - if err != nil { - return nil, err - } - - var primaryKeys []string - for _, field := range tableFields { - if gstr.Equal(field.Key, "PRI") { - primaryKeys = append(primaryKeys, field.Name) - } - } - - return primaryKeys, nil -} - // parseSqlForMerge generates MERGE statement for DM database. // When updateValues is empty, it only inserts (INSERT IGNORE behavior). // When updateValues is provided, it performs upsert (INSERT or UPDATE). diff --git a/contrib/drivers/mssql/mssql_do_insert.go b/contrib/drivers/mssql/mssql_do_insert.go index 4284780b7..48142a478 100644 --- a/contrib/drivers/mssql/mssql_do_insert.go +++ b/contrib/drivers/mssql/mssql_do_insert.go @@ -21,17 +21,45 @@ import ( // DoInsert inserts or updates data for given table. // The list parameter must contain at least one record, which was previously validated. -func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) { +func (d *Driver) DoInsert( + ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { switch option.InsertOption { - case gdb.InsertOptionSave: + case + gdb.InsertOptionSave, + gdb.InsertOptionReplace: + // MSSQL does not support REPLACE INTO syntax. + // Convert Replace to Save operation, using MERGE statement. + // Auto-detect primary keys if OnConflict is not specified. + if len(option.OnConflict) == 0 { + primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table) + if err != nil { + return nil, gerror.WrapCode( + gcode.CodeInternalError, + err, + `failed to get primary keys for Replace operation`, + ) + } + foundPrimaryKey := false + for _, primaryKey := range primaryKeys { + if _, ok := list[0][primaryKey]; ok { + foundPrimaryKey = true + break + } + } + if !foundPrimaryKey { + return nil, gerror.NewCodef( + gcode.CodeMissingParameter, + `Save/Replace operation requires conflict detection: `+ + `either specify OnConflict() columns or ensure table '%s' has a primary key in the data`, + table, + ) + } + option.OnConflict = primaryKeys + } + // Convert to Save operation return d.doSave(ctx, link, table, list, option) - case gdb.InsertOptionReplace: - return nil, gerror.NewCode( - gcode.CodeNotSupported, - `Replace operation is not supported by mssql driver`, - ) - default: return d.Core.DoInsert(ctx, link, table, list, option) } @@ -41,17 +69,10 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list func (d *Driver) doSave(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, ) (result sql.Result, err error) { - if len(option.OnConflict) == 0 { - return nil, gerror.NewCode( - gcode.CodeMissingParameter, `Please specify conflict columns`, - ) - } - var ( - one = list[0] - oneLen = len(one) - charL, charR = d.GetChars() - + one = list[0] + oneLen = len(one) + charL, charR = d.GetChars() conflictKeys = option.OnConflict conflictKeySet = gset.New(false) @@ -122,7 +143,10 @@ func parseSqlForUpsert(table string, insertValueStr = strings.Join(insertValues, ",") updateValueStr = strings.Join(updateValues, ",") duplicateKeyStr string - pattern = gstr.Trim(`MERGE INTO %s T1 USING (VALUES(%s)) T2 (%s) ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`) + pattern = gstr.Trim( + `MERGE INTO %s T1 USING (VALUES(%s)) T2 (%s) ON (%s) WHEN NOT MATCHED ` + + `THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`, + ) ) for index, keys := range duplicateKey { diff --git a/contrib/drivers/mssql/mssql_z_unit_basic_test.go b/contrib/drivers/mssql/mssql_z_unit_basic_test.go index 0999a1eee..9933f751c 100644 --- a/contrib/drivers/mssql/mssql_z_unit_basic_test.go +++ b/contrib/drivers/mssql/mssql_z_unit_basic_test.go @@ -138,15 +138,17 @@ func TestDoInsert(t *testing.T) { i := 10 data := g.Map{ - "id": i, + // "id": i, "passport": fmt.Sprintf(`t%d`, i), "password": fmt.Sprintf(`p%d`, i), "nickname": fmt.Sprintf(`T%d`, i), "create_time": gtime.Now(), } + // Save without OnConflict should fail (missing conflict columns) _, err := db.Save(context.Background(), "t_user", data, 10) gtest.AssertNE(err, nil) + // Replace should now work (it will auto-detect primary key) _, err = db.Replace(context.Background(), "t_user", data, 10) gtest.AssertNE(err, nil) }) diff --git a/contrib/drivers/mssql/mssql_z_unit_model_test.go b/contrib/drivers/mssql/mssql_z_unit_model_test.go index b3a1daa81..ce225dcad 100644 --- a/contrib/drivers/mssql/mssql_z_unit_model_test.go +++ b/contrib/drivers/mssql/mssql_z_unit_model_test.go @@ -2658,14 +2658,53 @@ func Test_Model_Replace(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - _, err := db.Model(table).Data(g.Map{ + // Insert initial record + result, err := db.Model(table).Data(g.Map{ + "id": 1, + "passport": "t1", + "password": "pass1", + "nickname": "T1", + "create_time": "2018-10-24 10:00:00", + }).Insert() + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + // Replace with new data (should update existing record using MERGE) + result, err = db.Model(table).Data(g.Map{ "id": 1, "passport": "t11", "password": "25d55ad283aa400af464c76d713c07ad", "nickname": "T11", "create_time": "2018-10-24 10:00:00", }).Replace() - t.Assert(err, "Replace operation is not supported by mssql driver") + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) + + // Verify the data was replaced + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "t11") + t.Assert(one["NICKNAME"].String(), "T11") + + // Replace with non-existing record (should insert new record) + result, err = db.Model(table).Data(g.Map{ + "id": 2, + "passport": "t222", + "password": "pass2", + "nickname": "T222", + "create_time": "2018-10-24 11:00:00", + }).Replace() + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) // MERGE reports: 1 for insert + + // Verify the new record was inserted + one, err = db.Model(table).WherePri(2).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "t222") + t.Assert(one["NICKNAME"].String(), "T222") }) } diff --git a/contrib/drivers/pgsql/pgsql_do_insert.go b/contrib/drivers/pgsql/pgsql_do_insert.go index 4bd7194c2..55e6f761c 100644 --- a/contrib/drivers/pgsql/pgsql_do_insert.go +++ b/contrib/drivers/pgsql/pgsql_do_insert.go @@ -24,12 +24,12 @@ func (d *Driver) DoInsert( ) (result sql.Result, err error) { switch option.InsertOption { case - gdb.InsertOptionReplace, - gdb.InsertOptionSave: + gdb.InsertOptionSave, + gdb.InsertOptionReplace: // PostgreSQL does not support REPLACE INTO syntax, use Save (ON CONFLICT ... DO UPDATE) instead. // Automatically detect primary keys if OnConflict is not specified. if len(option.OnConflict) == 0 { - primaryKeys, err := d.getPrimaryKeys(ctx, table) + primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table) if err != nil { return nil, gerror.WrapCode( gcode.CodeInternalError, @@ -45,9 +45,11 @@ func (d *Driver) DoInsert( } } if !foundPrimaryKey { - return nil, gerror.NewCode( + return nil, gerror.NewCodef( gcode.CodeMissingParameter, - `Please specify conflict columns or ensure the record has a primary key for Save/Replace operation`, + `Replace/Save operation requires conflict detection: `+ + `either specify OnConflict() columns or ensure table '%s' has a primary key in the data`, + table, ) } option.OnConflict = primaryKeys @@ -71,21 +73,3 @@ func (d *Driver) DoInsert( } return d.Core.DoInsert(ctx, link, table, list, option) } - -// getPrimaryKeys retrieves the primary key field list of the table. -// This method extracts primary key information from TableFields. -func (d *Driver) getPrimaryKeys(ctx context.Context, table string) ([]string, error) { - tableFields, err := d.TableFields(ctx, table) - if err != nil { - return nil, err - } - - var primaryKeys []string - for _, field := range tableFields { - if gstr.Equal(field.Key, "pri") { - primaryKeys = append(primaryKeys, field.Name) - } - } - - return primaryKeys, nil -} diff --git a/database/gdb/gdb_core_utility.go b/database/gdb/gdb_core_utility.go index 4872f13a1..b97d7431e 100644 --- a/database/gdb/gdb_core_utility.go +++ b/database/gdb/gdb_core_utility.go @@ -10,6 +10,7 @@ package gdb import ( "context" "fmt" + "strings" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" @@ -251,3 +252,22 @@ func (c *Core) guessPrimaryTableName(tableStr string) string { } return guessedTableName } + +// GetPrimaryKeys retrieves and returns the primary key field names of the specified table. +// This method extracts primary key information from TableFields. +// The parameter `schema` is optional, if not specified it uses the default schema. +func (c *Core) GetPrimaryKeys(ctx context.Context, table string, schema ...string) ([]string, error) { + tableFields, err := c.db.TableFields(ctx, table, schema...) + if err != nil { + return nil, err + } + + var primaryKeys []string + for _, field := range tableFields { + if strings.EqualFold(field.Key, "pri") { + primaryKeys = append(primaryKeys, field.Name) + } + } + + return primaryKeys, nil +}