From 7e9715ab1d2511ab31cf317f6f65179e162b64cf Mon Sep 17 00:00:00 2001 From: lxy1151 <316543569@qq.com> Date: Sun, 28 Sep 2025 17:55:08 +0800 Subject: [PATCH] feat(contrib/drivers/mssql): mssql support LastInsertId (#4051) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复mssqlserver的InsertAndGetId方法;插入记录如果是自增主键则返回ID --------- Co-authored-by: 林孝义 Co-authored-by: houseme Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: hailaz <739476267@qq.com> Co-authored-by: github-actions[bot] --- contrib/drivers/README.MD | 1 - contrib/drivers/README.zh_CN.MD | 1 - contrib/drivers/mssql/mssql_do_exec.go | 191 ++++++++++++++++++ .../drivers/mssql/mssql_z_unit_basic_test.go | 54 +++++ .../drivers/mssql/mssql_z_unit_init_test.go | 37 +++- 5 files changed, 278 insertions(+), 6 deletions(-) create mode 100644 contrib/drivers/mssql/mssql_do_exec.go diff --git a/contrib/drivers/README.MD b/contrib/drivers/README.MD index eaf4966f7..da22dc56c 100644 --- a/contrib/drivers/README.MD +++ b/contrib/drivers/README.MD @@ -83,7 +83,6 @@ import _ "github.com/gogf/gf/contrib/drivers/mssql/v2" Note: - It does not support `Replace` features. -- It does not support `LastInsertId`. - It supports server version >= `SQL Server2005` - It ONLY supports datetime2 and datetimeoffset types for auto handling created_at/updated_at/deleted_at columns, because datetime type does not support microseconds precision when column value is passed as string. diff --git a/contrib/drivers/README.zh_CN.MD b/contrib/drivers/README.zh_CN.MD index 27699c0b5..4c7b4f08a 100644 --- a/contrib/drivers/README.zh_CN.MD +++ b/contrib/drivers/README.zh_CN.MD @@ -81,7 +81,6 @@ import _ "github.com/gogf/gf/contrib/drivers/mssql/v2" 注意: - 不支持 `Replace` 功能。 -- 不支持 `LastInsertId`。 - 仅支持服务器版本 >= `SQL Server2005` - 仅支持 datetime2 和 datetimeoffset 类型来自动处理 created_at/updated_at/deleted_at 列,因为 datetime 类型在将列值作为字符串传递时不支持微秒精度。 diff --git a/contrib/drivers/mssql/mssql_do_exec.go b/contrib/drivers/mssql/mssql_do_exec.go new file mode 100644 index 000000000..a0d3fde39 --- /dev/null +++ b/contrib/drivers/mssql/mssql_do_exec.go @@ -0,0 +1,191 @@ +package mssql + +import ( + "context" + "database/sql" + "fmt" + "regexp" + "strings" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" +) + +const ( + backIdInsertHeadDefault = "INSERT INTO" + backIdInsertHeadInsertIgnore = "INSERT IGNORE INTO" + + autoIncrementName = "identity" + mssqlOutPutKey = "OUTPUT" + mssqlInsertedObjName = "INSERTED" + mssqlAffectFd = " 1 as AffectCount" + affectCountFieldName = "AffectCount" + mssqlPrimaryKeyName = "PRIMARY KEY" + fdId = "ID" + positionInsertValues = ") VALUES" // find the position of the string "VALUES" in the INSERT SQL statement to embed output code for retrieving the last inserted ID +) + +// DoExec commits the sql string and its arguments to underlying driver +// through given link object and returns the execution result. +func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sqlStr string, args ...interface{}) (result sql.Result, err error) { + // Transaction checks. + if link == nil { + if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil { + // Firstly, check and retrieve transaction link from context. + link = &txLinkMssql{tx.GetSqlTX()} + } else if link, err = d.Core.MasterLink(); err != nil { + // Or else it creates one from master node. + return nil, err + } + } else if !link.IsTransaction() { + // If current link is not transaction link, it checks and retrieves transaction from context. + if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil { + link = &txLinkMssql{tx.GetSqlTX()} + } + } + + // SQL filtering. + sqlStr, args = d.Core.FormatSqlBeforeExecuting(sqlStr, args) + sqlStr, args, err = d.DoFilter(ctx, link, sqlStr, args) + if err != nil { + return nil, err + } + + if !(strings.HasPrefix(sqlStr, backIdInsertHeadDefault) || strings.HasPrefix(sqlStr, backIdInsertHeadInsertIgnore)) { + return d.Core.DoExec(ctx, link, sqlStr, args) + } + // find the first pos + pos := strings.Index(sqlStr, positionInsertValues) + + table := d.GetTableNameFromSql(sqlStr) + outPutSql := d.GetInsertOutputSql(ctx, table) + // rebuild sql add output + var ( + sqlValueBefore = sqlStr[:pos+1] + sqlValueAfter = sqlStr[pos+1:] + ) + + sqlStr = fmt.Sprintf("%s%s%s", sqlValueBefore, outPutSql, sqlValueAfter) + + // fmt.Println("sql str:", sqlStr) + // Link execution. + var out gdb.DoCommitOutput + out, err = d.DoCommit(ctx, gdb.DoCommitInput{ + Link: link, + Sql: sqlStr, + Args: args, + Stmt: nil, + Type: gdb.SqlTypeQueryContext, + IsTransaction: link.IsTransaction(), + }) + if err != nil { + return &InsertResult{lastInsertId: 0, rowsAffected: 0, err: err}, err + } + var ( + aCount int64 // affect count + lId int64 // last insert id + ) + stdSqlResult := out.Records + if len(stdSqlResult) == 0 { + err = gerror.WrapCode(gcode.CodeDbOperationError, gerror.New("affectcount is zero"), `sql.Result.RowsAffected failed`) + return &InsertResult{lastInsertId: 0, rowsAffected: 0, err: err}, err + } + // get affect count + aCount = stdSqlResult[0].GMap().GetVar(affectCountFieldName).Int64() + // get last_insert_id + lId = stdSqlResult[0].GMap().GetVar(fdId).Int64() + + return &InsertResult{lastInsertId: lId, rowsAffected: aCount}, err +} + +// GetTableNameFromSql get table name from sql statement +// It handles table string like: +// "user" +// "user u" +// "DbLog.dbo.user", +// "user as u". +func (d *Driver) GetTableNameFromSql(sqlStr string) (table string) { + // INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?) + leftChars, rightChars := d.GetChars() + trimStr := leftChars + rightChars + "[] " + pattern := "INTO(.+?)\\(" + regCompile := regexp.MustCompile(pattern) + tableInfo := regCompile.FindStringSubmatch(sqlStr) + //get the first one. after the first it may be content of the value, it's not table name. + table = tableInfo[1] + table = strings.Trim(table, " ") + if strings.Contains(table, ".") { + tmpAry := strings.Split(table, ".") + // the last one is tablename + table = tmpAry[len(tmpAry)-1] + } else if strings.Contains(table, "as") || strings.Contains(table, " ") { + tmpAry := strings.Split(table, "as") + if len(tmpAry) < 2 { + tmpAry = strings.Split(table, " ") + } + // get the first one + table = tmpAry[0] + } + table = strings.Trim(table, trimStr) + return table +} + +// txLink is used to implement interface Link for TX. +type txLinkMssql struct { + *sql.Tx +} + +// IsTransaction returns if current Link is a transaction. +func (l *txLinkMssql) IsTransaction() bool { + return true +} + +// IsOnMaster checks and returns whether current link is operated on master node. +// Note that, transaction operation is always operated on master node. +func (l *txLinkMssql) IsOnMaster() bool { + return true +} + +// InsertResult instance of sql.Result +type InsertResult struct { + lastInsertId int64 + rowsAffected int64 + err error +} + +func (r *InsertResult) LastInsertId() (int64, error) { + return r.lastInsertId, r.err +} + +func (r *InsertResult) RowsAffected() (int64, error) { + return r.rowsAffected, r.err +} + +// GetInsertOutputSql gen get last_insert_id code +func (m *Driver) GetInsertOutputSql(ctx context.Context, table string) string { + fds, errFd := m.GetDB().TableFields(ctx, table) + if errFd != nil { + return "" + } + extraSqlAry := make([]string, 0) + extraSqlAry = append(extraSqlAry, fmt.Sprintf(" %s %s", mssqlOutPutKey, mssqlAffectFd)) + incrNo := 0 + if len(fds) > 0 { + for _, fd := range fds { + // has primary key and is auto-increment + if fd.Extra == autoIncrementName && fd.Key == mssqlPrimaryKeyName && !fd.Null { + incrNoStr := "" + if incrNo == 0 { // fixed first field named id, convenient to get + incrNoStr = fmt.Sprintf(" as %s", fdId) + } + + extraSqlAry = append(extraSqlAry, fmt.Sprintf("%s.%s%s", mssqlInsertedObjName, fd.Name, incrNoStr)) + incrNo++ + } + // fmt.Printf("null:%t name:%s key:%s k:%s \n", fd.Null, fd.Name, fd.Key, k) + } + } + return strings.Join(extraSqlAry, ",") + // sql example:INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?) +} diff --git a/contrib/drivers/mssql/mssql_z_unit_basic_test.go b/contrib/drivers/mssql/mssql_z_unit_basic_test.go index d39116a3c..3c7ed7235 100644 --- a/contrib/drivers/mssql/mssql_z_unit_basic_test.go +++ b/contrib/drivers/mssql/mssql_z_unit_basic_test.go @@ -13,11 +13,15 @@ import ( "testing" "time" + "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/encoding/gxml" "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/gctx" "github.com/gogf/gf/v2/os/gtime" "github.com/gogf/gf/v2/test/gtest" + + "github.com/gogf/gf/contrib/drivers/mssql/v2" ) func TestTables(t *testing.T) { @@ -148,6 +152,56 @@ func TestDoInsert(t *testing.T) { }) } +func TestDoInsertGetId(t *testing.T) { + // create test table + createInsertAndGetIdTableForTest() + gtest.C(t, func(t *gtest.T) { + table := "ip_to_id" + data := map[string]interface{}{ + "ip": "192.168.179.1", + } + id, err := db.InsertAndGetId(gctx.New(), table, data) + t.AssertNil(err) + t.AssertGT(id, 0) + // fmt.Println("id:", id) + + // multiple insert test + dataAry := []map[string]interface{}{{"ip": "192.168.5.9"}, {"ip": "192.168.5.10"}} + id1, err1 := db.InsertAndGetId(gctx.New(), table, dataAry) + t.AssertNil(err1) + t.AssertGT(id1, 0) + }) +} + +func TestGetTableFromSql(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + okTable := "ip_to_id" + sqlStr := "INSERT INTO \"ip_to_id\"(\"ip\") VALUES(?)" + dbWrapper, ok := db.GetCore().GetDB().(*gdb.DriverWrapperDB) + t.Assert(ok, true) + dbMssql, ok := dbWrapper.DB.(*mssql.Driver) + t.Assert(ok, true) + table := dbMssql.GetTableNameFromSql(sqlStr) + // fmt.Println("default table:", table) + t.Assert(table, okTable) + + sqlStr = "INSERT INTO \"MyLogDb\".\"dbo\".\"ip_to_id\"(\"ip\") VALUES(?)" + table = dbMssql.GetTableNameFromSql(sqlStr) + // fmt.Println("MyLogDb.dbo.ip_to_id table:", table) + t.Assert(table, okTable) + + sqlStr = "INSERT INTO \"ip_to_id\" as \"tt\" (\"ip\") VALUES(?)" + table = dbMssql.GetTableNameFromSql(sqlStr) + // fmt.Println("ip_to_id as tt table:", table) + t.Assert(table, okTable) + + sqlStr = "INSERT INTO \"ip_to_id\" \"tt\" (\"ip\") VALUES(?)" + table = dbMssql.GetTableNameFromSql(sqlStr) + // fmt.Println("ip_to_id tt table:", table) + t.Assert(table, okTable) + }) +} + func Test_DB_Ping(t *testing.T) { gtest.C(t, func(t *gtest.T) { err1 := db.PingMaster() diff --git a/contrib/drivers/mssql/mssql_z_unit_init_test.go b/contrib/drivers/mssql/mssql_z_unit_init_test.go index 08e925c5a..27a4db067 100644 --- a/contrib/drivers/mssql/mssql_z_unit_init_test.go +++ b/contrib/drivers/mssql/mssql_z_unit_init_test.go @@ -25,9 +25,14 @@ var ( ) const ( - TableSize = 10 - TestDbUser = "sa" - TestDbPass = "LoremIpsum86" + TableSize = 10 + TableName = "t_user" + TestSchema1 = "test1" + TestSchema2 = "test2" + TableNamePrefix1 = "gf_" + TestDbUser = "sa" + TestDbPass = "LoremIpsum86" + CreateTime = "2018-10-24 10:00:00" ) func init() { @@ -36,7 +41,7 @@ func init() { Port: "1433", User: TestDbUser, Pass: TestDbPass, - Name: "master", + Name: "test", Type: "mssql", Role: "master", Charset: "utf8", @@ -142,3 +147,27 @@ func dropTable(table string) { gtest.Fatal(err) } } + +// createInsertAndGetIdTableForTest test for InsertAndGetId +func createInsertAndGetIdTableForTest() (name string) { + + if _, err := db.Exec(context.Background(), ` +IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='ip_to_id' and xtype='U') +begin + CREATE TABLE [ip_to_id]( + [id] [int] IDENTITY(1,1) NOT NULL, + [ip] [varchar](128) NULL, + CONSTRAINT [PK_ip_to_id] PRIMARY KEY CLUSTERED + ( + [id] ASC + )WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY] + ) ON [PRIMARY] +end + `); err != nil { + gtest.Fatal(err) + } + + db.Schema(db.GetConfig().Name) + name = "ip_to_id" + return +}