diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index efa029572..af41f6d21 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -64,6 +64,7 @@ type DB interface { // Transaction. Begin() (*TX, error) + Transaction(f func(tx *TX) error) (err error) Insert(table string, data interface{}, batch ...int) (sql.Result, error) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 891754557..289b2a658 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -310,6 +310,34 @@ func (c *Core) Begin() (*TX, error) { } } +// Transaction wraps the transaction logic using function . +// It rollbacks the transaction and returns the error from function if +// it returns non-nil error. It commits the transaction and returns nil if +// function returns nil. +// +// Note that, you should not Commit or Rollback the transaction in function +// as it is automatically handled by this function. +func (c *Core) Transaction(f func(tx *TX) error) (err error) { + var tx *TX + tx, err = c.DB.Begin() + if err != nil { + return err + } + defer func() { + if err != nil { + if e := tx.Rollback(); e != nil { + err = e + } + } else { + if e := tx.Commit(); e != nil { + err = e + } + } + }() + err = f(tx) + return +} + // Insert does "INSERT INTO ..." statement for the table. // If there's already one unique record of the data in the table, it returns error. // diff --git a/database/gdb/gdb_unit_z_mysql_transaction_test.go b/database/gdb/gdb_unit_z_mysql_transaction_test.go index 28d817fb2..a76cc942a 100644 --- a/database/gdb/gdb_unit_z_mysql_transaction_test.go +++ b/database/gdb/gdb_unit_z_mysql_transaction_test.go @@ -7,7 +7,9 @@ package gdb_test import ( + "errors" "fmt" + "github.com/gogf/gf/database/gdb" "testing" "github.com/gogf/gf/frame/g" @@ -300,7 +302,6 @@ func Test_TX_Replace(t *testing.T) { t.Assert(value.String(), "name_1") } }) - } func Test_TX_Save(t *testing.T) { @@ -713,5 +714,53 @@ func Test_TX_Delete(t *testing.T) { t.AssertNE(n, 0) } }) - +} + +func Test_Transaction(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + err := db.Transaction(func(tx *gdb.TX) error { + if _, err := tx.Replace(table, g.Map{ + "id": 1, + "passport": "USER_1", + "password": "PASS_1", + "nickname": "NAME_1", + "create_time": gtime.Now().String(), + }); err != nil { + t.Error(err) + } + return errors.New("error") + }) + t.AssertNE(err, nil) + + if value, err := db.Table(table).Fields("nickname").Where("id", 1).Value(); err != nil { + gtest.Error(err) + } else { + t.Assert(value.String(), "name_1") + } + }) + + gtest.C(t, func(t *gtest.T) { + err := db.Transaction(func(tx *gdb.TX) error { + if _, err := tx.Replace(table, g.Map{ + "id": 1, + "passport": "USER_1", + "password": "PASS_1", + "nickname": "NAME_1", + "create_time": gtime.Now().String(), + }); err != nil { + t.Error(err) + } + return nil + }) + t.Assert(err, nil) + + if value, err := db.Table(table).Fields("nickname").Where("id", 1).Value(); err != nil { + gtest.Error(err) + } else { + t.Assert(value.String(), "NAME_1") + } + }) }