This commit is contained in:
John Guo
2025-12-08 16:28:34 +08:00
parent 67a8a28a18
commit 5f664f331a
6 changed files with 119 additions and 66 deletions

View File

@ -66,7 +66,7 @@ func (d *Driver) doMergeInsert(
// If OnConflict is not specified, automatically get the primary key of the table
conflictKeys := option.OnConflict
if len(conflictKeys) == 0 {
primaryKeys, err := d.getPrimaryKeys(ctx, table)
primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table)
if err != nil {
return nil, gerror.WrapCode(
gcode.CodeInternalError,
@ -82,9 +82,11 @@ func (d *Driver) doMergeInsert(
}
}
if !foundPrimaryKey {
return nil, gerror.NewCode(
return nil, gerror.NewCodef(
gcode.CodeMissingParameter,
`Please specify conflict columns or ensure the record has a primary key for Save/Replace/InsertIgnore operation`,
`Replace/Save/InsertIgnore operation requires conflict detection: `+
`either specify OnConflict() columns or ensure table '%s' has a primary key in the data`,
table,
)
}
conflictKeys = primaryKeys
@ -149,24 +151,6 @@ func (d *Driver) doMergeInsert(
return batchResult, nil
}
// getPrimaryKeys retrieves the primary key field names of the table as a slice of strings.
// This method extracts primary key information from TableFields.
func (d *Driver) getPrimaryKeys(ctx context.Context, table string) ([]string, error) {
tableFields, err := d.TableFields(ctx, table)
if err != nil {
return nil, err
}
var primaryKeys []string
for _, field := range tableFields {
if gstr.Equal(field.Key, "PRI") {
primaryKeys = append(primaryKeys, field.Name)
}
}
return primaryKeys, nil
}
// parseSqlForMerge generates MERGE statement for DM database.
// When updateValues is empty, it only inserts (INSERT IGNORE behavior).
// When updateValues is provided, it performs upsert (INSERT or UPDATE).

View File

@ -21,17 +21,45 @@ import (
// DoInsert inserts or updates data for given table.
// The list parameter must contain at least one record, which was previously validated.
func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) {
func (d *Driver) DoInsert(
ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption,
) (result sql.Result, err error) {
switch option.InsertOption {
case gdb.InsertOptionSave:
case
gdb.InsertOptionSave,
gdb.InsertOptionReplace:
// MSSQL does not support REPLACE INTO syntax.
// Convert Replace to Save operation, using MERGE statement.
// Auto-detect primary keys if OnConflict is not specified.
if len(option.OnConflict) == 0 {
primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table)
if err != nil {
return nil, gerror.WrapCode(
gcode.CodeInternalError,
err,
`failed to get primary keys for Replace operation`,
)
}
foundPrimaryKey := false
for _, primaryKey := range primaryKeys {
if _, ok := list[0][primaryKey]; ok {
foundPrimaryKey = true
break
}
}
if !foundPrimaryKey {
return nil, gerror.NewCodef(
gcode.CodeMissingParameter,
`Save/Replace operation requires conflict detection: `+
`either specify OnConflict() columns or ensure table '%s' has a primary key in the data`,
table,
)
}
option.OnConflict = primaryKeys
}
// Convert to Save operation
return d.doSave(ctx, link, table, list, option)
case gdb.InsertOptionReplace:
return nil, gerror.NewCode(
gcode.CodeNotSupported,
`Replace operation is not supported by mssql driver`,
)
default:
return d.Core.DoInsert(ctx, link, table, list, option)
}
@ -41,17 +69,10 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list
func (d *Driver) doSave(ctx context.Context,
link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption,
) (result sql.Result, err error) {
if len(option.OnConflict) == 0 {
return nil, gerror.NewCode(
gcode.CodeMissingParameter, `Please specify conflict columns`,
)
}
var (
one = list[0]
oneLen = len(one)
charL, charR = d.GetChars()
one = list[0]
oneLen = len(one)
charL, charR = d.GetChars()
conflictKeys = option.OnConflict
conflictKeySet = gset.New(false)
@ -122,7 +143,10 @@ func parseSqlForUpsert(table string,
insertValueStr = strings.Join(insertValues, ",")
updateValueStr = strings.Join(updateValues, ",")
duplicateKeyStr string
pattern = gstr.Trim(`MERGE INTO %s T1 USING (VALUES(%s)) T2 (%s) ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`)
pattern = gstr.Trim(
`MERGE INTO %s T1 USING (VALUES(%s)) T2 (%s) ON (%s) WHEN NOT MATCHED ` +
`THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`,
)
)
for index, keys := range duplicateKey {

View File

@ -138,15 +138,17 @@ func TestDoInsert(t *testing.T) {
i := 10
data := g.Map{
"id": i,
// "id": i,
"passport": fmt.Sprintf(`t%d`, i),
"password": fmt.Sprintf(`p%d`, i),
"nickname": fmt.Sprintf(`T%d`, i),
"create_time": gtime.Now(),
}
// Save without OnConflict should fail (missing conflict columns)
_, err := db.Save(context.Background(), "t_user", data, 10)
gtest.AssertNE(err, nil)
// Replace should now work (it will auto-detect primary key)
_, err = db.Replace(context.Background(), "t_user", data, 10)
gtest.AssertNE(err, nil)
})

View File

@ -2658,14 +2658,53 @@ func Test_Model_Replace(t *testing.T) {
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
_, err := db.Model(table).Data(g.Map{
// Insert initial record
result, err := db.Model(table).Data(g.Map{
"id": 1,
"passport": "t1",
"password": "pass1",
"nickname": "T1",
"create_time": "2018-10-24 10:00:00",
}).Insert()
t.AssertNil(err)
n, _ := result.RowsAffected()
t.Assert(n, 1)
// Replace with new data (should update existing record using MERGE)
result, err = db.Model(table).Data(g.Map{
"id": 1,
"passport": "t11",
"password": "25d55ad283aa400af464c76d713c07ad",
"nickname": "T11",
"create_time": "2018-10-24 10:00:00",
}).Replace()
t.Assert(err, "Replace operation is not supported by mssql driver")
t.AssertNil(err)
n, _ = result.RowsAffected()
t.Assert(n, 1)
// Verify the data was replaced
one, err := db.Model(table).WherePri(1).One()
t.AssertNil(err)
t.Assert(one["PASSPORT"].String(), "t11")
t.Assert(one["NICKNAME"].String(), "T11")
// Replace with non-existing record (should insert new record)
result, err = db.Model(table).Data(g.Map{
"id": 2,
"passport": "t222",
"password": "pass2",
"nickname": "T222",
"create_time": "2018-10-24 11:00:00",
}).Replace()
t.AssertNil(err)
n, _ = result.RowsAffected()
t.Assert(n, 1) // MERGE reports: 1 for insert
// Verify the new record was inserted
one, err = db.Model(table).WherePri(2).One()
t.AssertNil(err)
t.Assert(one["PASSPORT"].String(), "t222")
t.Assert(one["NICKNAME"].String(), "T222")
})
}

View File

@ -24,12 +24,12 @@ func (d *Driver) DoInsert(
) (result sql.Result, err error) {
switch option.InsertOption {
case
gdb.InsertOptionReplace,
gdb.InsertOptionSave:
gdb.InsertOptionSave,
gdb.InsertOptionReplace:
// PostgreSQL does not support REPLACE INTO syntax, use Save (ON CONFLICT ... DO UPDATE) instead.
// Automatically detect primary keys if OnConflict is not specified.
if len(option.OnConflict) == 0 {
primaryKeys, err := d.getPrimaryKeys(ctx, table)
primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table)
if err != nil {
return nil, gerror.WrapCode(
gcode.CodeInternalError,
@ -45,9 +45,11 @@ func (d *Driver) DoInsert(
}
}
if !foundPrimaryKey {
return nil, gerror.NewCode(
return nil, gerror.NewCodef(
gcode.CodeMissingParameter,
`Please specify conflict columns or ensure the record has a primary key for Save/Replace operation`,
`Replace/Save operation requires conflict detection: `+
`either specify OnConflict() columns or ensure table '%s' has a primary key in the data`,
table,
)
}
option.OnConflict = primaryKeys
@ -71,21 +73,3 @@ func (d *Driver) DoInsert(
}
return d.Core.DoInsert(ctx, link, table, list, option)
}
// getPrimaryKeys retrieves the primary key field list of the table.
// This method extracts primary key information from TableFields.
func (d *Driver) getPrimaryKeys(ctx context.Context, table string) ([]string, error) {
tableFields, err := d.TableFields(ctx, table)
if err != nil {
return nil, err
}
var primaryKeys []string
for _, field := range tableFields {
if gstr.Equal(field.Key, "pri") {
primaryKeys = append(primaryKeys, field.Name)
}
}
return primaryKeys, nil
}

View File

@ -10,6 +10,7 @@ package gdb
import (
"context"
"fmt"
"strings"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
@ -251,3 +252,22 @@ func (c *Core) guessPrimaryTableName(tableStr string) string {
}
return guessedTableName
}
// GetPrimaryKeys retrieves and returns the primary key field names of the specified table.
// This method extracts primary key information from TableFields.
// The parameter `schema` is optional, if not specified it uses the default schema.
func (c *Core) GetPrimaryKeys(ctx context.Context, table string, schema ...string) ([]string, error) {
tableFields, err := c.db.TableFields(ctx, table, schema...)
if err != nil {
return nil, err
}
var primaryKeys []string
for _, field := range tableFields {
if strings.EqualFold(field.Key, "pri") {
primaryKeys = append(primaryKeys, field.Name)
}
}
return primaryKeys, nil
}