mirror of
https://gitee.com/johng/gf
synced 2026-06-06 16:21:40 +08:00
fix: psgql tx unsupport LastInsertId (#2815)
This commit is contained in:
@ -35,8 +35,8 @@ type Driver struct {
|
||||
|
||||
const (
|
||||
internalPrimaryKeyInCtx gctx.StrKey = "primary_key"
|
||||
defaultSchema = "public"
|
||||
quoteChar = `"`
|
||||
defaultSchema string = "public"
|
||||
quoteChar string = `"`
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -372,14 +372,22 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
|
||||
)
|
||||
|
||||
// Transaction checks.
|
||||
if link != nil && link.IsTransaction() {
|
||||
isUseCoreDoExec = true
|
||||
} else {
|
||||
if link == nil {
|
||||
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
|
||||
isUseCoreDoExec = true
|
||||
// Firstly, check and retrieve transaction link from context.
|
||||
link = tx
|
||||
} else if link, err = d.MasterLink(); err != nil {
|
||||
// Or else it creates one from master node.
|
||||
return nil, err
|
||||
}
|
||||
} else if !link.IsTransaction() {
|
||||
// If current link is not transaction link, it checks and retrieves transaction from context.
|
||||
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
|
||||
link = tx
|
||||
}
|
||||
}
|
||||
|
||||
// Check if it is an insert operation with primary key.
|
||||
if value := ctx.Value(internalPrimaryKeyInCtx); value != nil {
|
||||
var ok bool
|
||||
pkField, ok = value.(gdb.TableField)
|
||||
@ -408,8 +416,7 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
|
||||
}
|
||||
|
||||
// Sql filtering.
|
||||
// TODO: internal function formatSql
|
||||
// sql, args = formatSql(sql, args)
|
||||
sql, args = d.FormatSqlBeforeExecuting(sql, args)
|
||||
sql, args, err = d.DoFilter(ctx, link, sql, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -442,10 +449,10 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...
|
||||
}
|
||||
|
||||
if out.Records[affected-1][primaryKey] != nil {
|
||||
lastInsertId := out.Records[affected-1][primaryKey].Int()
|
||||
lastInsertId := out.Records[affected-1][primaryKey].Int64()
|
||||
return Result{
|
||||
affected: int64(affected),
|
||||
lastInsertId: int64(lastInsertId),
|
||||
lastInsertId: lastInsertId,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,8 +7,10 @@
|
||||
package pgsql_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gctx"
|
||||
"github.com/gogf/gf/v2/test/gtest"
|
||||
@ -45,6 +47,43 @@ func Test_LastInsertId(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_TxLastInsertId(t *testing.T) {
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
tableName := createTable()
|
||||
defer dropTable(tableName)
|
||||
err := db.Transaction(context.TODO(), func(ctx context.Context, tx gdb.TX) error {
|
||||
// user
|
||||
res, err := tx.Model(tableName).Insert(g.List{
|
||||
{"passport": "user1", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
|
||||
{"passport": "user2", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
|
||||
{"passport": "user3", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
|
||||
})
|
||||
t.Assert(err, nil)
|
||||
lastInsertId, err := res.LastInsertId()
|
||||
t.Assert(err, nil)
|
||||
t.AssertEQ(lastInsertId, int64(3))
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
t.Assert(err, nil)
|
||||
t.AssertEQ(rowsAffected, int64(3))
|
||||
|
||||
res1, err := tx.Model(tableName).Insert(g.List{
|
||||
{"passport": "user4", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
|
||||
{"passport": "user5", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
|
||||
})
|
||||
t.Assert(err, nil)
|
||||
lastInsertId1, err := res1.LastInsertId()
|
||||
t.Assert(err, nil)
|
||||
t.AssertEQ(lastInsertId1, int64(5))
|
||||
rowsAffected1, err := res1.RowsAffected()
|
||||
t.Assert(err, nil)
|
||||
t.AssertEQ(rowsAffected1, int64(2))
|
||||
return nil
|
||||
|
||||
})
|
||||
t.Assert(err, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Driver_DoFilter(t *testing.T) {
|
||||
var (
|
||||
ctx = gctx.New()
|
||||
|
||||
@ -179,6 +179,8 @@ type DB interface {
|
||||
|
||||
// TX defines the interfaces for ORM transaction operations.
|
||||
type TX interface {
|
||||
Link
|
||||
|
||||
Ctx(ctx context.Context) TX
|
||||
Raw(rawSql string, args ...interface{}) *Model
|
||||
Model(tableNameQueryOrStruct ...interface{}) *Model
|
||||
|
||||
@ -796,3 +796,14 @@ func (c *Core) isSoftCreatedFieldName(fieldName string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FormatSqlBeforeExecuting formats the sql string and its arguments before executing.
|
||||
// The internal handleArguments function might be called twice during the SQL procedure,
|
||||
// but do not worry about it, it's safe and efficient.
|
||||
func (c *Core) FormatSqlBeforeExecuting(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
|
||||
// DO NOT do this as there may be multiple lines and comments in the sql.
|
||||
// sql = gstr.Trim(sql)
|
||||
// sql = gstr.Replace(sql, "\n", " ")
|
||||
// sql, _ = gregex.ReplaceString(`\s{2,}`, ` `, sql)
|
||||
return handleArguments(sql, args)
|
||||
}
|
||||
|
||||
@ -517,3 +517,28 @@ func (tx *TXCore) Update(table string, data interface{}, condition interface{},
|
||||
func (tx *TXCore) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
return tx.Model(table).Ctx(tx.ctx).Where(condition, args...).Delete()
|
||||
}
|
||||
|
||||
// QueryContext implements interface function Link.QueryContext.
|
||||
func (tx *TXCore) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) {
|
||||
return tx.tx.QueryContext(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// ExecContext implements interface function Link.ExecContext.
|
||||
func (tx *TXCore) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) {
|
||||
return tx.tx.ExecContext(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// PrepareContext implements interface function Link.PrepareContext.
|
||||
func (tx *TXCore) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) {
|
||||
return tx.tx.PrepareContext(ctx, sql)
|
||||
}
|
||||
|
||||
// IsOnMaster implements interface function Link.IsOnMaster.
|
||||
func (tx *TXCore) IsOnMaster() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsTransaction implements interface function Link.IsTransaction.
|
||||
func (tx *TXCore) IsTransaction() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -10,9 +10,10 @@ package gdb
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"reflect"
|
||||
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
@ -55,7 +56,7 @@ func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...inter
|
||||
}
|
||||
|
||||
// Sql filtering.
|
||||
sql, args = formatSql(sql, args)
|
||||
sql, args = c.FormatSqlBeforeExecuting(sql, args)
|
||||
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -116,7 +117,7 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf
|
||||
}
|
||||
|
||||
// SQL filtering.
|
||||
sql, args = formatSql(sql, args)
|
||||
sql, args = c.FormatSqlBeforeExecuting(sql, args)
|
||||
sql, args, err = c.db.DoFilter(ctx, link, sql, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -373,17 +373,6 @@ func GetPrimaryKeyCondition(primary string, where ...interface{}) (newWhereCondi
|
||||
return where
|
||||
}
|
||||
|
||||
// formatSql formats the sql string and its arguments before executing.
|
||||
// The internal handleArguments function might be called twice during the SQL procedure,
|
||||
// but do not worry about it, it's safe and efficient.
|
||||
func formatSql(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
|
||||
// DO NOT do this as there may be multiple lines and comments in the sql.
|
||||
// sql = gstr.Trim(sql)
|
||||
// sql = gstr.Replace(sql, "\n", " ")
|
||||
// sql, _ = gregex.ReplaceString(`\s{2,}`, ` `, sql)
|
||||
return handleArguments(sql, args)
|
||||
}
|
||||
|
||||
type formatWhereHolderInput struct {
|
||||
WhereHolder
|
||||
OmitNil bool
|
||||
|
||||
Reference in New Issue
Block a user