mirror of
https://gitee.com/johng/gf
synced 2026-06-06 02:25:47 +08:00
up
This commit is contained in:
@ -66,7 +66,7 @@ func (d *Driver) doMergeInsert(
|
||||
// If OnConflict is not specified, automatically get the primary key of the table
|
||||
conflictKeys := option.OnConflict
|
||||
if len(conflictKeys) == 0 {
|
||||
primaryKeys, err := d.getPrimaryKeys(ctx, table)
|
||||
primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table)
|
||||
if err != nil {
|
||||
return nil, gerror.WrapCode(
|
||||
gcode.CodeInternalError,
|
||||
@ -82,9 +82,11 @@ func (d *Driver) doMergeInsert(
|
||||
}
|
||||
}
|
||||
if !foundPrimaryKey {
|
||||
return nil, gerror.NewCode(
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeMissingParameter,
|
||||
`Please specify conflict columns or ensure the record has a primary key for Save/Replace/InsertIgnore operation`,
|
||||
`Replace/Save/InsertIgnore operation requires conflict detection: `+
|
||||
`either specify OnConflict() columns or ensure table '%s' has a primary key in the data`,
|
||||
table,
|
||||
)
|
||||
}
|
||||
conflictKeys = primaryKeys
|
||||
@ -149,24 +151,6 @@ func (d *Driver) doMergeInsert(
|
||||
return batchResult, nil
|
||||
}
|
||||
|
||||
// getPrimaryKeys retrieves the primary key field names of the table as a slice of strings.
|
||||
// This method extracts primary key information from TableFields.
|
||||
func (d *Driver) getPrimaryKeys(ctx context.Context, table string) ([]string, error) {
|
||||
tableFields, err := d.TableFields(ctx, table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var primaryKeys []string
|
||||
for _, field := range tableFields {
|
||||
if gstr.Equal(field.Key, "PRI") {
|
||||
primaryKeys = append(primaryKeys, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return primaryKeys, nil
|
||||
}
|
||||
|
||||
// parseSqlForMerge generates MERGE statement for DM database.
|
||||
// When updateValues is empty, it only inserts (INSERT IGNORE behavior).
|
||||
// When updateValues is provided, it performs upsert (INSERT or UPDATE).
|
||||
|
||||
@ -21,17 +21,45 @@ import (
|
||||
|
||||
// DoInsert inserts or updates data for given table.
|
||||
// The list parameter must contain at least one record, which was previously validated.
|
||||
func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) {
|
||||
func (d *Driver) DoInsert(
|
||||
ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption,
|
||||
) (result sql.Result, err error) {
|
||||
switch option.InsertOption {
|
||||
case gdb.InsertOptionSave:
|
||||
case
|
||||
gdb.InsertOptionSave,
|
||||
gdb.InsertOptionReplace:
|
||||
// MSSQL does not support REPLACE INTO syntax.
|
||||
// Convert Replace to Save operation, using MERGE statement.
|
||||
// Auto-detect primary keys if OnConflict is not specified.
|
||||
if len(option.OnConflict) == 0 {
|
||||
primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table)
|
||||
if err != nil {
|
||||
return nil, gerror.WrapCode(
|
||||
gcode.CodeInternalError,
|
||||
err,
|
||||
`failed to get primary keys for Replace operation`,
|
||||
)
|
||||
}
|
||||
foundPrimaryKey := false
|
||||
for _, primaryKey := range primaryKeys {
|
||||
if _, ok := list[0][primaryKey]; ok {
|
||||
foundPrimaryKey = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundPrimaryKey {
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeMissingParameter,
|
||||
`Save/Replace operation requires conflict detection: `+
|
||||
`either specify OnConflict() columns or ensure table '%s' has a primary key in the data`,
|
||||
table,
|
||||
)
|
||||
}
|
||||
option.OnConflict = primaryKeys
|
||||
}
|
||||
// Convert to Save operation
|
||||
return d.doSave(ctx, link, table, list, option)
|
||||
|
||||
case gdb.InsertOptionReplace:
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeNotSupported,
|
||||
`Replace operation is not supported by mssql driver`,
|
||||
)
|
||||
|
||||
default:
|
||||
return d.Core.DoInsert(ctx, link, table, list, option)
|
||||
}
|
||||
@ -41,17 +69,10 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list
|
||||
func (d *Driver) doSave(ctx context.Context,
|
||||
link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption,
|
||||
) (result sql.Result, err error) {
|
||||
if len(option.OnConflict) == 0 {
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeMissingParameter, `Please specify conflict columns`,
|
||||
)
|
||||
}
|
||||
|
||||
var (
|
||||
one = list[0]
|
||||
oneLen = len(one)
|
||||
charL, charR = d.GetChars()
|
||||
|
||||
one = list[0]
|
||||
oneLen = len(one)
|
||||
charL, charR = d.GetChars()
|
||||
conflictKeys = option.OnConflict
|
||||
conflictKeySet = gset.New(false)
|
||||
|
||||
@ -122,7 +143,10 @@ func parseSqlForUpsert(table string,
|
||||
insertValueStr = strings.Join(insertValues, ",")
|
||||
updateValueStr = strings.Join(updateValues, ",")
|
||||
duplicateKeyStr string
|
||||
pattern = gstr.Trim(`MERGE INTO %s T1 USING (VALUES(%s)) T2 (%s) ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`)
|
||||
pattern = gstr.Trim(
|
||||
`MERGE INTO %s T1 USING (VALUES(%s)) T2 (%s) ON (%s) WHEN NOT MATCHED ` +
|
||||
`THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`,
|
||||
)
|
||||
)
|
||||
|
||||
for index, keys := range duplicateKey {
|
||||
|
||||
@ -138,15 +138,17 @@ func TestDoInsert(t *testing.T) {
|
||||
|
||||
i := 10
|
||||
data := g.Map{
|
||||
"id": i,
|
||||
// "id": i,
|
||||
"passport": fmt.Sprintf(`t%d`, i),
|
||||
"password": fmt.Sprintf(`p%d`, i),
|
||||
"nickname": fmt.Sprintf(`T%d`, i),
|
||||
"create_time": gtime.Now(),
|
||||
}
|
||||
// Save without OnConflict should fail (missing conflict columns)
|
||||
_, err := db.Save(context.Background(), "t_user", data, 10)
|
||||
gtest.AssertNE(err, nil)
|
||||
|
||||
// Replace should now work (it will auto-detect primary key)
|
||||
_, err = db.Replace(context.Background(), "t_user", data, 10)
|
||||
gtest.AssertNE(err, nil)
|
||||
})
|
||||
|
||||
@ -2658,14 +2658,53 @@ func Test_Model_Replace(t *testing.T) {
|
||||
defer dropTable(table)
|
||||
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
_, err := db.Model(table).Data(g.Map{
|
||||
// Insert initial record
|
||||
result, err := db.Model(table).Data(g.Map{
|
||||
"id": 1,
|
||||
"passport": "t1",
|
||||
"password": "pass1",
|
||||
"nickname": "T1",
|
||||
"create_time": "2018-10-24 10:00:00",
|
||||
}).Insert()
|
||||
t.AssertNil(err)
|
||||
n, _ := result.RowsAffected()
|
||||
t.Assert(n, 1)
|
||||
|
||||
// Replace with new data (should update existing record using MERGE)
|
||||
result, err = db.Model(table).Data(g.Map{
|
||||
"id": 1,
|
||||
"passport": "t11",
|
||||
"password": "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname": "T11",
|
||||
"create_time": "2018-10-24 10:00:00",
|
||||
}).Replace()
|
||||
t.Assert(err, "Replace operation is not supported by mssql driver")
|
||||
t.AssertNil(err)
|
||||
n, _ = result.RowsAffected()
|
||||
t.Assert(n, 1)
|
||||
|
||||
// Verify the data was replaced
|
||||
one, err := db.Model(table).WherePri(1).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["PASSPORT"].String(), "t11")
|
||||
t.Assert(one["NICKNAME"].String(), "T11")
|
||||
|
||||
// Replace with non-existing record (should insert new record)
|
||||
result, err = db.Model(table).Data(g.Map{
|
||||
"id": 2,
|
||||
"passport": "t222",
|
||||
"password": "pass2",
|
||||
"nickname": "T222",
|
||||
"create_time": "2018-10-24 11:00:00",
|
||||
}).Replace()
|
||||
t.AssertNil(err)
|
||||
n, _ = result.RowsAffected()
|
||||
t.Assert(n, 1) // MERGE reports: 1 for insert
|
||||
|
||||
// Verify the new record was inserted
|
||||
one, err = db.Model(table).WherePri(2).One()
|
||||
t.AssertNil(err)
|
||||
t.Assert(one["PASSPORT"].String(), "t222")
|
||||
t.Assert(one["NICKNAME"].String(), "T222")
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -24,12 +24,12 @@ func (d *Driver) DoInsert(
|
||||
) (result sql.Result, err error) {
|
||||
switch option.InsertOption {
|
||||
case
|
||||
gdb.InsertOptionReplace,
|
||||
gdb.InsertOptionSave:
|
||||
gdb.InsertOptionSave,
|
||||
gdb.InsertOptionReplace:
|
||||
// PostgreSQL does not support REPLACE INTO syntax, use Save (ON CONFLICT ... DO UPDATE) instead.
|
||||
// Automatically detect primary keys if OnConflict is not specified.
|
||||
if len(option.OnConflict) == 0 {
|
||||
primaryKeys, err := d.getPrimaryKeys(ctx, table)
|
||||
primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table)
|
||||
if err != nil {
|
||||
return nil, gerror.WrapCode(
|
||||
gcode.CodeInternalError,
|
||||
@ -45,9 +45,11 @@ func (d *Driver) DoInsert(
|
||||
}
|
||||
}
|
||||
if !foundPrimaryKey {
|
||||
return nil, gerror.NewCode(
|
||||
return nil, gerror.NewCodef(
|
||||
gcode.CodeMissingParameter,
|
||||
`Please specify conflict columns or ensure the record has a primary key for Save/Replace operation`,
|
||||
`Replace/Save operation requires conflict detection: `+
|
||||
`either specify OnConflict() columns or ensure table '%s' has a primary key in the data`,
|
||||
table,
|
||||
)
|
||||
}
|
||||
option.OnConflict = primaryKeys
|
||||
@ -71,21 +73,3 @@ func (d *Driver) DoInsert(
|
||||
}
|
||||
return d.Core.DoInsert(ctx, link, table, list, option)
|
||||
}
|
||||
|
||||
// getPrimaryKeys retrieves the primary key field list of the table.
|
||||
// This method extracts primary key information from TableFields.
|
||||
func (d *Driver) getPrimaryKeys(ctx context.Context, table string) ([]string, error) {
|
||||
tableFields, err := d.TableFields(ctx, table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var primaryKeys []string
|
||||
for _, field := range tableFields {
|
||||
if gstr.Equal(field.Key, "pri") {
|
||||
primaryKeys = append(primaryKeys, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return primaryKeys, nil
|
||||
}
|
||||
|
||||
@ -10,6 +10,7 @@ package gdb
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
@ -251,3 +252,22 @@ func (c *Core) guessPrimaryTableName(tableStr string) string {
|
||||
}
|
||||
return guessedTableName
|
||||
}
|
||||
|
||||
// GetPrimaryKeys retrieves and returns the primary key field names of the specified table.
|
||||
// This method extracts primary key information from TableFields.
|
||||
// The parameter `schema` is optional, if not specified it uses the default schema.
|
||||
func (c *Core) GetPrimaryKeys(ctx context.Context, table string, schema ...string) ([]string, error) {
|
||||
tableFields, err := c.db.TableFields(ctx, table, schema...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var primaryKeys []string
|
||||
for _, field := range tableFields {
|
||||
if strings.EqualFold(field.Key, "pri") {
|
||||
primaryKeys = append(primaryKeys, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return primaryKeys, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user