improve context and nested transaction feature for package gdb

This commit is contained in:
John Guo
2021-05-21 13:25:53 +08:00
parent 017c6e4e1f
commit 4e41d8aff8
28 changed files with 472 additions and 394 deletions

View File

@ -96,19 +96,6 @@ type DB interface {
Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) // See Core.Update.
Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) // See Core.Delete.
// ===========================================================================
// Internal APIs for CURD, which can be overwrote for custom CURD implements.
// ===========================================================================
DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) // See Core.DoQuery.
DoGetAll(link Link, sql string, args ...interface{}) (result Result, err error) // See Core.DoGetAll.
DoExec(link Link, sql string, args ...interface{}) (result sql.Result, err error) // See Core.DoExec.
DoPrepare(link Link, sql string) (*Stmt, error) // See Core.DoPrepare.
DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) // See Core.DoInsert.
DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) // See Core.DoBatchInsert.
DoUpdate(link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoUpdate.
DoDelete(link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoDelete.
// ===========================================================================
// Query APIs for convenience purpose.
// ===========================================================================
@ -167,40 +154,25 @@ type DB interface {
// Utility methods.
// ===========================================================================
GetCtx() context.Context // See Core.GetCtx.
GetChars() (charLeft string, charRight string) // See Core.GetChars.
GetMaster(schema ...string) (*sql.DB, error) // See Core.GetMaster.
GetSlave(schema ...string) (*sql.DB, error) // See Core.GetSlave.
QuoteWord(s string) string // See Core.QuoteWord.
QuoteString(s string) string // See Core.QuoteString.
QuotePrefixTableName(table string) string // See Core.QuotePrefixTableName.
Tables(schema ...string) (tables []string, err error) // See Core.Tables.
TableFields(link Link, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields.
HasTable(name string) (bool, error) // See Core.HasTable.
FilteredLinkInfo() string // See Core.FilteredLinkInfo.
GetCtx() context.Context // See Core.GetCtx.
GetCore() *Core // See Core.GetCore
GetChars() (charLeft string, charRight string) // See Core.GetChars.
Tables(ctx context.Context, schema ...string) (tables []string, err error) // See Core.Tables.
TableFields(ctx context.Context, link Link, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields.
FilteredLinkInfo() string // See Core.FilteredLinkInfo.
// HandleSqlBeforeCommit is a hook function, which deals with the sql string before
// it's committed to underlying driver. The parameter `link` specifies the current
// database connection operation object. You can modify the sql string `sql` and its
// arguments `args` as you wish before they're committed to driver.
// Also see Core.HandleSqlBeforeCommit.
HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{})
// ===========================================================================
// Internal methods, for internal usage purpose, you do not need consider it.
// ===========================================================================
mappingAndFilterData(schema, table string, data map[string]interface{}, filter bool) (map[string]interface{}, error) // See Core.mappingAndFilterData.
convertFieldValueToLocalValue(fieldValue interface{}, fieldType string) interface{} // See Core.convertFieldValueToLocalValue.
convertRowsToResult(rows *sql.Rows) (Result, error) // See Core.convertRowsToResult.
addSqlToTracing(ctx context.Context, sql *Sql) // See Core.addSqlToTracing.
writeSqlToLogger(v *Sql) // See Core.writeSqlToLogger.
HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{})
}
// Core is the base struct for database management.
type Core struct {
db DB // DB interface object.
ctx context.Context // Context for chaining operation only.
ctx context.Context // Context for chaining operation only. Do not set a default value in Core initialization.
group string // Configuration group name.
debug *gtype.Bool // Enable debug mode for the database, which can be changed in runtime.
cache *gcache.Cache // Cache manager, SQL result cache only.

View File

@ -15,9 +15,8 @@ import (
"strings"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/internal/utils"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/container/gvar"
"github.com/gogf/gf/os/gtime"
@ -25,6 +24,10 @@ import (
"github.com/gogf/gf/util/gconv"
)
func (c *Core) GetCore() *Core {
return c
}
// Ctx is a chaining function, which creates and returns a new DB that is a shallow copy
// of current DB object and with given context in it.
// Note that this returned DB object can be used only once, so do not assign it to
@ -33,6 +36,11 @@ func (c *Core) Ctx(ctx context.Context) DB {
if ctx == nil {
return c.db
}
// It is already set context in previous chaining operation.
if c.ctx != nil {
return c.db
}
// It makes a shallow copy of current db and changes its context for next chaining operation.
var (
err error
newCore = &Core{}
@ -41,9 +49,10 @@ func (c *Core) Ctx(ctx context.Context) DB {
*newCore = *c
newCore.ctx = ctx
newCore.db, err = driverMap[configNode.Type].New(newCore, configNode)
// Seldom error, just log it.
if err != nil {
c.db.GetLogger().Ctx(ctx).Error(err)
// It is really a serious error here.
// Do not let it continue.
panic(err)
}
return newCore.db
}
@ -60,7 +69,7 @@ func (c *Core) GetCtx() context.Context {
// GetCtxTimeout returns the context and cancel function for specified timeout type.
func (c *Core) GetCtxTimeout(timeoutType int, ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = c.db.GetCtx()
ctx = c.GetCtx()
} else {
ctx = context.WithValue(ctx, "WrappedByGetCtxTimeout", nil)
}
@ -102,15 +111,14 @@ func (c *Core) Query(sql string, args ...interface{}) (rows *sql.Rows, err error
if err != nil {
return nil, err
}
return c.db.DoQuery(link, sql, args...)
return c.DoQuery(c.GetCtx(), link, sql, args...)
}
// DoQuery commits the sql string and its arguments to underlying driver
// through given link object and returns the execution result.
func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) {
func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) {
sql, args = formatSql(sql, args)
sql, args = c.db.HandleSqlBeforeCommit(link, sql, args)
ctx := c.db.GetCtx()
sql, args = c.db.HandleSqlBeforeCommit(ctx, link, sql, args)
if c.GetConfig().QueryTimeout > 0 {
ctx, _ = context.WithTimeout(ctx, c.GetConfig().QueryTimeout)
}
@ -127,9 +135,10 @@ func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Ro
End: mTime2,
Group: c.db.GetGroup(),
}
// Tracing and logging.
c.addSqlToTracing(ctx, sqlObj)
if c.db.GetDebug() {
c.writeSqlToLogger(sqlObj)
c.writeSqlToLogger(ctx, sqlObj)
}
if err == nil {
return rows, nil
@ -146,15 +155,14 @@ func (c *Core) Exec(sql string, args ...interface{}) (result sql.Result, err err
if err != nil {
return nil, err
}
return c.db.DoExec(link, sql, args...)
return c.DoExec(c.GetCtx(), link, sql, args...)
}
// DoExec commits the sql string and its arguments to underlying driver
// through given link object and returns the execution result.
func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Result, err error) {
func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interface{}) (result sql.Result, err error) {
sql, args = formatSql(sql, args)
sql, args = c.db.HandleSqlBeforeCommit(link, sql, args)
ctx := c.db.GetCtx()
sql, args = c.db.HandleSqlBeforeCommit(ctx, link, sql, args)
if c.GetConfig().ExecTimeout > 0 {
var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithTimeout(ctx, c.GetConfig().ExecTimeout)
@ -178,9 +186,10 @@ func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Re
End: mTime2,
Group: c.db.GetGroup(),
}
// Tracing and logging.
c.addSqlToTracing(ctx, sqlObj)
if c.db.GetDebug() {
c.writeSqlToLogger(sqlObj)
c.writeSqlToLogger(ctx, sqlObj)
}
return result, formatError(err, sql, args...)
}
@ -207,12 +216,11 @@ func (c *Core) Prepare(sql string, execOnMaster ...bool) (*Stmt, error) {
return nil, err
}
}
return c.db.DoPrepare(link, sql)
return c.DoPrepare(c.GetCtx(), link, sql)
}
// doPrepare calls prepare function on given link object and returns the statement object.
func (c *Core) DoPrepare(link Link, sql string) (*Stmt, error) {
ctx := c.db.GetCtx()
// DoPrepare calls prepare function on given link object and returns the statement object.
func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) {
if c.GetConfig().PrepareTimeout > 0 {
// DO NOT USE cancel function in prepare statement.
ctx, _ = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout)
@ -232,9 +240,10 @@ func (c *Core) DoPrepare(link Link, sql string) (*Stmt, error) {
Group: c.db.GetGroup(),
}
)
// Tracing and logging.
c.addSqlToTracing(ctx, sqlObj)
if c.db.GetDebug() {
c.writeSqlToLogger(sqlObj)
c.writeSqlToLogger(ctx, sqlObj)
}
return &Stmt{
Stmt: stmt,
@ -245,23 +254,23 @@ func (c *Core) DoPrepare(link Link, sql string) (*Stmt, error) {
// GetAll queries and returns data records from database.
func (c *Core) GetAll(sql string, args ...interface{}) (Result, error) {
return c.db.DoGetAll(nil, sql, args...)
return c.DoGetAll(c.GetCtx(), nil, sql, args...)
}
// DoGetAll queries and returns data records from database.
func (c *Core) DoGetAll(link Link, sql string, args ...interface{}) (result Result, err error) {
func (c *Core) DoGetAll(ctx context.Context, link Link, sql string, args ...interface{}) (result Result, err error) {
if link == nil {
link, err = c.db.Slave()
if err != nil {
return nil, err
}
}
rows, err := c.db.DoQuery(link, sql, args...)
rows, err := c.DoQuery(ctx, link, sql, args...)
if err != nil || rows == nil {
return nil, err
}
defer rows.Close()
return c.db.convertRowsToResult(rows)
return c.convertRowsToResult(rows)
}
// GetOne queries and returns one record from database.
@ -279,7 +288,7 @@ func (c *Core) GetOne(sql string, args ...interface{}) (Record, error) {
// GetArray queries and returns data values as slice from database.
// Note that if there are multiple columns in the result, it returns just one column values randomly.
func (c *Core) GetArray(sql string, args ...interface{}) ([]Value, error) {
all, err := c.db.DoGetAll(nil, sql, args...)
all, err := c.DoGetAll(c.GetCtx(), nil, sql, args...)
if err != nil {
return nil, err
}
@ -374,91 +383,6 @@ func (c *Core) PingSlave() error {
}
}
// Begin starts and returns the transaction object.
// You should call Commit or Rollback functions of the transaction object
// if you no longer use the transaction. Commit or Rollback functions will also
// close the transaction automatically.
func (c *Core) Begin() (*TX, error) {
if master, err := c.db.Master(); err != nil {
return nil, err
} else {
//ctx := c.db.GetCtx()
//if c.GetConfig().TranTimeout > 0 {
// var cancelFunc context.CancelFunc
// ctx, cancelFunc = context.WithTimeout(ctx, c.GetConfig().TranTimeout)
// defer cancelFunc()
//}
var (
sqlStr = "BEGIN"
mTime1 = gtime.TimestampMilli()
rawTx, err = master.Begin()
mTime2 = gtime.TimestampMilli()
sqlObj = &Sql{
Sql: sqlStr,
Type: "DB.Begin",
Args: nil,
Format: sqlStr,
Error: err,
Start: mTime1,
End: mTime2,
Group: c.db.GetGroup(),
}
)
c.db.addSqlToTracing(c.db.GetCtx(), sqlObj)
if c.db.GetDebug() {
c.db.writeSqlToLogger(sqlObj)
}
if err == nil {
return &TX{
db: c.db,
tx: rawTx,
master: master,
}, nil
}
return nil, err
}
}
// Transaction wraps the transaction logic using function `f`.
// It rollbacks the transaction and returns the error from function `f` if
// it returns non-nil error. It commits the transaction and returns nil if
// function `f` returns nil.
//
// Note that, you should not Commit or Rollback the transaction in function `f`
// as it is automatically handled by this function.
func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx *TX) error) (err error) {
var tx *TX
// Check transaction object from context.
tx = TXFromCtx(ctx)
if tx != nil {
return tx.Transaction(ctx, f)
}
tx, err = c.db.Begin()
if err != nil {
return err
}
// Inject transaction object into context.
ctx = WithTX(ctx, tx)
defer func() {
if err == nil {
if e := recover(); e != nil {
err = fmt.Errorf("%v", e)
}
}
if err != nil {
if e := tx.Rollback(); e != nil {
err = e
}
} else {
if e := tx.Commit(); e != nil {
err = e
}
}
}()
err = f(ctx, tx)
return
}
// Insert does "INSERT INTO ..." statement for the table.
// If there's already one unique record of the data in the table, it returns error.
//
@ -536,7 +460,7 @@ func (c *Core) Save(table string, data interface{}, batch ...int) (sql.Result, e
return c.Model(table).Data(data).Save()
}
// doInsert inserts or updates data for given table.
// DoInsert inserts or updates data for given table.
// This function is usually used for custom interface definition, you do not need call it manually.
// The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc.
// Eg:
@ -548,8 +472,8 @@ func (c *Core) Save(table string, data interface{}, batch ...int) (sql.Result, e
// 1: replace: if there's unique/primary key in the data, it deletes it from table and inserts a new one;
// 2: save: if there's unique/primary key in the data, it updates it or else inserts a new one;
// 3: ignore: if there's unique/primary key in the data, it ignores the inserting;
func (c *Core) DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
table = c.db.QuotePrefixTableName(table)
func (c *Core) DoInsert(ctx context.Context, link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
table = c.QuotePrefixTableName(table)
var (
fields []string
values []string
@ -564,10 +488,10 @@ func (c *Core) DoInsert(link Link, table string, data interface{}, option int, b
}
switch reflectKind {
case reflect.Slice, reflect.Array:
return c.db.DoBatchInsert(link, table, data, option, batch...)
return c.DoBatchInsert(ctx, link, table, data, option, batch...)
case reflect.Struct:
if _, ok := data.(apiInterfaces); ok {
return c.db.DoBatchInsert(link, table, data, option, batch...)
return c.DoBatchInsert(ctx, link, table, data, option, batch...)
} else {
dataMap = ConvertDataForTableRecord(data)
}
@ -616,15 +540,11 @@ func (c *Core) DoInsert(link Link, table string, data interface{}, option int, b
return nil, err
}
}
return c.db.DoExec(
link,
fmt.Sprintf(
"%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(fields, ","),
strings.Join(values, ","), updateStr,
),
params...,
)
return c.DoExec(ctx, link, fmt.Sprintf(
"%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(fields, ","),
strings.Join(values, ","), updateStr,
), params...)
}
// BatchInsert batch inserts data.
@ -665,8 +585,8 @@ func (c *Core) BatchSave(table string, list interface{}, batch ...int) (sql.Resu
// DoBatchInsert batch inserts/replaces/saves data.
// This function is usually used for custom interface definition, you do not need call it manually.
func (c *Core) DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
table = c.db.QuotePrefixTableName(table)
func (c *Core) DoBatchInsert(ctx context.Context, link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
table = c.QuotePrefixTableName(table)
var (
keys []string // Field names.
values []string // Value holder string array, like: (?,?,?)
@ -777,16 +697,12 @@ func (c *Core) DoBatchInsert(link Link, table string, list interface{}, option i
}
valueHolder = append(valueHolder, "("+gstr.Join(values, ",")+")")
if len(valueHolder) == batchNum || (i == listMapLen-1 && len(valueHolder) > 0) {
r, err := c.db.DoExec(
link,
fmt.Sprintf(
"%s INTO %s(%s) VALUES%s %s",
operation, table, keysStr,
gstr.Join(valueHolder, ","),
updateStr,
),
params...,
)
r, err := c.DoExec(ctx, link, fmt.Sprintf(
"%s INTO %s(%s) VALUES%s %s",
operation, table, keysStr,
gstr.Join(valueHolder, ","),
updateStr,
), params...)
if err != nil {
return r, err
}
@ -821,10 +737,10 @@ func (c *Core) Update(table string, data interface{}, condition interface{}, arg
return c.Model(table).Data(data).Where(condition, args...).Update()
}
// doUpdate does "UPDATE ... " statement for the table.
// DoUpdate does "UPDATE ... " statement for the table.
// This function is usually used for custom interface definition, you do not need call it manually.
func (c *Core) DoUpdate(link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) {
table = c.db.QuotePrefixTableName(table)
func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) {
table = c.QuotePrefixTableName(table)
var (
rv = reflect.ValueOf(data)
kind = rv.Kind()
@ -849,7 +765,7 @@ func (c *Core) DoUpdate(link Link, table string, data interface{}, condition str
if value.Value != 0 {
column := k
if value.Field != "" {
column = c.db.QuoteWord(value.Field)
column = c.QuoteWord(value.Field)
}
fields = append(fields, fmt.Sprintf("%s=%s+?", column, column))
params = append(params, value.Value)
@ -858,16 +774,16 @@ func (c *Core) DoUpdate(link Link, table string, data interface{}, condition str
if value.Value != 0 {
column := k
if value.Field != "" {
column = c.db.QuoteWord(value.Field)
column = c.QuoteWord(value.Field)
}
fields = append(fields, fmt.Sprintf("%s=%s+?", column, column))
params = append(params, value.Value)
}
default:
if s, ok := v.(Raw); ok {
fields = append(fields, c.db.QuoteWord(k)+"="+gconv.String(s))
fields = append(fields, c.QuoteWord(k)+"="+gconv.String(s))
} else {
fields = append(fields, c.db.QuoteWord(k)+"=?")
fields = append(fields, c.QuoteWord(k)+"=?")
params = append(params, v)
}
}
@ -888,11 +804,7 @@ func (c *Core) DoUpdate(link Link, table string, data interface{}, condition str
return nil, err
}
}
return c.db.DoExec(
link,
fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition),
args...,
)
return c.DoExec(ctx, link, fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition), args...)
}
// Delete does "DELETE FROM ... " statement for the table.
@ -912,14 +824,14 @@ func (c *Core) Delete(table string, condition interface{}, args ...interface{})
// DoDelete does "DELETE FROM ... " statement for the table.
// This function is usually used for custom interface definition, you do not need call it manually.
func (c *Core) DoDelete(link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) {
func (c *Core) DoDelete(ctx context.Context, link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) {
if link == nil {
if link, err = c.db.Master(); err != nil {
return nil, err
}
}
table = c.db.QuotePrefixTableName(table)
return c.db.DoExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
table = c.QuotePrefixTableName(table)
return c.DoExec(ctx, link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
}
// convertRowsToResult converts underlying data record type sql.Rows to Result type.
@ -955,7 +867,7 @@ func (c *Core) convertRowsToResult(rows *sql.Rows) (Result, error) {
if value == nil {
row[columnNames[i]] = gvar.New(nil)
} else {
row[columnNames[i]] = gvar.New(c.db.convertFieldValueToLocalValue(value, columnTypes[i]))
row[columnNames[i]] = gvar.New(c.convertFieldValueToLocalValue(value, columnTypes[i]))
}
}
records = append(records, row)
@ -977,19 +889,23 @@ func (c *Core) MarshalJSON() ([]byte, error) {
// writeSqlToLogger outputs the sql object to logger.
// It is enabled only if configuration "debug" is true.
func (c *Core) writeSqlToLogger(v *Sql) {
s := fmt.Sprintf("[%3d ms] [%s] %s", v.End-v.Start, v.Group, v.Format)
if v.Error != nil {
s += "\nError: " + v.Error.Error()
c.logger.Ctx(c.db.GetCtx()).Error(s)
func (c *Core) writeSqlToLogger(ctx context.Context, sql *Sql) {
var transactionIdStr string
if v := ctx.Value(transactionIdForLoggerCtx); v != nil {
transactionIdStr = fmt.Sprintf(`[%d] `, v.(uint64))
}
s := fmt.Sprintf("[%3d ms] [%s] %s%s", sql.End-sql.Start, sql.Group, transactionIdStr, sql.Format)
if sql.Error != nil {
s += "\nError: " + sql.Error.Error()
c.logger.Ctx(ctx).Error(s)
} else {
c.logger.Ctx(c.db.GetCtx()).Debug(s)
c.logger.Ctx(ctx).Debug(s)
}
}
// HasTable determine whether the table name exists in the database.
func (c *Core) HasTable(name string) (bool, error) {
tableList, err := c.db.Tables()
tableList, err := c.db.Tables(c.GetCtx())
if err != nil {
return false, err
}

View File

@ -8,15 +8,12 @@ package gdb
import (
"fmt"
"github.com/gogf/gf/os/gcache"
"sync"
"time"
"github.com/gogf/gf/os/glog"
)
"github.com/gogf/gf/os/gcache"
const (
DefaultGroupName = "default" // Default group name.
"github.com/gogf/gf/os/glog"
)
// Config is the configuration management object.
@ -53,6 +50,10 @@ type ConfigNode struct {
TimeMaintainDisabled bool `json:"timeMaintainDisabled"` // (Optional) Disable the automatic time maintaining feature.
}
const (
DefaultGroupName = "default" // Default group name.
)
// configs is internal used configuration object.
var configs struct {
sync.RWMutex

View File

@ -7,10 +7,11 @@
package gdb
import (
"github.com/gogf/gf/util/gutil"
"strings"
"time"
"github.com/gogf/gf/util/gutil"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/os/gtime"
@ -149,7 +150,7 @@ func (c *Core) convertFieldValueToLocalValue(fieldValue interface{}, fieldType s
// mappingAndFilterData automatically mappings the map key to table field and removes
// all key-value pairs that are not the field of given table.
func (c *Core) mappingAndFilterData(schema, table string, data map[string]interface{}, filter bool) (map[string]interface{}, error) {
if fieldsMap, err := c.db.TableFields(nil, table, schema); err == nil {
if fieldsMap, err := c.db.TableFields(c.GetCtx(), nil, table, schema); err == nil {
fieldsKeyMap := make(map[string]interface{}, len(fieldsMap))
for k, _ := range fieldsMap {
fieldsKeyMap[k] = nil

View File

@ -12,12 +12,14 @@
package gdb
import (
"context"
"database/sql"
"fmt"
"github.com/gogf/gf/errors/gerror"
"strconv"
"strings"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/text/gstr"
@ -77,7 +79,7 @@ func (d *DriverMssql) GetChars() (charLeft string, charRight string) {
}
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverMssql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
func (d *DriverMssql) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
var index int
// Convert place holder char '?' to string "@px".
str, _ := gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
@ -183,14 +185,14 @@ func (d *DriverMssql) parseSql(sql string) string {
// Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models.
func (d *DriverMssql) Tables(schema ...string) (tables []string, err error) {
func (d *DriverMssql) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
var result Result
link, err := d.db.GetSlave(schema...)
link, err := d.GetSlave(schema...)
if err != nil {
return nil, err
}
result, err = d.db.DoGetAll(link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`)
result, err = d.DoGetAll(ctx, link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`)
if err != nil {
return
}
@ -205,7 +207,7 @@ func (d *DriverMssql) Tables(schema ...string) (tables []string, err error) {
// TableFields retrieves and returns the fields information of specified table of current schema.
//
// Also see DriverMysql.TableFields.
func (d *DriverMssql) TableFields(link Link, table string, schema ...string) (fields map[string]*TableField, err error) {
func (d *DriverMssql) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) {
charL, charR := d.GetChars()
table = gstr.Trim(table, charL+charR)
if gstr.Contains(table, " ") {
@ -224,7 +226,7 @@ func (d *DriverMssql) TableFields(link Link, table string, schema ...string) (fi
result Result
)
if link == nil {
link, err = d.db.GetSlave(checkSchema)
link, err = d.GetSlave(checkSchema)
if err != nil {
return nil
}
@ -260,7 +262,7 @@ ORDER BY a.id,a.colorder`,
strings.ToUpper(table),
)
structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql))
result, err = d.db.DoGetAll(link, structureSql)
result, err = d.DoGetAll(ctx, link, structureSql)
if err != nil {
return nil
}

View File

@ -7,8 +7,10 @@
package gdb
import (
"context"
"database/sql"
"fmt"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/text/gregex"
@ -75,19 +77,19 @@ func (d *DriverMysql) GetChars() (charLeft string, charRight string) {
}
// HandleSqlBeforeCommit handles the sql before posts it to database.
func (d *DriverMysql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
func (d *DriverMysql) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
return sql, args
}
// Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models.
func (d *DriverMysql) Tables(schema ...string) (tables []string, err error) {
func (d *DriverMysql) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
var result Result
link, err := d.db.GetSlave(schema...)
link, err := d.GetSlave(schema...)
if err != nil {
return nil, err
}
result, err = d.db.DoGetAll(link, `SHOW TABLES`)
result, err = d.DoGetAll(ctx, link, `SHOW TABLES`)
if err != nil {
return
}
@ -111,7 +113,7 @@ func (d *DriverMysql) Tables(schema ...string) (tables []string, err error) {
//
// It's using cache feature to enhance the performance, which is never expired util the
// process restarts.
func (d *DriverMysql) TableFields(link Link, table string, schema ...string) (fields map[string]*TableField, err error) {
func (d *DriverMysql) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) {
charL, charR := d.GetChars()
table = gstr.Trim(table, charL+charR)
if gstr.Contains(table, " ") {
@ -130,14 +132,13 @@ func (d *DriverMysql) TableFields(link Link, table string, schema ...string) (fi
result Result
)
if link == nil {
link, err = d.db.GetSlave(checkSchema)
link, err = d.GetSlave(checkSchema)
if err != nil {
return nil
}
}
result, err = d.db.DoGetAll(
link,
fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.db.QuoteWord(table)),
result, err = d.DoGetAll(ctx, link,
fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.QuoteWord(table)),
)
if err != nil {
return nil

View File

@ -12,17 +12,19 @@
package gdb
import (
"context"
"database/sql"
"fmt"
"reflect"
"strconv"
"strings"
"time"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/text/gregex"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/util/gconv"
"reflect"
"strconv"
"strings"
"time"
)
// DriverOracle is the driver for oracle database.
@ -83,7 +85,7 @@ func (d *DriverOracle) GetChars() (charLeft string, charRight string) {
}
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverOracle) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}) {
func (d *DriverOracle) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}) {
var index int
// Convert place holder char '?' to string ":vx".
newSql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
@ -164,9 +166,9 @@ func (d *DriverOracle) parseSql(sql string) string {
// Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models.
// Note that it ignores the parameter `schema` in oracle database, as it is not necessary.
func (d *DriverOracle) Tables(schema ...string) (tables []string, err error) {
func (d *DriverOracle) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
var result Result
result, err = d.db.DoGetAll(nil, "SELECT TABLE_NAME FROM USER_TABLES ORDER BY TABLE_NAME")
result, err = d.DoGetAll(ctx, nil, "SELECT TABLE_NAME FROM USER_TABLES ORDER BY TABLE_NAME")
if err != nil {
return
}
@ -181,7 +183,7 @@ func (d *DriverOracle) Tables(schema ...string) (tables []string, err error) {
// TableFields retrieves and returns the fields information of specified table of current schema.
//
// Also see DriverMysql.TableFields.
func (d *DriverOracle) TableFields(link Link, table string, schema ...string) (fields map[string]*TableField, err error) {
func (d *DriverOracle) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) {
charL, charR := d.GetChars()
table = gstr.Trim(table, charL+charR)
if gstr.Contains(table, " ") {
@ -211,12 +213,12 @@ FROM USER_TAB_COLUMNS WHERE TABLE_NAME = '%s' ORDER BY COLUMN_ID`,
)
structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql))
if link == nil {
link, err = d.db.GetSlave(checkSchema)
link, err = d.GetSlave(checkSchema)
if err != nil {
return nil
}
}
result, err = d.db.DoGetAll(link, structureSql)
result, err = d.DoGetAll(ctx, link, structureSql)
if err != nil {
return nil
}
@ -264,7 +266,7 @@ func (d *DriverOracle) getTableUniqueIndex(table string) (fields map[string]map[
return
}
func (d *DriverOracle) DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
func (d *DriverOracle) DoInsert(ctx context.Context, link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) {
var (
fields []string
values []string
@ -279,7 +281,7 @@ func (d *DriverOracle) DoInsert(link Link, table string, data interface{}, optio
}
switch kind {
case reflect.Slice, reflect.Array:
return d.db.DoBatchInsert(link, table, data, option, batch...)
return d.DoBatchInsert(ctx, link, table, data, option, batch...)
case reflect.Map:
fallthrough
case reflect.Struct:
@ -353,20 +355,17 @@ func (d *DriverOracle) DoInsert(link Link, table string, data interface{}, optio
table, tableAlias1, strings.Join(subSqlStr, ","), tableAlias2,
strings.Join(onStr, "AND"), strings.Join(updateStr, ","), strings.Join(fields, ","), strings.Join(values, ","),
)
return d.db.DoExec(link, tmp, params...)
return d.DoExec(ctx, link, tmp, params...)
case insertOptionIgnore:
return d.db.DoExec(link,
fmt.Sprintf(
"INSERT /*+ IGNORE_ROW_ON_DUPKEY_INDEX(%s(%s)) */ INTO %s(%s) VALUES(%s)",
table, strings.Join(indexes, ","), table, strings.Join(fields, ","), strings.Join(values, ","),
),
params...)
return d.DoExec(ctx, link, fmt.Sprintf(
"INSERT /*+ IGNORE_ROW_ON_DUPKEY_INDEX(%s(%s)) */ INTO %s(%s) VALUES(%s)",
table, strings.Join(indexes, ","), table, strings.Join(fields, ","), strings.Join(values, ","),
), params...)
}
}
return d.db.DoExec(
link,
return d.DoExec(ctx, link,
fmt.Sprintf(
"INSERT INTO %s(%s) VALUES(%s)",
table, strings.Join(fields, ","), strings.Join(values, ","),
@ -374,7 +373,7 @@ func (d *DriverOracle) DoInsert(link Link, table string, data interface{}, optio
params...)
}
func (d *DriverOracle) DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
func (d *DriverOracle) DoBatchInsert(ctx context.Context, link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) {
var (
keys []string
values []string
@ -435,7 +434,7 @@ func (d *DriverOracle) DoBatchInsert(link Link, table string, list interface{},
)
if option != insertOptionDefault {
for _, v := range listMap {
r, err := d.db.DoInsert(link, table, v, option, 1)
r, err := d.DoInsert(ctx, link, table, v, option, 1)
if err != nil {
return r, err
}
@ -463,7 +462,7 @@ func (d *DriverOracle) DoBatchInsert(link Link, table string, list interface{},
values = append(values, valueHolderStr)
intoStr = append(intoStr, fmt.Sprintf(" INTO %s(%s) VALUES(%s) ", table, keyStr, valueHolderStr))
if len(intoStr) == batchNum {
r, err := d.db.DoExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
r, err := d.DoExec(ctx, link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
if err != nil {
return r, err
}
@ -479,7 +478,7 @@ func (d *DriverOracle) DoBatchInsert(link Link, table string, list interface{},
}
// The leftover data.
if len(intoStr) > 0 {
r, err := d.db.DoExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
r, err := d.DoExec(ctx, link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...)
if err != nil {
return r, err
}

View File

@ -12,12 +12,14 @@
package gdb
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/text/gstr"
"strings"
"github.com/gogf/gf/text/gregex"
)
@ -75,7 +77,7 @@ func (d *DriverPgsql) GetChars() (charLeft string, charRight string) {
}
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverPgsql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
func (d *DriverPgsql) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
var index int
// Convert place holder char '?' to string "$x".
sql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
@ -88,9 +90,9 @@ func (d *DriverPgsql) HandleSqlBeforeCommit(link Link, sql string, args []interf
// Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models.
func (d *DriverPgsql) Tables(schema ...string) (tables []string, err error) {
func (d *DriverPgsql) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
var result Result
link, err := d.db.GetSlave(schema...)
link, err := d.GetSlave(schema...)
if err != nil {
return nil, err
}
@ -98,7 +100,7 @@ func (d *DriverPgsql) Tables(schema ...string) (tables []string, err error) {
if len(schema) > 0 && schema[0] != "" {
query = fmt.Sprintf("SELECT TABLENAME FROM PG_TABLES WHERE SCHEMANAME = '%s' ORDER BY TABLENAME", schema[0])
}
result, err = d.db.DoGetAll(link, query)
result, err = d.DoGetAll(ctx, link, query)
if err != nil {
return
}
@ -113,7 +115,7 @@ func (d *DriverPgsql) Tables(schema ...string) (tables []string, err error) {
// TableFields retrieves and returns the fields information of specified table of current schema.
//
// Also see DriverMysql.TableFields.
func (d *DriverPgsql) TableFields(link Link, table string, schema ...string) (fields map[string]*TableField, err error) {
func (d *DriverPgsql) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) {
charL, charR := d.GetChars()
table = gstr.Trim(table, charL+charR)
if gstr.Contains(table, " ") {
@ -141,12 +143,12 @@ ORDER BY a.attnum`,
)
structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql))
if link == nil {
link, err = d.db.GetSlave(checkSchema)
link, err = d.GetSlave(checkSchema)
if err != nil {
return nil
}
}
result, err = d.db.DoGetAll(link, structureSql)
result, err = d.DoGetAll(ctx, link, structureSql)
if err != nil {
return nil
}

View File

@ -11,13 +11,15 @@
package gdb
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/os/gfile"
"github.com/gogf/gf/text/gstr"
"strings"
)
// DriverSqlite is the driver for sqlite database.
@ -67,20 +69,20 @@ func (d *DriverSqlite) GetChars() (charLeft string, charRight string) {
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
// TODO 需要增加对Save方法的支持可使用正则来实现替换
// TODO 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE)
func (d *DriverSqlite) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
func (d *DriverSqlite) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) {
return sql, args
}
// Tables retrieves and returns the tables of current schema.
// It's mainly used in cli tool chain for automatically generating the models.
func (d *DriverSqlite) Tables(schema ...string) (tables []string, err error) {
func (d *DriverSqlite) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
var result Result
link, err := d.db.GetSlave(schema...)
link, err := d.GetSlave(schema...)
if err != nil {
return nil, err
}
result, err = d.db.DoGetAll(link, `SELECT NAME FROM SQLITE_MASTER WHERE TYPE='table' ORDER BY NAME`)
result, err = d.DoGetAll(ctx, link, `SELECT NAME FROM SQLITE_MASTER WHERE TYPE='table' ORDER BY NAME`)
if err != nil {
return
}
@ -95,7 +97,7 @@ func (d *DriverSqlite) Tables(schema ...string) (tables []string, err error) {
// TableFields retrieves and returns the fields information of specified table of current schema.
//
// Also see DriverMysql.TableFields.
func (d *DriverSqlite) TableFields(link Link, table string, schema ...string) (fields map[string]*TableField, err error) {
func (d *DriverSqlite) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) {
charL, charR := d.GetChars()
table = gstr.Trim(table, charL+charR)
if gstr.Contains(table, " ") {
@ -114,12 +116,12 @@ func (d *DriverSqlite) TableFields(link Link, table string, schema ...string) (f
result Result
)
if link == nil {
link, err = d.db.GetSlave(checkSchema)
link, err = d.GetSlave(checkSchema)
if err != nil {
return nil
}
}
result, err = d.db.DoGetAll(link, fmt.Sprintf(`PRAGMA TABLE_INFO(%s)`, table))
result, err = d.DoGetAll(ctx, link, fmt.Sprintf(`PRAGMA TABLE_INFO(%s)`, table))
if err != nil {
return nil
}

View File

@ -9,6 +9,11 @@ package gdb
import (
"bytes"
"fmt"
"reflect"
"regexp"
"strings"
"time"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/internal/empty"
"github.com/gogf/gf/internal/json"
@ -16,10 +21,6 @@ import (
"github.com/gogf/gf/os/gtime"
"github.com/gogf/gf/util/gmeta"
"github.com/gogf/gf/util/gutil"
"reflect"
"regexp"
"strings"
"time"
"github.com/gogf/gf/internal/structs"
@ -504,7 +505,7 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) (
// Eg: Where/And/Or("uid>=", 1)
newWhere += "?"
} else if gregex.IsMatchString(regularFieldNameRegPattern, newWhere) {
newWhere = db.QuoteString(newWhere)
newWhere = db.GetCore().QuoteString(newWhere)
if len(newArgs) > 0 {
if utils.IsArray(newArgs[0]) {
// Eg:
@ -543,9 +544,9 @@ func formatWhereInterfaces(db DB, where []interface{}, buffer *bytes.Buffer, new
for i := 0; i < len(where); i += 2 {
str = gconv.String(where[i])
if buffer.Len() > 0 {
buffer.WriteString(" AND " + db.QuoteWord(str) + "=?")
buffer.WriteString(" AND " + db.GetCore().QuoteWord(str) + "=?")
} else {
buffer.WriteString(db.QuoteWord(str) + "=?")
buffer.WriteString(db.GetCore().QuoteWord(str) + "=?")
}
if s, ok := where[i+1].(Raw); ok {
buffer.WriteString(gconv.String(s))
@ -558,7 +559,7 @@ func formatWhereInterfaces(db DB, where []interface{}, buffer *bytes.Buffer, new
// formatWhereKeyValue handles each key-value pair of the parameter map.
func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key string, value interface{}) []interface{} {
quotedKey := db.QuoteWord(key)
quotedKey := db.GetCore().QuoteWord(key)
if buffer.Len() > 0 {
buffer.WriteString(" AND ")
}

View File

@ -9,9 +9,10 @@ package gdb
import (
"context"
"fmt"
"github.com/gogf/gf/text/gregex"
"time"
"github.com/gogf/gf/text/gregex"
"github.com/gogf/gf/text/gstr"
)
@ -98,10 +99,10 @@ func (c *Core) Model(tableNameOrStruct ...interface{}) *Model {
if len(tableNames) > 1 {
tableStr = fmt.Sprintf(
`%s AS %s`, c.db.QuotePrefixTableName(tableNames[0]), c.db.QuoteWord(tableNames[1]),
`%s AS %s`, c.QuotePrefixTableName(tableNames[0]), c.QuoteWord(tableNames[1]),
)
} else if len(tableNames) == 1 {
tableStr = c.db.QuotePrefixTableName(tableNames[0])
tableStr = c.QuotePrefixTableName(tableNames[0])
}
return &Model{
db: c.db,
@ -148,9 +149,21 @@ func (m *Model) Ctx(ctx context.Context) *Model {
}
model := m.getModel()
model.db = model.db.Ctx(ctx)
if m.tx != nil {
model.tx = model.tx.Ctx(ctx)
}
return model
}
// GetCtx returns the context for current Model.
// It returns `context.Background()` is there's no context previously set.
func (m *Model) GetCtx() context.Context {
if m.tx != nil && m.tx.ctx != nil {
return m.tx.ctx
}
return m.db.GetCtx()
}
// As sets an alias name for current table.
func (m *Model) As(as string) *Model {
if m.tables != "" {

View File

@ -38,6 +38,6 @@ func (m *Model) Cache(duration time.Duration, name ...string) *Model {
// cache feature is enabled.
func (m *Model) checkAndRemoveCache() {
if m.cacheEnabled && m.cacheDuration < 0 && len(m.cacheName) > 0 {
m.db.GetCache().Ctx(m.db.GetCtx()).Remove(m.cacheName)
m.db.GetCache().Ctx(m.GetCtx()).Remove(m.cacheName)
}
}

View File

@ -61,53 +61,53 @@ func (m *Model) WherePri(where interface{}, args ...interface{}) *Model {
// WhereBetween builds `xxx BETWEEN x AND y` statement.
func (m *Model) WhereBetween(column string, min, max interface{}) *Model {
return m.Where(fmt.Sprintf(`%s BETWEEN ? AND ?`, m.db.QuoteWord(column)), min, max)
return m.Where(fmt.Sprintf(`%s BETWEEN ? AND ?`, m.db.GetCore().QuoteWord(column)), min, max)
}
// WhereLike builds `xxx LIKE x` statement.
func (m *Model) WhereLike(column string, like interface{}) *Model {
return m.Where(fmt.Sprintf(`%s LIKE ?`, m.db.QuoteWord(column)), like)
return m.Where(fmt.Sprintf(`%s LIKE ?`, m.db.GetCore().QuoteWord(column)), like)
}
// WhereIn builds `xxx IN (x)` statement.
func (m *Model) WhereIn(column string, in interface{}) *Model {
return m.Where(fmt.Sprintf(`%s IN (?)`, m.db.QuoteWord(column)), in)
return m.Where(fmt.Sprintf(`%s IN (?)`, m.db.GetCore().QuoteWord(column)), in)
}
// WhereNull builds `xxx IS NULL` statement.
func (m *Model) WhereNull(columns ...string) *Model {
model := m
for _, column := range columns {
model = m.Where(fmt.Sprintf(`%s IS NULL`, m.db.QuoteWord(column)))
model = m.Where(fmt.Sprintf(`%s IS NULL`, m.db.GetCore().QuoteWord(column)))
}
return model
}
// WhereNotBetween builds `xxx NOT BETWEEN x AND y` statement.
func (m *Model) WhereNotBetween(column string, min, max interface{}) *Model {
return m.Where(fmt.Sprintf(`%s NOT BETWEEN ? AND ?`, m.db.QuoteWord(column)), min, max)
return m.Where(fmt.Sprintf(`%s NOT BETWEEN ? AND ?`, m.db.GetCore().QuoteWord(column)), min, max)
}
// WhereNotLike builds `xxx NOT LIKE x` statement.
func (m *Model) WhereNotLike(column string, like interface{}) *Model {
return m.Where(fmt.Sprintf(`%s NOT LIKE ?`, m.db.QuoteWord(column)), like)
return m.Where(fmt.Sprintf(`%s NOT LIKE ?`, m.db.GetCore().QuoteWord(column)), like)
}
// WhereNot builds `xxx != x` statement.
func (m *Model) WhereNot(column string, value interface{}) *Model {
return m.Where(fmt.Sprintf(`%s != ?`, m.db.QuoteWord(column)), value)
return m.Where(fmt.Sprintf(`%s != ?`, m.db.GetCore().QuoteWord(column)), value)
}
// WhereNotIn builds `xxx NOT IN (x)` statement.
func (m *Model) WhereNotIn(column string, in interface{}) *Model {
return m.Where(fmt.Sprintf(`%s NOT IN (?)`, m.db.QuoteWord(column)), in)
return m.Where(fmt.Sprintf(`%s NOT IN (?)`, m.db.GetCore().QuoteWord(column)), in)
}
// WhereNotNull builds `xxx IS NOT NULL` statement.
func (m *Model) WhereNotNull(columns ...string) *Model {
model := m
for _, column := range columns {
model = m.Where(fmt.Sprintf(`%s IS NOT NULL`, m.db.QuoteWord(column)))
model = m.Where(fmt.Sprintf(`%s IS NOT NULL`, m.db.GetCore().QuoteWord(column)))
}
return model
}
@ -128,48 +128,48 @@ func (m *Model) WhereOr(where interface{}, args ...interface{}) *Model {
// WhereOrBetween builds `xxx BETWEEN x AND y` statement in `OR` conditions.
func (m *Model) WhereOrBetween(column string, min, max interface{}) *Model {
return m.WhereOr(fmt.Sprintf(`%s BETWEEN ? AND ?`, m.db.QuoteWord(column)), min, max)
return m.WhereOr(fmt.Sprintf(`%s BETWEEN ? AND ?`, m.db.GetCore().QuoteWord(column)), min, max)
}
// WhereOrLike builds `xxx LIKE x` statement in `OR` conditions.
func (m *Model) WhereOrLike(column string, like interface{}) *Model {
return m.WhereOr(fmt.Sprintf(`%s LIKE ?`, m.db.QuoteWord(column)), like)
return m.WhereOr(fmt.Sprintf(`%s LIKE ?`, m.db.GetCore().QuoteWord(column)), like)
}
// WhereOrIn builds `xxx IN (x)` statement in `OR` conditions.
func (m *Model) WhereOrIn(column string, in interface{}) *Model {
return m.WhereOr(fmt.Sprintf(`%s IN (?)`, m.db.QuoteWord(column)), in)
return m.WhereOr(fmt.Sprintf(`%s IN (?)`, m.db.GetCore().QuoteWord(column)), in)
}
// WhereOrNull builds `xxx IS NULL` statement in `OR` conditions.
func (m *Model) WhereOrNull(columns ...string) *Model {
model := m
for _, column := range columns {
model = m.WhereOr(fmt.Sprintf(`%s IS NULL`, m.db.QuoteWord(column)))
model = m.WhereOr(fmt.Sprintf(`%s IS NULL`, m.db.GetCore().QuoteWord(column)))
}
return model
}
// WhereOrNotBetween builds `xxx NOT BETWEEN x AND y` statement in `OR` conditions.
func (m *Model) WhereOrNotBetween(column string, min, max interface{}) *Model {
return m.WhereOr(fmt.Sprintf(`%s NOT BETWEEN ? AND ?`, m.db.QuoteWord(column)), min, max)
return m.WhereOr(fmt.Sprintf(`%s NOT BETWEEN ? AND ?`, m.db.GetCore().QuoteWord(column)), min, max)
}
// WhereOrNotLike builds `xxx NOT LIKE x` statement in `OR` conditions.
func (m *Model) WhereOrNotLike(column string, like interface{}) *Model {
return m.WhereOr(fmt.Sprintf(`%s NOT LIKE ?`, m.db.QuoteWord(column)), like)
return m.WhereOr(fmt.Sprintf(`%s NOT LIKE ?`, m.db.GetCore().QuoteWord(column)), like)
}
// WhereOrNotIn builds `xxx NOT IN (x)` statement.
func (m *Model) WhereOrNotIn(column string, in interface{}) *Model {
return m.WhereOr(fmt.Sprintf(`%s NOT IN (?)`, m.db.QuoteWord(column)), in)
return m.WhereOr(fmt.Sprintf(`%s NOT IN (?)`, m.db.GetCore().QuoteWord(column)), in)
}
// WhereOrNotNull builds `xxx IS NOT NULL` statement in `OR` conditions.
func (m *Model) WhereOrNotNull(columns ...string) *Model {
model := m
for _, column := range columns {
model = m.WhereOr(fmt.Sprintf(`%s IS NOT NULL`, m.db.QuoteWord(column)))
model = m.WhereOr(fmt.Sprintf(`%s IS NOT NULL`, m.db.GetCore().QuoteWord(column)))
}
return model
}
@ -177,7 +177,7 @@ func (m *Model) WhereOrNotNull(columns ...string) *Model {
// Group sets the "GROUP BY" statement for the model.
func (m *Model) Group(groupBy string) *Model {
model := m.getModel()
model.groupBy = m.db.QuoteString(groupBy)
model.groupBy = m.db.GetCore().QuoteString(groupBy)
return model
}
@ -215,7 +215,7 @@ func (m *Model) Order(orderBy ...string) *Model {
return m
}
model := m.getModel()
model.orderBy = m.db.QuoteString(strings.Join(orderBy, " "))
model.orderBy = m.db.GetCore().QuoteString(strings.Join(orderBy, " "))
return model
}
@ -225,7 +225,7 @@ func (m *Model) OrderAsc(column string) *Model {
return m
}
model := m.getModel()
model.orderBy = m.db.QuoteWord(column) + " ASC"
model.orderBy = m.db.GetCore().QuoteWord(column) + " ASC"
return model
}
@ -235,7 +235,7 @@ func (m *Model) OrderDesc(column string) *Model {
return m
}
model := m.getModel()
model.orderBy = m.db.QuoteWord(column) + " DESC"
model.orderBy = m.db.GetCore().QuoteWord(column) + " DESC"
return model
}

View File

@ -9,6 +9,7 @@ package gdb
import (
"database/sql"
"fmt"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/os/gtime"
"github.com/gogf/gf/text/gstr"
@ -32,10 +33,11 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) {
)
// Soft deleting.
if !m.unscoped && fieldNameDelete != "" {
return m.db.DoUpdate(
return m.db.GetCore().DoUpdate(
m.GetCtx(),
m.getLink(true),
m.tables,
fmt.Sprintf(`%s=?`, m.db.QuoteString(fieldNameDelete)),
fmt.Sprintf(`%s=?`, m.db.GetCore().QuoteString(fieldNameDelete)),
conditionWhere+conditionExtra,
append([]interface{}{gtime.Now().String()}, conditionArgs...),
)
@ -44,5 +46,5 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) {
if !gstr.ContainsI(conditionStr, " WHERE ") {
return nil, gerror.New("there should be WHERE condition statement for DELETE operation")
}
return m.db.DoDelete(m.getLink(true), m.tables, conditionStr, conditionArgs...)
return m.db.GetCore().DoDelete(m.GetCtx(), m.getLink(true), m.tables, conditionStr, conditionArgs...)
}

View File

@ -114,7 +114,7 @@ func (m *Model) GetFieldsStr(prefix ...string) string {
}
newFields += prefixStr + k
}
newFields = m.db.QuoteString(newFields)
newFields = m.db.GetCore().QuoteString(newFields)
return newFields
}
@ -158,7 +158,7 @@ func (m *Model) GetFieldsExStr(fields string, prefix ...string) string {
}
newFields += prefixStr + k
}
newFields = m.db.QuoteString(newFields)
newFields = m.db.GetCore().QuoteString(newFields)
return newFields
}

View File

@ -8,12 +8,13 @@ package gdb
import (
"database/sql"
"reflect"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/os/gtime"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/util/gconv"
"github.com/gogf/gf/util/gutil"
"reflect"
)
// Batch sets the batch operation number for the model.
@ -194,7 +195,8 @@ func (m *Model) doInsertWithOption(option int) (result sql.Result, err error) {
list[k] = v
}
}
return m.db.DoBatchInsert(
return m.db.GetCore().DoBatchInsert(
m.GetCtx(),
m.getLink(true),
m.tables,
newData,
@ -219,7 +221,8 @@ func (m *Model) doInsertWithOption(option int) (result sql.Result, err error) {
data[fieldNameUpdate] = nowString
}
}
return m.db.DoInsert(
return m.db.GetCore().DoInsert(
m.GetCtx(),
m.getLink(true),
m.tables,
newData,

View File

@ -8,6 +8,7 @@ package gdb
import (
"fmt"
"github.com/gogf/gf/text/gstr"
)
@ -72,13 +73,13 @@ func (m *Model) doJoin(operator string, table ...string) *Model {
joinStr = "(" + joinStr + ")"
}
} else {
joinStr = m.db.QuotePrefixTableName(table[0])
joinStr = m.db.GetCore().QuotePrefixTableName(table[0])
}
}
if len(table) > 2 {
model.tables += fmt.Sprintf(
" %s JOIN %s AS %s ON (%s)",
operator, joinStr, m.db.QuoteWord(table[1]), table[2],
operator, joinStr, m.db.GetCore().QuoteWord(table[1]), table[2],
)
} else if len(table) == 2 {
model.tables += fmt.Sprintf(

View File

@ -8,13 +8,14 @@ package gdb
import (
"fmt"
"reflect"
"github.com/gogf/gf/container/gset"
"github.com/gogf/gf/container/gvar"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/internal/json"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/util/gconv"
"reflect"
)
// Select is alias of Model.All.
@ -66,7 +67,7 @@ func (m *Model) getFieldsFiltered() string {
if m.fieldsEx == "" {
// No filtering.
if !gstr.Contains(m.fields, ".") && !gstr.Contains(m.fields, " ") {
return m.db.QuoteString(m.fields)
return m.db.GetCore().QuoteString(m.fields)
}
return m.fields
}
@ -105,7 +106,7 @@ func (m *Model) getFieldsFiltered() string {
if len(newFields) > 0 {
newFields += ","
}
newFields += m.db.QuoteWord(k)
newFields += m.db.GetCore().QuoteWord(k)
}
return newFields
}
@ -367,7 +368,7 @@ func (m *Model) Min(column string) (float64, error) {
if len(column) == 0 {
return 0, nil
}
value, err := m.Fields(fmt.Sprintf(`MIN(%s)`, m.db.QuoteWord(column))).Value()
value, err := m.Fields(fmt.Sprintf(`MIN(%s)`, m.db.GetCore().QuoteWord(column))).Value()
if err != nil {
return 0, err
}
@ -379,7 +380,7 @@ func (m *Model) Max(column string) (float64, error) {
if len(column) == 0 {
return 0, nil
}
value, err := m.Fields(fmt.Sprintf(`MAX(%s)`, m.db.QuoteWord(column))).Value()
value, err := m.Fields(fmt.Sprintf(`MAX(%s)`, m.db.GetCore().QuoteWord(column))).Value()
if err != nil {
return 0, err
}
@ -391,7 +392,7 @@ func (m *Model) Avg(column string) (float64, error) {
if len(column) == 0 {
return 0, nil
}
value, err := m.Fields(fmt.Sprintf(`AVG(%s)`, m.db.QuoteWord(column))).Value()
value, err := m.Fields(fmt.Sprintf(`AVG(%s)`, m.db.GetCore().QuoteWord(column))).Value()
if err != nil {
return 0, err
}
@ -403,7 +404,7 @@ func (m *Model) Sum(column string) (float64, error) {
if len(column) == 0 {
return 0, nil
}
value, err := m.Fields(fmt.Sprintf(`SUM(%s)`, m.db.QuoteWord(column))).Value()
value, err := m.Fields(fmt.Sprintf(`SUM(%s)`, m.db.GetCore().QuoteWord(column))).Value()
if err != nil {
return 0, err
}
@ -474,7 +475,7 @@ func (m *Model) FindScan(pointer interface{}, where ...interface{}) error {
// doGetAllBySql does the select statement on the database.
func (m *Model) doGetAllBySql(sql string, args ...interface{}) (result Result, err error) {
cacheKey := ""
cacheObj := m.db.GetCache().Ctx(m.db.GetCtx())
cacheObj := m.db.GetCache().Ctx(m.GetCtx())
// Retrieve from cache.
if m.cacheEnabled && m.tx == nil {
cacheKey = m.cacheName
@ -496,7 +497,7 @@ func (m *Model) doGetAllBySql(sql string, args ...interface{}) (result Result, e
}
}
}
result, err = m.db.DoGetAll(m.getLink(false), sql, m.mergeArguments(args)...)
result, err = m.db.GetCore().DoGetAll(m.GetCtx(), m.getLink(false), sql, m.mergeArguments(args)...)
// Cache the result.
if cacheKey != "" && err == nil {
if m.cacheDuration < 0 {

View File

@ -140,7 +140,7 @@ func (m *Model) getConditionForSoftDeleting() string {
}
// Only one table.
if fieldName := m.getSoftFieldNameDeleted(); fieldName != "" {
return fmt.Sprintf(`%s IS NULL`, m.db.QuoteWord(fieldName))
return fmt.Sprintf(`%s IS NULL`, m.db.GetCore().QuoteWord(fieldName))
}
return ""
}
@ -163,12 +163,12 @@ func (m *Model) getConditionOfTableStringForSoftDeleting(s string) string {
return ""
}
if len(array1) >= 3 {
return fmt.Sprintf(`%s.%s IS NULL`, m.db.QuoteWord(array1[2]), m.db.QuoteWord(field))
return fmt.Sprintf(`%s.%s IS NULL`, m.db.GetCore().QuoteWord(array1[2]), m.db.GetCore().QuoteWord(field))
}
if len(array1) >= 2 {
return fmt.Sprintf(`%s.%s IS NULL`, m.db.QuoteWord(array1[1]), m.db.QuoteWord(field))
return fmt.Sprintf(`%s.%s IS NULL`, m.db.GetCore().QuoteWord(array1[1]), m.db.GetCore().QuoteWord(field))
}
return fmt.Sprintf(`%s.%s IS NULL`, m.db.QuoteWord(table), m.db.QuoteWord(field))
return fmt.Sprintf(`%s.%s IS NULL`, m.db.GetCore().QuoteWord(table), m.db.GetCore().QuoteWord(field))
}
// getPrimaryTableName parses and returns the primary table name.

View File

@ -9,12 +9,13 @@ package gdb
import (
"database/sql"
"fmt"
"reflect"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/os/gtime"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/util/gconv"
"github.com/gogf/gf/util/gutil"
"reflect"
)
// Update does "UPDATE ... " statement for the model.
@ -81,7 +82,8 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro
if !gstr.ContainsI(conditionStr, " WHERE ") {
return nil, gerror.New("there should be WHERE condition statement for UPDATE operation")
}
return m.db.DoUpdate(
return m.db.GetCore().DoUpdate(
m.GetCtx(),
m.getLink(true),
m.tables,
newData,

View File

@ -8,6 +8,8 @@ package gdb
import (
"fmt"
"time"
"github.com/gogf/gf/container/gset"
"github.com/gogf/gf/internal/empty"
"github.com/gogf/gf/os/gtime"
@ -15,7 +17,6 @@ import (
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/util/gconv"
"github.com/gogf/gf/util/gutil"
"time"
)
// TableFields retrieves and returns the fields information of specified table of current
@ -29,12 +30,12 @@ func (m *Model) TableFields(table string, schema ...string) (fields map[string]*
if m.tx != nil {
link = m.tx.tx
} else {
link, err = m.db.GetSlave(schema...)
link, err = m.db.GetCore().GetSlave(schema...)
if err != nil {
return
}
}
return m.db.TableFields(link, table, schema...)
return m.db.TableFields(m.GetCtx(), link, table, schema...)
}
// getModel creates and returns a cloned model of current model if `safe` is true, or else it returns
@ -111,7 +112,7 @@ func (m *Model) filterDataForInsertOrUpdate(data interface{}) (interface{}, erro
// Note that, it does not filter list item, which is also type of map, for "omit empty" feature.
func (m *Model) doMappingAndFilterForInsertOrUpdateDataMap(data Map, allowOmitEmpty bool) (Map, error) {
var err error
data, err = m.db.mappingAndFilterData(m.schema, m.tables, data, m.filter)
data, err = m.db.GetCore().mappingAndFilterData(m.schema, m.tables, data, m.filter)
if err != nil {
return nil, err
}
@ -187,13 +188,13 @@ func (m *Model) getLink(master bool) Link {
}
switch linkType {
case linkTypeMaster:
link, err := m.db.GetMaster(m.schema)
link, err := m.db.GetCore().GetMaster(m.schema)
if err != nil {
panic(err)
}
return link
case linkTypeSlave:
link, err := m.db.GetSlave(m.schema)
link, err := m.db.GetCore().GetSlave(m.schema)
if err != nil {
panic(err)
}

View File

@ -8,12 +8,13 @@ package gdb
import (
"fmt"
"reflect"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/internal/structs"
"github.com/gogf/gf/internal/utils"
"github.com/gogf/gf/text/gregex"
"github.com/gogf/gf/text/gstr"
"reflect"
)
// With creates and returns an ORM model based on meta data of given object.
@ -38,7 +39,7 @@ func (m *Model) With(objects ...interface{}) *Model {
model := m.getModel()
for _, object := range objects {
if m.tables == "" {
m.tables = m.db.QuotePrefixTableName(getTableNameFromOrmTag(object))
m.tables = m.db.GetCore().QuotePrefixTableName(getTableNameFromOrmTag(object))
return model
}
model.withArray = append(model.withArray, object)

View File

@ -9,6 +9,7 @@ package gdb
import (
"context"
"database/sql"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/os/gtime"
)
@ -72,9 +73,10 @@ func (s *Stmt) doStmtCommit(stmtType string, ctx context.Context, args ...interf
Group: s.core.db.GetGroup(),
}
)
// Tracing and logging.
s.core.addSqlToTracing(ctx, sqlObj)
if s.core.db.GetDebug() {
s.core.writeSqlToLogger(sqlObj)
s.core.writeSqlToLogger(ctx, sqlObj)
}
return result, err
}

View File

@ -12,50 +12,184 @@ import (
"fmt"
"reflect"
"github.com/gogf/gf/container/gtype"
"github.com/gogf/gf/os/gtime"
"github.com/gogf/gf/util/gconv"
"github.com/gogf/gf/util/guid"
"github.com/gogf/gf/text/gregex"
)
// TX is the struct for transaction management.
type TX struct {
db DB // db is the current gdb database manager.
tx *sql.Tx // tx is the raw and underlying transaction manager.
master *sql.DB // master is the raw and underlying database manager.
transactionCount int // transactionCount marks the times that Begins.
db DB // db is the current gdb database manager.
tx *sql.Tx // tx is the raw and underlying transaction manager.
ctx context.Context // ctx is the context for this transaction only.
master *sql.DB // master is the raw and underlying database manager.
transactionId string // transactionId is an unique id generated by this object for this transaction.
transactionCount int // transactionCount marks the times that Begins.
}
const (
transactionPointerPrefix = "transaction"
contextTransactionKey = "TransactionObject"
transactionPointerPrefix = "transaction"
contextTransactionKeyPrefix = "TransactionObjectForGroup_"
transactionIdForLoggerCtx = "TransactionId"
)
var (
transactionIdGenerator = gtype.NewUint64()
)
// Begin starts and returns the transaction object.
// You should call Commit or Rollback functions of the transaction object
// if you no longer use the transaction. Commit or Rollback functions will also
// close the transaction automatically.
func (c *Core) Begin() (tx *TX, err error) {
return c.doBeginCtx(c.GetCtx())
}
func (c *Core) doBeginCtx(ctx context.Context) (*TX, error) {
if master, err := c.db.Master(); err != nil {
return nil, err
} else {
var (
tx *TX
sqlStr = "BEGIN"
mTime1 = gtime.TimestampMilli()
rawTx, err = master.Begin()
mTime2 = gtime.TimestampMilli()
sqlObj = &Sql{
Sql: sqlStr,
Type: "DB.Begin",
Args: nil,
Format: sqlStr,
Error: err,
Start: mTime1,
End: mTime2,
Group: c.db.GetGroup(),
}
)
if err == nil {
tx = &TX{
db: c.db,
tx: rawTx,
ctx: context.WithValue(ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1)),
master: master,
transactionId: guid.S(),
}
ctx = tx.ctx
}
// Tracing and logging.
c.addSqlToTracing(ctx, sqlObj)
if c.db.GetDebug() {
c.writeSqlToLogger(ctx, sqlObj)
}
return tx, err
}
}
// Transaction wraps the transaction logic using function `f`.
// It rollbacks the transaction and returns the error from function `f` if
// it returns non-nil error. It commits the transaction and returns nil if
// function `f` returns nil.
//
// Note that, you should not Commit or Rollback the transaction in function `f`
// as it is automatically handled by this function.
func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx *TX) error) (err error) {
var (
tx *TX
)
if ctx == nil {
ctx = c.GetCtx()
}
// Check transaction object from context.
tx = TXFromCtx(ctx, c.db.GetGroup())
if tx != nil {
return tx.Transaction(ctx, f)
}
tx, err = c.doBeginCtx(ctx)
if err != nil {
return err
}
// Inject transaction object into context.
tx.ctx = WithTX(tx.ctx, tx)
defer func() {
if err == nil {
if e := recover(); e != nil {
err = fmt.Errorf("%v", e)
}
}
if err != nil {
if e := tx.Rollback(); e != nil {
err = e
}
} else {
if e := tx.Commit(); e != nil {
err = e
}
}
}()
err = f(tx.ctx, tx)
return
}
// WithTX injects given transaction object into context and returns a new context.
func WithTX(ctx context.Context, tx *TX) context.Context {
return context.WithValue(ctx, contextTransactionKey, tx)
if tx == nil {
return ctx
}
// Check repeat injection from given.
group := tx.db.GetGroup()
if tx := TXFromCtx(ctx, group); tx != nil && tx.db.GetGroup() == group {
return ctx
}
dbCtx := tx.db.GetCtx()
if tx := TXFromCtx(dbCtx, group); tx != nil && tx.db.GetGroup() == group {
return dbCtx
}
// Inject transaction object and id into context.
ctx = context.WithValue(ctx, transactionKeyForContext(group), tx)
return ctx
}
// TXFromCtx retrieves and returns transaction object from context.
// It is usually used in nested transaction feature, and it returns nil if it is not set previously.
func TXFromCtx(ctx context.Context) *TX {
func TXFromCtx(ctx context.Context, group string) *TX {
if ctx == nil {
return nil
}
v := ctx.Value(contextTransactionKey)
v := ctx.Value(transactionKeyForContext(group))
if v != nil {
return v.(*TX)
tx := v.(*TX)
tx.ctx = ctx
return tx
}
return nil
}
// transactionKeyForContext forms and returns a string for storing transaction object of certain database group into context.
func transactionKeyForContext(group string) string {
return contextTransactionKeyPrefix + group
}
// transactionKeyForNestedPoint forms and returns the transaction key at current save point.
func (tx *TX) transactionKeyForNestedPoint() string {
return tx.db.GetCore().QuoteWord(transactionPointerPrefix + gconv.String(tx.transactionCount))
}
// Ctx sets the context for current transaction.
func (tx *TX) Ctx(ctx context.Context) *TX {
tx.ctx = ctx
return tx
}
// Commit commits current transaction.
// Note that it releases previous saved transaction point if it's in a nested transaction procedure,
// or else it commits the hole transaction.
func (tx *TX) Commit() error {
if tx.transactionCount > 0 {
tx.transactionCount--
_, err := tx.Exec("RELEASE SAVEPOINT " + tx.transactionKey())
_, err := tx.Exec("RELEASE SAVEPOINT " + tx.transactionKeyForNestedPoint())
return err
}
var (
@ -74,9 +208,9 @@ func (tx *TX) Commit() error {
Group: tx.db.GetGroup(),
}
)
tx.db.addSqlToTracing(tx.db.GetCtx(), sqlObj)
tx.db.GetCore().addSqlToTracing(tx.ctx, sqlObj)
if tx.db.GetDebug() {
tx.db.writeSqlToLogger(sqlObj)
tx.db.GetCore().writeSqlToLogger(tx.ctx, sqlObj)
}
return err
}
@ -87,7 +221,7 @@ func (tx *TX) Commit() error {
func (tx *TX) Rollback() error {
if tx.transactionCount > 0 {
tx.transactionCount--
_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.transactionKey())
_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.transactionKeyForNestedPoint())
return err
}
var (
@ -106,16 +240,16 @@ func (tx *TX) Rollback() error {
Group: tx.db.GetGroup(),
}
)
tx.db.addSqlToTracing(tx.db.GetCtx(), sqlObj)
tx.db.GetCore().addSqlToTracing(tx.ctx, sqlObj)
if tx.db.GetDebug() {
tx.db.writeSqlToLogger(sqlObj)
tx.db.GetCore().writeSqlToLogger(tx.ctx, sqlObj)
}
return err
}
// Begin starts a nested transaction procedure.
func (tx *TX) Begin() error {
_, err := tx.Exec("SAVEPOINT " + tx.transactionKey())
_, err := tx.Exec("SAVEPOINT " + tx.transactionKeyForNestedPoint())
if err != nil {
return err
}
@ -126,22 +260,17 @@ func (tx *TX) Begin() error {
// SavePoint performs `SAVEPOINT xxx` SQL statement that saves transaction at current point.
// The parameter `point` specifies the point name that will be saved to server.
func (tx *TX) SavePoint(point string) error {
_, err := tx.Exec("SAVEPOINT " + tx.db.QuoteWord(point))
_, err := tx.Exec("SAVEPOINT " + tx.db.GetCore().QuoteWord(point))
return err
}
// RollbackTo performs `ROLLBACK TO SAVEPOINT xxx` SQL statement that rollbacks to specified saved transaction.
// The parameter `point` specifies the point name that was saved previously.
func (tx *TX) RollbackTo(point string) error {
_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.db.QuoteWord(point))
_, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.db.GetCore().QuoteWord(point))
return err
}
// transactionKey forms and returns the transaction key at current save point.
func (tx *TX) transactionKey() string {
return tx.db.QuoteWord(transactionPointerPrefix + gconv.String(tx.transactionCount))
}
// Transaction wraps the transaction logic using function `f`.
// It rollbacks the transaction and returns the error from function `f` if
// it returns non-nil error. It commits the transaction and returns nil if
@ -150,10 +279,13 @@ func (tx *TX) transactionKey() string {
// Note that, you should not Commit or Rollback the transaction in function `f`
// as it is automatically handled by this function.
func (tx *TX) Transaction(ctx context.Context, f func(ctx context.Context, tx *TX) error) (err error) {
if ctx != nil {
tx.ctx = ctx
}
// Check transaction object from context.
if TXFromCtx(ctx) == nil {
if TXFromCtx(tx.ctx, tx.db.GetGroup()) == nil {
// Inject transaction object into context.
ctx = WithTX(ctx, tx)
tx.ctx = WithTX(tx.ctx, tx)
}
err = tx.Begin()
if err != nil {
@ -175,20 +307,20 @@ func (tx *TX) Transaction(ctx context.Context, f func(ctx context.Context, tx *T
}
}
}()
err = f(ctx, tx)
err = f(tx.ctx, tx)
return
}
// Query does query operation on transaction.
// See Core.Query.
func (tx *TX) Query(sql string, args ...interface{}) (rows *sql.Rows, err error) {
return tx.db.DoQuery(tx.tx, sql, args...)
return tx.db.GetCore().DoQuery(tx.ctx, tx.tx, sql, args...)
}
// Exec does none query operation on transaction.
// See Core.Exec.
func (tx *TX) Exec(sql string, args ...interface{}) (sql.Result, error) {
return tx.db.DoExec(tx.tx, sql, args...)
return tx.db.GetCore().DoExec(tx.ctx, tx.tx, sql, args...)
}
// Prepare creates a prepared statement for later queries or executions.
@ -197,7 +329,7 @@ func (tx *TX) Exec(sql string, args ...interface{}) (sql.Result, error) {
// The caller must call the statement's Close method
// when the statement is no longer needed.
func (tx *TX) Prepare(sql string) (*Stmt, error) {
return tx.db.DoPrepare(tx.tx, sql)
return tx.db.GetCore().DoPrepare(tx.ctx, tx.tx, sql)
}
// GetAll queries and returns data records from database.
@ -207,7 +339,7 @@ func (tx *TX) GetAll(sql string, args ...interface{}) (Result, error) {
return nil, err
}
defer rows.Close()
return tx.db.convertRowsToResult(rows)
return tx.db.GetCore().convertRowsToResult(rows)
}
// GetOne queries and returns one record from database.
@ -248,8 +380,8 @@ func (tx *TX) GetStructs(objPointerSlice interface{}, sql string, args ...interf
// If parameter `pointer` is type of struct pointer, it calls GetStruct internally for
// the conversion. If parameter `pointer` is type of slice, it calls GetStructs internally
// for conversion.
func (tx *TX) GetScan(objPointer interface{}, sql string, args ...interface{}) error {
t := reflect.TypeOf(objPointer)
func (tx *TX) GetScan(pointer interface{}, sql string, args ...interface{}) error {
t := reflect.TypeOf(pointer)
k := t.Kind()
if k != reflect.Ptr {
return fmt.Errorf("params should be type of pointer, but got: %v", k)
@ -257,9 +389,9 @@ func (tx *TX) GetScan(objPointer interface{}, sql string, args ...interface{}) e
k = t.Elem().Kind()
switch k {
case reflect.Array, reflect.Slice:
return tx.db.GetStructs(objPointer, sql, args...)
return tx.GetStructs(pointer, sql, args...)
case reflect.Struct:
return tx.db.GetStruct(objPointer, sql, args...)
return tx.GetStruct(pointer, sql, args...)
default:
return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k)
}
@ -302,9 +434,9 @@ func (tx *TX) GetCount(sql string, args ...interface{}) (int, error) {
// The parameter `batch` specifies the batch operation count when given data is slice.
func (tx *TX) Insert(table string, data interface{}, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Data(data).Batch(batch[0]).Insert()
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Insert()
}
return tx.Model(table).Data(data).Insert()
return tx.Model(table).Ctx(tx.ctx).Data(data).Insert()
}
// InsertIgnore does "INSERT IGNORE INTO ..." statement for the table.
@ -318,17 +450,17 @@ func (tx *TX) Insert(table string, data interface{}, batch ...int) (sql.Result,
// The parameter `batch` specifies the batch operation count when given data is slice.
func (tx *TX) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Data(data).Batch(batch[0]).InsertIgnore()
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertIgnore()
}
return tx.Model(table).Data(data).InsertIgnore()
return tx.Model(table).Ctx(tx.ctx).Data(data).InsertIgnore()
}
// InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
func (tx *TX) InsertAndGetId(table string, data interface{}, batch ...int) (int64, error) {
if len(batch) > 0 {
return tx.Model(table).Data(data).Batch(batch[0]).InsertAndGetId()
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertAndGetId()
}
return tx.Model(table).Data(data).InsertAndGetId()
return tx.Model(table).Ctx(tx.ctx).Data(data).InsertAndGetId()
}
// Replace does "REPLACE INTO ..." statement for the table.
@ -345,9 +477,9 @@ func (tx *TX) InsertAndGetId(table string, data interface{}, batch ...int) (int6
// `batch` specifies the batch operation count.
func (tx *TX) Replace(table string, data interface{}, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Data(data).Batch(batch[0]).Replace()
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Replace()
}
return tx.Model(table).Data(data).Replace()
return tx.Model(table).Ctx(tx.ctx).Data(data).Replace()
}
// Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table.
@ -363,45 +495,45 @@ func (tx *TX) Replace(table string, data interface{}, batch ...int) (sql.Result,
// `batch` specifies the batch operation count.
func (tx *TX) Save(table string, data interface{}, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Data(data).Batch(batch[0]).Save()
return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Save()
}
return tx.Model(table).Data(data).Save()
return tx.Model(table).Ctx(tx.ctx).Data(data).Save()
}
// BatchInsert batch inserts data.
// The parameter `list` must be type of slice of map or struct.
func (tx *TX) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Data(list).Batch(batch[0]).Insert()
return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).Insert()
}
return tx.Model(table).Data(list).Insert()
return tx.Model(table).Ctx(tx.ctx).Data(list).Insert()
}
// BatchInsertIgnore batch inserts data with ignore option.
// The parameter `list` must be type of slice of map or struct.
func (tx *TX) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Data(list).Batch(batch[0]).InsertIgnore()
return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).InsertIgnore()
}
return tx.Model(table).Data(list).InsertIgnore()
return tx.Model(table).Ctx(tx.ctx).Data(list).InsertIgnore()
}
// BatchReplace batch replaces data.
// The parameter `list` must be type of slice of map or struct.
func (tx *TX) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Data(list).Batch(batch[0]).Replace()
return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).Replace()
}
return tx.Model(table).Data(list).Replace()
return tx.Model(table).Ctx(tx.ctx).Data(list).Replace()
}
// BatchSave batch replaces data.
// The parameter `list` must be type of slice of map or struct.
func (tx *TX) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) {
if len(batch) > 0 {
return tx.Model(table).Data(list).Batch(batch[0]).Save()
return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).Save()
}
return tx.Model(table).Data(list).Save()
return tx.Model(table).Ctx(tx.ctx).Data(list).Save()
}
// Update does "UPDATE ... " statement for the table.
@ -419,7 +551,7 @@ func (tx *TX) BatchSave(table string, list interface{}, batch ...int) (sql.Resul
// "age IN(?,?)", 18, 50
// User{ Id : 1, UserName : "john"}
func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
return tx.Model(table).Data(data).Where(condition, args...).Update()
return tx.Model(table).Ctx(tx.ctx).Data(data).Where(condition, args...).Update()
}
// Delete does "DELETE FROM ... " statement for the table.
@ -434,5 +566,5 @@ func (tx *TX) Update(table string, data interface{}, condition interface{}, args
// "age IN(?,?)", 18, 50
// User{ Id : 1, UserName : "john"}
func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
return tx.Model(table).Where(condition, args...).Delete()
return tx.Model(table).Ctx(tx.ctx).Where(condition, args...).Delete()
}

View File

@ -7,11 +7,13 @@
package gdb_test
import (
"context"
"testing"
"github.com/gogf/gf/container/gtype"
"github.com/gogf/gf/database/gdb"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/test/gtest"
"testing"
)
// MyDriver is a custom database driver, which is used for testing only.
@ -41,9 +43,9 @@ func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) {
// HandleSqlBeforeCommit handles the sql before posts it to database.
// It here overwrites the same method of gdb.DriverMysql and makes some custom changes.
func (d *MyDriver) HandleSqlBeforeCommit(link gdb.Link, sql string, args []interface{}) (string, []interface{}) {
func (d *MyDriver) HandleSqlBeforeCommit(ctx context.Context, link gdb.Link, sql string, args []interface{}) (string, []interface{}) {
latestSqlString.Set(sql)
return d.DriverMysql.HandleSqlBeforeCommit(link, sql, args)
return d.DriverMysql.HandleSqlBeforeCommit(ctx, link, sql, args)
}
func init() {

View File

@ -2910,13 +2910,13 @@ func Test_Model_HasTable(t *testing.T) {
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
result, err := db.HasTable(table)
result, err := db.GetCore().HasTable(table)
t.Assert(result, true)
t.AssertNil(err)
})
gtest.C(t, func(t *gtest.T) {
result, err := db.HasTable("table12321")
result, err := db.GetCore().HasTable("table12321")
t.Assert(result, false)
t.AssertNil(err)
})

View File

@ -833,7 +833,7 @@ func Test_Transaction_Nested_Begin_Rollback_Commit(t *testing.T) {
func Test_Transaction_Nested_TX_Transaction_UseTX(t *testing.T) {
table := createTable()
defer dropTable(table)
db.SetDebug(true)
gtest.C(t, func(t *gtest.T) {
var (
err error
@ -897,7 +897,7 @@ func Test_Transaction_Nested_TX_Transaction_UseTX(t *testing.T) {
func Test_Transaction_Nested_TX_Transaction_UseDB(t *testing.T) {
table := createTable()
defer dropTable(table)
db.SetDebug(true)
gtest.C(t, func(t *gtest.T) {
var (
err error
@ -910,7 +910,7 @@ func Test_Transaction_Nested_TX_Transaction_UseDB(t *testing.T) {
err = db.Transaction(ctx, func(ctx context.Context, tx *gdb.TX) error {
err = db.Transaction(ctx, func(ctx context.Context, tx *gdb.TX) error {
err = db.Transaction(ctx, func(ctx context.Context, tx *gdb.TX) error {
_, err = db.Model(table).Data(g.Map{
_, err = db.Model(table).Ctx(ctx).Data(g.Map{
"id": 1,
"passport": "USER_1",
"password": "PASS_1",
@ -933,9 +933,10 @@ func Test_Transaction_Nested_TX_Transaction_UseDB(t *testing.T) {
return err
})
t.AssertNil(err)
// rollback
err = db.Transaction(ctx, func(ctx context.Context, tx *gdb.TX) error {
_, err = tx.Model(table).Data(g.Map{
_, err = tx.Model(table).Ctx(ctx).Data(g.Map{
"id": 2,
"passport": "USER_2",
"password": "PASS_2",
@ -943,6 +944,7 @@ func Test_Transaction_Nested_TX_Transaction_UseDB(t *testing.T) {
"create_time": gtime.Now().String(),
}).Insert()
t.AssertNil(err)
// panic makes this transaction rollback.
panic("error")
return err
})

View File

@ -9,14 +9,15 @@ package glog
import (
"errors"
"fmt"
"io"
"strings"
"time"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/os/gfile"
"github.com/gogf/gf/util/gconv"
"github.com/gogf/gf/util/gutil"
"io"
"strings"
"time"
)
// Config is the configuration object for logger.
@ -165,6 +166,24 @@ func (l *Logger) SetCtxKeys(keys ...interface{}) {
l.config.CtxKeys = keys
}
// AppendCtxKeys appends extra keys to logger.
// It ignores the key if it is already appended to the logger previously.
func (l *Logger) AppendCtxKeys(keys ...interface{}) {
var isExist bool
for _, key := range keys {
isExist = false
for _, ctxKey := range l.config.CtxKeys {
if ctxKey == key {
isExist = true
break
}
}
if !isExist {
l.config.CtxKeys = append(l.config.CtxKeys, key)
}
}
}
// GetCtxKeys retrieves and returns the context keys for logging.
func (l *Logger) GetCtxKeys() []interface{} {
return l.config.CtxKeys