add function Transaction for package gdb

This commit is contained in:
John
2020-04-26 17:47:19 +08:00
parent e01bfa05c3
commit f69da3ace1
3 changed files with 80 additions and 2 deletions

View File

@ -64,6 +64,7 @@ type DB interface {
// Transaction.
Begin() (*TX, error)
Transaction(f func(tx *TX) error) (err error)
Insert(table string, data interface{}, batch ...int) (sql.Result, error)
InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error)

View File

@ -310,6 +310,34 @@ func (c *Core) Begin() (*TX, error) {
}
}
// Transaction wraps the transaction logic using function <f>.
// It rollbacks the transaction and returns the error from function <f> if
// it returns non-nil error. It commits the transaction and returns nil if
// function <f> returns nil.
//
// Note that, you should not Commit or Rollback the transaction in function <f>
// as it is automatically handled by this function.
func (c *Core) Transaction(f func(tx *TX) error) (err error) {
var tx *TX
tx, err = c.DB.Begin()
if err != nil {
return err
}
defer func() {
if err != nil {
if e := tx.Rollback(); e != nil {
err = e
}
} else {
if e := tx.Commit(); e != nil {
err = e
}
}
}()
err = f(tx)
return
}
// Insert does "INSERT INTO ..." statement for the table.
// If there's already one unique record of the data in the table, it returns error.
//

View File

@ -7,7 +7,9 @@
package gdb_test
import (
"errors"
"fmt"
"github.com/gogf/gf/database/gdb"
"testing"
"github.com/gogf/gf/frame/g"
@ -300,7 +302,6 @@ func Test_TX_Replace(t *testing.T) {
t.Assert(value.String(), "name_1")
}
})
}
func Test_TX_Save(t *testing.T) {
@ -713,5 +714,53 @@ func Test_TX_Delete(t *testing.T) {
t.AssertNE(n, 0)
}
})
}
func Test_Transaction(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
err := db.Transaction(func(tx *gdb.TX) error {
if _, err := tx.Replace(table, g.Map{
"id": 1,
"passport": "USER_1",
"password": "PASS_1",
"nickname": "NAME_1",
"create_time": gtime.Now().String(),
}); err != nil {
t.Error(err)
}
return errors.New("error")
})
t.AssertNE(err, nil)
if value, err := db.Table(table).Fields("nickname").Where("id", 1).Value(); err != nil {
gtest.Error(err)
} else {
t.Assert(value.String(), "name_1")
}
})
gtest.C(t, func(t *gtest.T) {
err := db.Transaction(func(tx *gdb.TX) error {
if _, err := tx.Replace(table, g.Map{
"id": 1,
"passport": "USER_1",
"password": "PASS_1",
"nickname": "NAME_1",
"create_time": gtime.Now().String(),
}); err != nil {
t.Error(err)
}
return nil
})
t.Assert(err, nil)
if value, err := db.Table(table).Fields("nickname").Where("id", 1).Value(); err != nil {
gtest.Error(err)
} else {
t.Assert(value.String(), "NAME_1")
}
})
}