mirror of
https://gitee.com/johng/gf
synced 2026-07-04 21:03:13 +08:00
This pull request introduces significant improvements to the handling of the `Replace` and `Save` operations for multiple database drivers, especially for MSSQL and PostgreSQL. The changes ensure that these operations now auto-detect primary keys when conflict columns are not explicitly provided, improving usability and aligning behavior across drivers. Additionally, the pull request updates related tests to reflect these enhancements and includes some minor documentation and code cleanup. **Key changes:** ### Enhanced Replace/Save Logic for Database Drivers * **MSSQL Driver:** - `Replace` and `Save` operations now auto-detect primary keys if `OnConflict` is not specified, using the `MERGE` statement for upsert functionality. If no primary key is found in the data, a detailed error is returned. [[1]](diffhunk://#diff-87815aa559a927e2de09bd05148f9841dfc06a1b5f3ecc5e3d5fcb80323a87f8L23-R61) [[2]](diffhunk://#diff-87815aa559a927e2de09bd05148f9841dfc06a1b5f3ecc5e3d5fcb80323a87f8L43-L59) - Updated tests to verify that `Replace` correctly updates or inserts records, and that missing conflict columns are properly handled. [[1]](diffhunk://#diff-bdbde9d7d6ee14c795343767b414740c4396f4dd3e97788b1f9d4e615405a42dL141-R151) [[2]](diffhunk://#diff-26338e93e473300b1313936eb0f6826546473793442f24715fa294b595f7a805L2661-R2707) * **PostgreSQL Driver:** - Similar to MSSQL, `Replace` and `Save` now auto-detect primary keys for conflict resolution if `OnConflict` is not set, and treat `Replace` as a `Save` operation. - Adjusted tests to ensure `Save` and `Replace` work as expected, including verifying data replacement and insertion. [[1]](diffhunk://#diff-c22703c37ebb6836c332f7cd2ada570577ba4564fe39886db02f7c2d0e7a2048L93-R93) [[2]](diffhunk://#diff-c22703c37ebb6836c332f7cd2ada570577ba4564fe39886db02f7c2d0e7a2048R102) [[3]](diffhunk://#diff-c22703c37ebb6836c332f7cd2ada570577ba4564fe39886db02f7c2d0e7a2048L110-R130) * **DM Driver:** - Improved conflict detection: now checks that at least one primary key exists in the provided data when `OnConflict` is not specified, and provides clearer error messages. - Refactored to use the core method for primary key detection and removed redundant code. ### Minor Improvements and Documentation * Added clarifying comments to `DoInsert` methods for ClickHouse, DM, MSSQL, Oracle, and PostgreSQL drivers, specifying that the input list must have at least one validated record. [[1]](diffhunk://#diff-f2e003895041ed3c52b91bb8c270696adc3528d77c39d2f7137af3396267444cR19) [[2]](diffhunk://#diff-f51b30e3f0b0f1284b905385a89992efd0de2fe9ff8c5a4062344dfab17d428eR23) [[3]](diffhunk://#diff-87815aa559a927e2de09bd05148f9841dfc06a1b5f3ecc5e3d5fcb80323a87f8L23-R61) [[4]](diffhunk://#diff-f61dac3fcfd5df4a3936cd8743499c8c0fc45f4f5d0f5398ed84a0cb1603202cR24) [[5]](diffhunk://#diff-c1dfed79aaa3a432057d2bd74d270e4b4094ebcf72984f1161d4972bea009410R16-R72) * Minor code and comment cleanups, including improved formatting and error handling. [[1]](diffhunk://#diff-f61dac3fcfd5df4a3936cd8743499c8c0fc45f4f5d0f5398ed84a0cb1603202cR37) [[2]](diffhunk://#diff-f61dac3fcfd5df4a3936cd8743499c8c0fc45f4f5d0f5398ed84a0cb1603202cL96-R98) [[3]](diffhunk://#diff-f61dac3fcfd5df4a3936cd8743499c8c0fc45f4f5d0f5398ed84a0cb1603202cL106-L116) [[4]](diffhunk://#diff-a17b44c76aaac53d1f164a2bb9440a5531659f4355e7ccfabdadff8dc8633c09L170-R171) [[5]](diffhunk://#diff-56189fa9ae1df51716b50d34d7fe56bfe67a330e8ac2c6b0de7b958db6817ed5R83-R98) ### Workflow and Documentation Updates * Updated example Docker commands in the CI workflow for consistency and clarity. [[1]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL57-R57) [[2]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL78-R78) [[3]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL92-R92) [[4]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL106-R106) [[5]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL153-R153) [[6]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL164-R164) * Removed outdated note about `Replace` support from the SQLite driver documentation. These changes improve the consistency, reliability, and developer experience when performing upsert operations across different database backends. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Lance Add <1196661499@qq.com>
193 lines
6.1 KiB
Go
193 lines
6.1 KiB
Go
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
|
//
|
|
// This Source Code Form is subject to the terms of the MIT License.
|
|
// If a copy of the MIT was not distributed with this file,
|
|
// You can obtain one at https://github.com/gogf/gf.
|
|
|
|
package mssql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/gogf/gf/v2/database/gdb"
|
|
"github.com/gogf/gf/v2/errors/gcode"
|
|
"github.com/gogf/gf/v2/errors/gerror"
|
|
)
|
|
|
|
const (
|
|
// INSERT statement prefixes
|
|
insertPrefixDefault = "INSERT INTO"
|
|
insertPrefixIgnore = "INSERT IGNORE INTO"
|
|
|
|
// Database field attributes
|
|
fieldExtraIdentity = "IDENTITY"
|
|
fieldKeyPrimary = "PRI"
|
|
|
|
// SQL keywords and syntax markers
|
|
outputKeyword = "OUTPUT"
|
|
insertValuesMarker = ") VALUES" // find the position of the string "VALUES" in the INSERT SQL statement to embed output code for retrieving the last inserted ID
|
|
|
|
// Object and field references
|
|
insertedObjectName = "INSERTED"
|
|
|
|
// Result field names and aliases
|
|
affectCountExpression = " 1 as AffectCount"
|
|
lastInsertIdFieldAlias = "ID"
|
|
)
|
|
|
|
// DoExec commits the sql string and its arguments to underlying driver
|
|
// through given link object and returns the execution result.
|
|
func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sqlStr string, args ...interface{}) (result sql.Result, err error) {
|
|
// Transaction checks.
|
|
if link == nil {
|
|
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
|
|
// Firstly, check and retrieve transaction link from context.
|
|
link = &txLinkMssql{tx.GetSqlTX()}
|
|
} else if link, err = d.MasterLink(); err != nil {
|
|
// Or else it creates one from master node.
|
|
return nil, err
|
|
}
|
|
} else if !link.IsTransaction() {
|
|
// If current link is not transaction link, it checks and retrieves transaction from context.
|
|
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
|
|
link = &txLinkMssql{tx.GetSqlTX()}
|
|
}
|
|
}
|
|
|
|
// SQL filtering.
|
|
sqlStr, args = d.FormatSqlBeforeExecuting(sqlStr, args)
|
|
sqlStr, args, err = d.DoFilter(ctx, link, sqlStr, args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !strings.HasPrefix(sqlStr, insertPrefixDefault) && !strings.HasPrefix(sqlStr, insertPrefixIgnore) {
|
|
return d.Core.DoExec(ctx, link, sqlStr, args)
|
|
}
|
|
// Find the first position of VALUES marker in the INSERT statement.
|
|
pos := strings.Index(sqlStr, insertValuesMarker)
|
|
|
|
table := d.GetTableNameFromSql(sqlStr)
|
|
outPutSql := d.GetInsertOutputSql(ctx, table)
|
|
// rebuild sql add output
|
|
var (
|
|
sqlValueBefore = sqlStr[:pos+1]
|
|
sqlValueAfter = sqlStr[pos+1:]
|
|
)
|
|
|
|
sqlStr = fmt.Sprintf("%s%s%s", sqlValueBefore, outPutSql, sqlValueAfter)
|
|
|
|
// fmt.Println("sql str:", sqlStr)
|
|
// Link execution.
|
|
var out gdb.DoCommitOutput
|
|
out, err = d.DoCommit(ctx, gdb.DoCommitInput{
|
|
Link: link,
|
|
Sql: sqlStr,
|
|
Args: args,
|
|
Stmt: nil,
|
|
Type: gdb.SqlTypeQueryContext,
|
|
IsTransaction: link.IsTransaction(),
|
|
})
|
|
if err != nil {
|
|
return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err
|
|
}
|
|
stdSqlResult := out.Records
|
|
if len(stdSqlResult) == 0 {
|
|
err = gerror.WrapCode(
|
|
gcode.CodeDbOperationError,
|
|
gerror.New("affected count is zero"),
|
|
`sql.Result.RowsAffected failed`,
|
|
)
|
|
return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err
|
|
}
|
|
// For batch insert, OUTPUT clause returns one row per inserted row.
|
|
// So the rowsAffected should be the count of returned records.
|
|
rowsAffected := int64(len(stdSqlResult))
|
|
// get last_insert_id from the first returned row
|
|
lastInsertId := stdSqlResult[0].GMap().GetVar(lastInsertIdFieldAlias).Int64()
|
|
|
|
return &Result{lastInsertId: lastInsertId, rowsAffected: rowsAffected}, err
|
|
}
|
|
|
|
// GetTableNameFromSql get table name from sql statement
|
|
// It handles table string like:
|
|
// "user"
|
|
// "user u"
|
|
// "DbLog.dbo.user",
|
|
// "user as u".
|
|
func (d *Driver) GetTableNameFromSql(sqlStr string) (table string) {
|
|
// INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?)
|
|
var (
|
|
leftChars, rightChars = d.GetChars()
|
|
trimStr = leftChars + rightChars + "[] "
|
|
pattern = "INTO(.+?)\\("
|
|
regCompile = regexp.MustCompile(pattern)
|
|
tableInfo = regCompile.FindStringSubmatch(sqlStr)
|
|
)
|
|
// get the first one. after the first it may be content of the value, it's not table name.
|
|
table = tableInfo[1]
|
|
table = strings.Trim(table, " ")
|
|
if strings.Contains(table, ".") {
|
|
tmpAry := strings.Split(table, ".")
|
|
// the last one is table name
|
|
table = tmpAry[len(tmpAry)-1]
|
|
} else if strings.Contains(table, "as") || strings.Contains(table, " ") {
|
|
tmpAry := strings.Split(table, "as")
|
|
if len(tmpAry) < 2 {
|
|
tmpAry = strings.Split(table, " ")
|
|
}
|
|
// get the first one
|
|
table = tmpAry[0]
|
|
}
|
|
table = strings.Trim(table, trimStr)
|
|
return table
|
|
}
|
|
|
|
// txLink is used to implement interface Link for TX.
|
|
type txLinkMssql struct {
|
|
*sql.Tx
|
|
}
|
|
|
|
// IsTransaction returns if current Link is a transaction.
|
|
func (l *txLinkMssql) IsTransaction() bool {
|
|
return true
|
|
}
|
|
|
|
// IsOnMaster checks and returns whether current link is operated on master node.
|
|
// Note that, transaction operation is always operated on master node.
|
|
func (l *txLinkMssql) IsOnMaster() bool {
|
|
return true
|
|
}
|
|
|
|
// GetInsertOutputSql gen get last_insert_id code
|
|
func (d *Driver) GetInsertOutputSql(ctx context.Context, table string) string {
|
|
fds, errFd := d.GetDB().TableFields(ctx, table)
|
|
if errFd != nil {
|
|
return ""
|
|
}
|
|
extraSqlAry := make([]string, 0)
|
|
extraSqlAry = append(extraSqlAry, fmt.Sprintf(" %s %s", outputKeyword, affectCountExpression))
|
|
incrNo := 0
|
|
if len(fds) > 0 {
|
|
for _, fd := range fds {
|
|
// has primary key and is auto-increment
|
|
if fd.Extra == fieldExtraIdentity && fd.Key == fieldKeyPrimary && !fd.Null {
|
|
incrNoStr := ""
|
|
if incrNo == 0 { // fixed first field named id, convenient to get
|
|
incrNoStr = fmt.Sprintf(" as %s", lastInsertIdFieldAlias)
|
|
}
|
|
|
|
extraSqlAry = append(extraSqlAry, fmt.Sprintf("%s.%s%s", insertedObjectName, fd.Name, incrNoStr))
|
|
incrNo++
|
|
}
|
|
// fmt.Printf("null:%t name:%s key:%s k:%s \n", fd.Null, fd.Name, fd.Key, k)
|
|
}
|
|
}
|
|
return strings.Join(extraSqlAry, ",")
|
|
// sql example:INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?)
|
|
}
|