From a4e7cc470051b9a03724a2fe36b20aff3a6b0c3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B5=B7=E4=BA=AE?= <739476267@qq.com> Date: Thu, 3 Aug 2023 19:59:22 +0800 Subject: [PATCH] fix: psgql tx unsupport LastInsertId (#2815) --- contrib/drivers/pgsql/pgsql.go | 27 ++++++++++++------- contrib/drivers/pgsql/pgsql_z_test.go | 39 +++++++++++++++++++++++++++ database/gdb/gdb.go | 2 ++ database/gdb/gdb_core.go | 11 ++++++++ database/gdb/gdb_core_transaction.go | 25 +++++++++++++++++ database/gdb/gdb_core_underlying.go | 7 ++--- database/gdb/gdb_func.go | 11 -------- 7 files changed, 98 insertions(+), 24 deletions(-) diff --git a/contrib/drivers/pgsql/pgsql.go b/contrib/drivers/pgsql/pgsql.go index c4d0dbd5d..e3e8b8c66 100644 --- a/contrib/drivers/pgsql/pgsql.go +++ b/contrib/drivers/pgsql/pgsql.go @@ -35,8 +35,8 @@ type Driver struct { const ( internalPrimaryKeyInCtx gctx.StrKey = "primary_key" - defaultSchema = "public" - quoteChar = `"` + defaultSchema string = "public" + quoteChar string = `"` ) func init() { @@ -372,14 +372,22 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ... ) // Transaction checks. - if link != nil && link.IsTransaction() { - isUseCoreDoExec = true - } else { + if link == nil { if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil { - isUseCoreDoExec = true + // Firstly, check and retrieve transaction link from context. + link = tx + } else if link, err = d.MasterLink(); err != nil { + // Or else it creates one from master node. + return nil, err + } + } else if !link.IsTransaction() { + // If current link is not transaction link, it checks and retrieves transaction from context. + if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil { + link = tx } } + // Check if it is an insert operation with primary key. if value := ctx.Value(internalPrimaryKeyInCtx); value != nil { var ok bool pkField, ok = value.(gdb.TableField) @@ -408,8 +416,7 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ... } // Sql filtering. - // TODO: internal function formatSql - // sql, args = formatSql(sql, args) + sql, args = d.FormatSqlBeforeExecuting(sql, args) sql, args, err = d.DoFilter(ctx, link, sql, args) if err != nil { return nil, err @@ -442,10 +449,10 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ... } if out.Records[affected-1][primaryKey] != nil { - lastInsertId := out.Records[affected-1][primaryKey].Int() + lastInsertId := out.Records[affected-1][primaryKey].Int64() return Result{ affected: int64(affected), - lastInsertId: int64(lastInsertId), + lastInsertId: lastInsertId, }, nil } } diff --git a/contrib/drivers/pgsql/pgsql_z_test.go b/contrib/drivers/pgsql/pgsql_z_test.go index f5559d989..6d9a6f981 100644 --- a/contrib/drivers/pgsql/pgsql_z_test.go +++ b/contrib/drivers/pgsql/pgsql_z_test.go @@ -7,8 +7,10 @@ package pgsql_test import ( + "context" "testing" + "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/os/gctx" "github.com/gogf/gf/v2/test/gtest" @@ -45,6 +47,43 @@ func Test_LastInsertId(t *testing.T) { }) } +func Test_TxLastInsertId(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + tableName := createTable() + defer dropTable(tableName) + err := db.Transaction(context.TODO(), func(ctx context.Context, tx gdb.TX) error { + // user + res, err := tx.Model(tableName).Insert(g.List{ + {"passport": "user1", "password": "pwd", "nickname": "nickname", "create_time": CreateTime}, + {"passport": "user2", "password": "pwd", "nickname": "nickname", "create_time": CreateTime}, + {"passport": "user3", "password": "pwd", "nickname": "nickname", "create_time": CreateTime}, + }) + t.Assert(err, nil) + lastInsertId, err := res.LastInsertId() + t.Assert(err, nil) + t.AssertEQ(lastInsertId, int64(3)) + rowsAffected, err := res.RowsAffected() + t.Assert(err, nil) + t.AssertEQ(rowsAffected, int64(3)) + + res1, err := tx.Model(tableName).Insert(g.List{ + {"passport": "user4", "password": "pwd", "nickname": "nickname", "create_time": CreateTime}, + {"passport": "user5", "password": "pwd", "nickname": "nickname", "create_time": CreateTime}, + }) + t.Assert(err, nil) + lastInsertId1, err := res1.LastInsertId() + t.Assert(err, nil) + t.AssertEQ(lastInsertId1, int64(5)) + rowsAffected1, err := res1.RowsAffected() + t.Assert(err, nil) + t.AssertEQ(rowsAffected1, int64(2)) + return nil + + }) + t.Assert(err, nil) + }) +} + func Test_Driver_DoFilter(t *testing.T) { var ( ctx = gctx.New() diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 1e0c1edad..2af5fb476 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -179,6 +179,8 @@ type DB interface { // TX defines the interfaces for ORM transaction operations. type TX interface { + Link + Ctx(ctx context.Context) TX Raw(rawSql string, args ...interface{}) *Model Model(tableNameQueryOrStruct ...interface{}) *Model diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 7e2173678..823353f21 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -796,3 +796,14 @@ func (c *Core) isSoftCreatedFieldName(fieldName string) bool { } return false } + +// FormatSqlBeforeExecuting formats the sql string and its arguments before executing. +// The internal handleArguments function might be called twice during the SQL procedure, +// but do not worry about it, it's safe and efficient. +func (c *Core) FormatSqlBeforeExecuting(sql string, args []interface{}) (newSql string, newArgs []interface{}) { + // DO NOT do this as there may be multiple lines and comments in the sql. + // sql = gstr.Trim(sql) + // sql = gstr.Replace(sql, "\n", " ") + // sql, _ = gregex.ReplaceString(`\s{2,}`, ` `, sql) + return handleArguments(sql, args) +} diff --git a/database/gdb/gdb_core_transaction.go b/database/gdb/gdb_core_transaction.go index b16486b47..1fd52f2a2 100644 --- a/database/gdb/gdb_core_transaction.go +++ b/database/gdb/gdb_core_transaction.go @@ -517,3 +517,28 @@ func (tx *TXCore) Update(table string, data interface{}, condition interface{}, func (tx *TXCore) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { return tx.Model(table).Ctx(tx.ctx).Where(condition, args...).Delete() } + +// QueryContext implements interface function Link.QueryContext. +func (tx *TXCore) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { + return tx.tx.QueryContext(ctx, sql, args...) +} + +// ExecContext implements interface function Link.ExecContext. +func (tx *TXCore) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { + return tx.tx.ExecContext(ctx, sql, args...) +} + +// PrepareContext implements interface function Link.PrepareContext. +func (tx *TXCore) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) { + return tx.tx.PrepareContext(ctx, sql) +} + +// IsOnMaster implements interface function Link.IsOnMaster. +func (tx *TXCore) IsOnMaster() bool { + return true +} + +// IsTransaction implements interface function Link.IsTransaction. +func (tx *TXCore) IsTransaction() bool { + return true +} diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index 1930270d9..85b0aaa84 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -10,9 +10,10 @@ package gdb import ( "context" "database/sql" + "reflect" + "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" - "reflect" "github.com/gogf/gf/v2/util/gconv" @@ -55,7 +56,7 @@ func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...inter } // Sql filtering. - sql, args = formatSql(sql, args) + sql, args = c.FormatSqlBeforeExecuting(sql, args) sql, args, err = c.db.DoFilter(ctx, link, sql, args) if err != nil { return nil, err @@ -116,7 +117,7 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf } // SQL filtering. - sql, args = formatSql(sql, args) + sql, args = c.FormatSqlBeforeExecuting(sql, args) sql, args, err = c.db.DoFilter(ctx, link, sql, args) if err != nil { return nil, err diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 90d02470e..26e00480e 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -373,17 +373,6 @@ func GetPrimaryKeyCondition(primary string, where ...interface{}) (newWhereCondi return where } -// formatSql formats the sql string and its arguments before executing. -// The internal handleArguments function might be called twice during the SQL procedure, -// but do not worry about it, it's safe and efficient. -func formatSql(sql string, args []interface{}) (newSql string, newArgs []interface{}) { - // DO NOT do this as there may be multiple lines and comments in the sql. - // sql = gstr.Trim(sql) - // sql = gstr.Replace(sql, "\n", " ") - // sql, _ = gregex.ReplaceString(`\s{2,}`, ` `, sql) - return handleArguments(sql, args) -} - type formatWhereHolderInput struct { WhereHolder OmitNil bool