diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 7b76f5dca..58519e927 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -38,10 +38,6 @@ func (c *Core) Ctx(ctx context.Context) DB { if ctx == nil { return c.db } - // It is already set context in previous chaining operation. - if c.ctx != nil { - return c.db - } ctx = context.WithValue(ctx, ctxStrictKeyName, 1) // It makes a shallow copy of current db and changes its context for next chaining operation. var ( diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index 8e88d0938..036ed3494 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -27,10 +27,15 @@ func (c *Core) Query(sql string, args ...interface{}) (rows *sql.Rows, err error func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) { // Transaction checks. if link == nil { - if link, err = c.SlaveLink(); err != nil { + if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { + // Firstly, check and retrieve transaction link from context. + link = &txLink{tx.tx} + } else if link, err = c.SlaveLink(); err != nil { + // Or else it creates one from master node. return nil, err } } else if !link.IsTransaction() { + // If current link is not transaction link, it checks and retrieves transaction from context. if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { link = &txLink{tx.tx} } @@ -84,10 +89,15 @@ func (c *Core) Exec(sql string, args ...interface{}) (result sql.Result, err err func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interface{}) (result sql.Result, err error) { // Transaction checks. if link == nil { - if link, err = c.MasterLink(); err != nil { + if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { + // Firstly, check and retrieve transaction link from context. + link = &txLink{tx.tx} + } else if link, err = c.MasterLink(); err != nil { + // Or else it creates one from master node. return nil, err } } else if !link.IsTransaction() { + // If current link is not transaction link, it checks and retrieves transaction from context. if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { link = &txLink{tx.tx} } @@ -170,11 +180,25 @@ func (c *Core) Prepare(sql string, execOnMaster ...bool) (*Stmt, error) { // DoPrepare calls prepare function on given link object and returns the statement object. func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) { - if link != nil && !link.IsTransaction() { + // Transaction checks. + if link == nil { + if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { + // Firstly, check and retrieve transaction link from context. + link = &txLink{tx.tx} + } else { + // Or else it creates one from master node. + var err error + if link, err = c.MasterLink(); err != nil { + return nil, err + } + } + } else if !link.IsTransaction() { + // If current link is not transaction link, it checks and retrieves transaction from context. if tx := TXFromCtx(ctx, c.db.GetGroup()); tx != nil { link = &txLink{tx.tx} } } + if c.GetConfig().PrepareTimeout > 0 { // DO NOT USE cancel function in prepare statement. ctx, _ = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout) diff --git a/database/gdb/gdb_z_mysql_transaction_test.go b/database/gdb/gdb_z_mysql_transaction_test.go index 669055145..d12ff51e4 100644 --- a/database/gdb/gdb_z_mysql_transaction_test.go +++ b/database/gdb/gdb_z_mysql_transaction_test.go @@ -9,6 +9,7 @@ package gdb_test import ( "context" "fmt" + "github.com/gogf/gf/os/gctx" "testing" "github.com/gogf/gf/database/gdb" @@ -1116,3 +1117,34 @@ func Test_Transaction_Nested_SavePoint_RollbackTo(t *testing.T) { t.Assert(all[0]["id"], 1) }) } + +func Test_Transaction_Method(t *testing.T) { + table := createTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + var err error + err = db.Transaction(gctx.New(), func(ctx context.Context, tx *gdb.TX) error { + _, err = db.Model(table).Ctx(ctx).Data(g.Map{ + "id": 1, + "passport": "t1", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "T1", + "create_time": gtime.Now().String(), + }).Insert() + t.AssertNil(err) + + _, err = db.Ctx(ctx).Exec(fmt.Sprintf( + "insert into %s(`passport`,`password`,`nickname`,`create_time`,`id`) "+ + "VALUES('t2','25d55ad283aa400af464c76d713c07ad','T2','2021-08-25 21:53:00',2) ", + table)) + t.AssertNil(err) + return gerror.New("rollback") + }) + t.AssertNE(err, nil) + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 0) + }) +}