fix: psgql tx unsupport LastInsertId (#2815)

This commit is contained in:
海亮
2023-08-03 19:59:22 +08:00
committed by GitHub
parent 2fbe4125dd
commit a4e7cc4700
7 changed files with 98 additions and 24 deletions

View File

@ -35,8 +35,8 @@ type Driver struct {
const (
internalPrimaryKeyInCtx gctx.StrKey = "primary_key"
defaultSchema = "public"
quoteChar = `"`
defaultSchema string = "public"
quoteChar string = `"`
)
func init() {
@ -372,14 +372,22 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
)
// Transaction checks.
if link != nil && link.IsTransaction() {
isUseCoreDoExec = true
} else {
if link == nil {
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
isUseCoreDoExec = true
// Firstly, check and retrieve transaction link from context.
link = tx
} else if link, err = d.MasterLink(); err != nil {
// Or else it creates one from master node.
return nil, err
}
} else if !link.IsTransaction() {
// If current link is not transaction link, it checks and retrieves transaction from context.
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
link = tx
}
}
// Check if it is an insert operation with primary key.
if value := ctx.Value(internalPrimaryKeyInCtx); value != nil {
var ok bool
pkField, ok = value.(gdb.TableField)
@ -408,8 +416,7 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
}
// Sql filtering.
// TODO: internal function formatSql
// sql, args = formatSql(sql, args)
sql, args = d.FormatSqlBeforeExecuting(sql, args)
sql, args, err = d.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
@ -442,10 +449,10 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
}
if out.Records[affected-1][primaryKey] != nil {
lastInsertId := out.Records[affected-1][primaryKey].Int()
lastInsertId := out.Records[affected-1][primaryKey].Int64()
return Result{
affected: int64(affected),
lastInsertId: int64(lastInsertId),
lastInsertId: lastInsertId,
}, nil
}
}

View File

@ -7,8 +7,10 @@
package pgsql_test
import (
"context"
"testing"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/test/gtest"
@ -45,6 +47,43 @@ func Test_LastInsertId(t *testing.T) {
})
}
func Test_TxLastInsertId(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
tableName := createTable()
defer dropTable(tableName)
err := db.Transaction(context.TODO(), func(ctx context.Context, tx gdb.TX) error {
// user
res, err := tx.Model(tableName).Insert(g.List{
{"passport": "user1", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
{"passport": "user2", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
{"passport": "user3", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
})
t.Assert(err, nil)
lastInsertId, err := res.LastInsertId()
t.Assert(err, nil)
t.AssertEQ(lastInsertId, int64(3))
rowsAffected, err := res.RowsAffected()
t.Assert(err, nil)
t.AssertEQ(rowsAffected, int64(3))
res1, err := tx.Model(tableName).Insert(g.List{
{"passport": "user4", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
{"passport": "user5", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
})
t.Assert(err, nil)
lastInsertId1, err := res1.LastInsertId()
t.Assert(err, nil)
t.AssertEQ(lastInsertId1, int64(5))
rowsAffected1, err := res1.RowsAffected()
t.Assert(err, nil)
t.AssertEQ(rowsAffected1, int64(2))
return nil
})
t.Assert(err, nil)
})
}
func Test_Driver_DoFilter(t *testing.T) {
var (
ctx = gctx.New()

View File

@ -179,6 +179,8 @@ type DB interface {
// TX defines the interfaces for ORM transaction operations.
type TX interface {
Link
Ctx(ctx context.Context) TX
Raw(rawSql string, args ...interface{}) *Model
Model(tableNameQueryOrStruct ...interface{}) *Model

View File

@ -796,3 +796,14 @@ func (c *Core) isSoftCreatedFieldName(fieldName string) bool {
}
return false
}
// FormatSqlBeforeExecuting formats the sql string and its arguments before executing.
// The internal handleArguments function might be called twice during the SQL procedure,
// but do not worry about it, it's safe and efficient.
func (c *Core) FormatSqlBeforeExecuting(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
// DO NOT do this as there may be multiple lines and comments in the sql.
// sql = gstr.Trim(sql)
// sql = gstr.Replace(sql, "\n", " ")
// sql, _ = gregex.ReplaceString(`\s{2,}`, ` `, sql)
return handleArguments(sql, args)
}

View File

@ -517,3 +517,28 @@ func (tx *TXCore) Update(table string, data interface{}, condition interface{},
func (tx *TXCore) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
return tx.Model(table).Ctx(tx.ctx).Where(condition, args...).Delete()
}
// QueryContext implements interface function Link.QueryContext.
func (tx *TXCore) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) {
return tx.tx.QueryContext(ctx, sql, args...)
}
// ExecContext implements interface function Link.ExecContext.
func (tx *TXCore) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) {
return tx.tx.ExecContext(ctx, sql, args...)
}
// PrepareContext implements interface function Link.PrepareContext.
func (tx *TXCore) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) {
return tx.tx.PrepareContext(ctx, sql)
}
// IsOnMaster implements interface function Link.IsOnMaster.
func (tx *TXCore) IsOnMaster() bool {
return true
}
// IsTransaction implements interface function Link.IsTransaction.
func (tx *TXCore) IsTransaction() bool {
return true
}

View File

@ -10,9 +10,10 @@ package gdb
import (
"context"
"database/sql"
"reflect"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"reflect"
"github.com/gogf/gf/v2/util/gconv"
@ -55,7 +56,7 @@ func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...inter
}
// Sql filtering.
sql, args = formatSql(sql, args)
sql, args = c.FormatSqlBeforeExecuting(sql, args)
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
@ -116,7 +117,7 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf
}
// SQL filtering.
sql, args = formatSql(sql, args)
sql, args = c.FormatSqlBeforeExecuting(sql, args)
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err

View File

@ -373,17 +373,6 @@ func GetPrimaryKeyCondition(primary string, where ...interface{}) (newWhereCondi
return where
}
// formatSql formats the sql string and its arguments before executing.
// The internal handleArguments function might be called twice during the SQL procedure,
// but do not worry about it, it's safe and efficient.
func formatSql(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
// DO NOT do this as there may be multiple lines and comments in the sql.
// sql = gstr.Trim(sql)
// sql = gstr.Replace(sql, "\n", " ")
// sql, _ = gregex.ReplaceString(`\s{2,}`, ` `, sql)
return handleArguments(sql, args)
}
type formatWhereHolderInput struct {
WhereHolder
OmitNil bool