mirror of
https://gitee.com/johng/gf
synced 2026-06-07 02:12:11 +08:00
feat(contrib/drivers/mssql): mssql support LastInsertId (#4051)
修复mssqlserver的InsertAndGetId方法;插入记录如果是自增主键则返回ID --------- Co-authored-by: 林孝义 <linxy@3755.com> Co-authored-by: houseme <housemecn@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: hailaz <739476267@qq.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@ -83,7 +83,6 @@ import _ "github.com/gogf/gf/contrib/drivers/mssql/v2"
|
||||
Note:
|
||||
|
||||
- It does not support `Replace` features.
|
||||
- It does not support `LastInsertId`.
|
||||
- It supports server version >= `SQL Server2005`
|
||||
- It ONLY supports datetime2 and datetimeoffset types for auto handling created_at/updated_at/deleted_at columns, because datetime type does not support microseconds precision when column value is passed as string.
|
||||
|
||||
|
||||
@ -81,7 +81,6 @@ import _ "github.com/gogf/gf/contrib/drivers/mssql/v2"
|
||||
注意:
|
||||
|
||||
- 不支持 `Replace` 功能。
|
||||
- 不支持 `LastInsertId`。
|
||||
- 仅支持服务器版本 >= `SQL Server2005`
|
||||
- 仅支持 datetime2 和 datetimeoffset 类型来自动处理 created_at/updated_at/deleted_at 列,因为 datetime 类型在将列值作为字符串传递时不支持微秒精度。
|
||||
|
||||
|
||||
191
contrib/drivers/mssql/mssql_do_exec.go
Normal file
191
contrib/drivers/mssql/mssql_do_exec.go
Normal file
@ -0,0 +1,191 @@
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
)
|
||||
|
||||
const (
|
||||
backIdInsertHeadDefault = "INSERT INTO"
|
||||
backIdInsertHeadInsertIgnore = "INSERT IGNORE INTO"
|
||||
|
||||
autoIncrementName = "identity"
|
||||
mssqlOutPutKey = "OUTPUT"
|
||||
mssqlInsertedObjName = "INSERTED"
|
||||
mssqlAffectFd = " 1 as AffectCount"
|
||||
affectCountFieldName = "AffectCount"
|
||||
mssqlPrimaryKeyName = "PRIMARY KEY"
|
||||
fdId = "ID"
|
||||
positionInsertValues = ") VALUES" // find the position of the string "VALUES" in the INSERT SQL statement to embed output code for retrieving the last inserted ID
|
||||
)
|
||||
|
||||
// DoExec commits the sql string and its arguments to underlying driver
|
||||
// through given link object and returns the execution result.
|
||||
func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sqlStr string, args ...interface{}) (result sql.Result, err error) {
|
||||
// Transaction checks.
|
||||
if link == nil {
|
||||
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
|
||||
// Firstly, check and retrieve transaction link from context.
|
||||
link = &txLinkMssql{tx.GetSqlTX()}
|
||||
} else if link, err = d.Core.MasterLink(); err != nil {
|
||||
// Or else it creates one from master node.
|
||||
return nil, err
|
||||
}
|
||||
} else if !link.IsTransaction() {
|
||||
// If current link is not transaction link, it checks and retrieves transaction from context.
|
||||
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
|
||||
link = &txLinkMssql{tx.GetSqlTX()}
|
||||
}
|
||||
}
|
||||
|
||||
// SQL filtering.
|
||||
sqlStr, args = d.Core.FormatSqlBeforeExecuting(sqlStr, args)
|
||||
sqlStr, args, err = d.DoFilter(ctx, link, sqlStr, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !(strings.HasPrefix(sqlStr, backIdInsertHeadDefault) || strings.HasPrefix(sqlStr, backIdInsertHeadInsertIgnore)) {
|
||||
return d.Core.DoExec(ctx, link, sqlStr, args)
|
||||
}
|
||||
// find the first pos
|
||||
pos := strings.Index(sqlStr, positionInsertValues)
|
||||
|
||||
table := d.GetTableNameFromSql(sqlStr)
|
||||
outPutSql := d.GetInsertOutputSql(ctx, table)
|
||||
// rebuild sql add output
|
||||
var (
|
||||
sqlValueBefore = sqlStr[:pos+1]
|
||||
sqlValueAfter = sqlStr[pos+1:]
|
||||
)
|
||||
|
||||
sqlStr = fmt.Sprintf("%s%s%s", sqlValueBefore, outPutSql, sqlValueAfter)
|
||||
|
||||
// fmt.Println("sql str:", sqlStr)
|
||||
// Link execution.
|
||||
var out gdb.DoCommitOutput
|
||||
out, err = d.DoCommit(ctx, gdb.DoCommitInput{
|
||||
Link: link,
|
||||
Sql: sqlStr,
|
||||
Args: args,
|
||||
Stmt: nil,
|
||||
Type: gdb.SqlTypeQueryContext,
|
||||
IsTransaction: link.IsTransaction(),
|
||||
})
|
||||
if err != nil {
|
||||
return &InsertResult{lastInsertId: 0, rowsAffected: 0, err: err}, err
|
||||
}
|
||||
var (
|
||||
aCount int64 // affect count
|
||||
lId int64 // last insert id
|
||||
)
|
||||
stdSqlResult := out.Records
|
||||
if len(stdSqlResult) == 0 {
|
||||
err = gerror.WrapCode(gcode.CodeDbOperationError, gerror.New("affectcount is zero"), `sql.Result.RowsAffected failed`)
|
||||
return &InsertResult{lastInsertId: 0, rowsAffected: 0, err: err}, err
|
||||
}
|
||||
// get affect count
|
||||
aCount = stdSqlResult[0].GMap().GetVar(affectCountFieldName).Int64()
|
||||
// get last_insert_id
|
||||
lId = stdSqlResult[0].GMap().GetVar(fdId).Int64()
|
||||
|
||||
return &InsertResult{lastInsertId: lId, rowsAffected: aCount}, err
|
||||
}
|
||||
|
||||
// GetTableNameFromSql get table name from sql statement
|
||||
// It handles table string like:
|
||||
// "user"
|
||||
// "user u"
|
||||
// "DbLog.dbo.user",
|
||||
// "user as u".
|
||||
func (d *Driver) GetTableNameFromSql(sqlStr string) (table string) {
|
||||
// INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?)
|
||||
leftChars, rightChars := d.GetChars()
|
||||
trimStr := leftChars + rightChars + "[] "
|
||||
pattern := "INTO(.+?)\\("
|
||||
regCompile := regexp.MustCompile(pattern)
|
||||
tableInfo := regCompile.FindStringSubmatch(sqlStr)
|
||||
//get the first one. after the first it may be content of the value, it's not table name.
|
||||
table = tableInfo[1]
|
||||
table = strings.Trim(table, " ")
|
||||
if strings.Contains(table, ".") {
|
||||
tmpAry := strings.Split(table, ".")
|
||||
// the last one is tablename
|
||||
table = tmpAry[len(tmpAry)-1]
|
||||
} else if strings.Contains(table, "as") || strings.Contains(table, " ") {
|
||||
tmpAry := strings.Split(table, "as")
|
||||
if len(tmpAry) < 2 {
|
||||
tmpAry = strings.Split(table, " ")
|
||||
}
|
||||
// get the first one
|
||||
table = tmpAry[0]
|
||||
}
|
||||
table = strings.Trim(table, trimStr)
|
||||
return table
|
||||
}
|
||||
|
||||
// txLink is used to implement interface Link for TX.
|
||||
type txLinkMssql struct {
|
||||
*sql.Tx
|
||||
}
|
||||
|
||||
// IsTransaction returns if current Link is a transaction.
|
||||
func (l *txLinkMssql) IsTransaction() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsOnMaster checks and returns whether current link is operated on master node.
|
||||
// Note that, transaction operation is always operated on master node.
|
||||
func (l *txLinkMssql) IsOnMaster() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// InsertResult instance of sql.Result
|
||||
type InsertResult struct {
|
||||
lastInsertId int64
|
||||
rowsAffected int64
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *InsertResult) LastInsertId() (int64, error) {
|
||||
return r.lastInsertId, r.err
|
||||
}
|
||||
|
||||
func (r *InsertResult) RowsAffected() (int64, error) {
|
||||
return r.rowsAffected, r.err
|
||||
}
|
||||
|
||||
// GetInsertOutputSql gen get last_insert_id code
|
||||
func (m *Driver) GetInsertOutputSql(ctx context.Context, table string) string {
|
||||
fds, errFd := m.GetDB().TableFields(ctx, table)
|
||||
if errFd != nil {
|
||||
return ""
|
||||
}
|
||||
extraSqlAry := make([]string, 0)
|
||||
extraSqlAry = append(extraSqlAry, fmt.Sprintf(" %s %s", mssqlOutPutKey, mssqlAffectFd))
|
||||
incrNo := 0
|
||||
if len(fds) > 0 {
|
||||
for _, fd := range fds {
|
||||
// has primary key and is auto-increment
|
||||
if fd.Extra == autoIncrementName && fd.Key == mssqlPrimaryKeyName && !fd.Null {
|
||||
incrNoStr := ""
|
||||
if incrNo == 0 { // fixed first field named id, convenient to get
|
||||
incrNoStr = fmt.Sprintf(" as %s", fdId)
|
||||
}
|
||||
|
||||
extraSqlAry = append(extraSqlAry, fmt.Sprintf("%s.%s%s", mssqlInsertedObjName, fd.Name, incrNoStr))
|
||||
incrNo++
|
||||
}
|
||||
// fmt.Printf("null:%t name:%s key:%s k:%s \n", fd.Null, fd.Name, fd.Key, k)
|
||||
}
|
||||
}
|
||||
return strings.Join(extraSqlAry, ",")
|
||||
// sql example:INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?)
|
||||
}
|
||||
@ -13,11 +13,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/encoding/gxml"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gctx"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/test/gtest"
|
||||
|
||||
"github.com/gogf/gf/contrib/drivers/mssql/v2"
|
||||
)
|
||||
|
||||
func TestTables(t *testing.T) {
|
||||
@ -148,6 +152,56 @@ func TestDoInsert(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDoInsertGetId(t *testing.T) {
|
||||
// create test table
|
||||
createInsertAndGetIdTableForTest()
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
table := "ip_to_id"
|
||||
data := map[string]interface{}{
|
||||
"ip": "192.168.179.1",
|
||||
}
|
||||
id, err := db.InsertAndGetId(gctx.New(), table, data)
|
||||
t.AssertNil(err)
|
||||
t.AssertGT(id, 0)
|
||||
// fmt.Println("id:", id)
|
||||
|
||||
// multiple insert test
|
||||
dataAry := []map[string]interface{}{{"ip": "192.168.5.9"}, {"ip": "192.168.5.10"}}
|
||||
id1, err1 := db.InsertAndGetId(gctx.New(), table, dataAry)
|
||||
t.AssertNil(err1)
|
||||
t.AssertGT(id1, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetTableFromSql(t *testing.T) {
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
okTable := "ip_to_id"
|
||||
sqlStr := "INSERT INTO \"ip_to_id\"(\"ip\") VALUES(?)"
|
||||
dbWrapper, ok := db.GetCore().GetDB().(*gdb.DriverWrapperDB)
|
||||
t.Assert(ok, true)
|
||||
dbMssql, ok := dbWrapper.DB.(*mssql.Driver)
|
||||
t.Assert(ok, true)
|
||||
table := dbMssql.GetTableNameFromSql(sqlStr)
|
||||
// fmt.Println("default table:", table)
|
||||
t.Assert(table, okTable)
|
||||
|
||||
sqlStr = "INSERT INTO \"MyLogDb\".\"dbo\".\"ip_to_id\"(\"ip\") VALUES(?)"
|
||||
table = dbMssql.GetTableNameFromSql(sqlStr)
|
||||
// fmt.Println("MyLogDb.dbo.ip_to_id table:", table)
|
||||
t.Assert(table, okTable)
|
||||
|
||||
sqlStr = "INSERT INTO \"ip_to_id\" as \"tt\" (\"ip\") VALUES(?)"
|
||||
table = dbMssql.GetTableNameFromSql(sqlStr)
|
||||
// fmt.Println("ip_to_id as tt table:", table)
|
||||
t.Assert(table, okTable)
|
||||
|
||||
sqlStr = "INSERT INTO \"ip_to_id\" \"tt\" (\"ip\") VALUES(?)"
|
||||
table = dbMssql.GetTableNameFromSql(sqlStr)
|
||||
// fmt.Println("ip_to_id tt table:", table)
|
||||
t.Assert(table, okTable)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_DB_Ping(t *testing.T) {
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
err1 := db.PingMaster()
|
||||
|
||||
@ -25,9 +25,14 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
TableSize = 10
|
||||
TestDbUser = "sa"
|
||||
TestDbPass = "LoremIpsum86"
|
||||
TableSize = 10
|
||||
TableName = "t_user"
|
||||
TestSchema1 = "test1"
|
||||
TestSchema2 = "test2"
|
||||
TableNamePrefix1 = "gf_"
|
||||
TestDbUser = "sa"
|
||||
TestDbPass = "LoremIpsum86"
|
||||
CreateTime = "2018-10-24 10:00:00"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -36,7 +41,7 @@ func init() {
|
||||
Port: "1433",
|
||||
User: TestDbUser,
|
||||
Pass: TestDbPass,
|
||||
Name: "master",
|
||||
Name: "test",
|
||||
Type: "mssql",
|
||||
Role: "master",
|
||||
Charset: "utf8",
|
||||
@ -142,3 +147,27 @@ func dropTable(table string) {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// createInsertAndGetIdTableForTest test for InsertAndGetId
|
||||
func createInsertAndGetIdTableForTest() (name string) {
|
||||
|
||||
if _, err := db.Exec(context.Background(), `
|
||||
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='ip_to_id' and xtype='U')
|
||||
begin
|
||||
CREATE TABLE [ip_to_id](
|
||||
[id] [int] IDENTITY(1,1) NOT NULL,
|
||||
[ip] [varchar](128) NULL,
|
||||
CONSTRAINT [PK_ip_to_id] PRIMARY KEY CLUSTERED
|
||||
(
|
||||
[id] ASC
|
||||
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
|
||||
) ON [PRIMARY]
|
||||
end
|
||||
`); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
|
||||
db.Schema(db.GetConfig().Name)
|
||||
name = "ip_to_id"
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user