diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 8ccb29a03..8d634359e 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -96,19 +96,6 @@ type DB interface { Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) // See Core.Update. Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) // See Core.Delete. - // =========================================================================== - // Internal APIs for CURD, which can be overwrote for custom CURD implements. - // =========================================================================== - - DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) // See Core.DoQuery. - DoGetAll(link Link, sql string, args ...interface{}) (result Result, err error) // See Core.DoGetAll. - DoExec(link Link, sql string, args ...interface{}) (result sql.Result, err error) // See Core.DoExec. - DoPrepare(link Link, sql string) (*Stmt, error) // See Core.DoPrepare. - DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) // See Core.DoInsert. - DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) // See Core.DoBatchInsert. - DoUpdate(link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoUpdate. - DoDelete(link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoDelete. - // =========================================================================== // Query APIs for convenience purpose. // =========================================================================== @@ -167,40 +154,25 @@ type DB interface { // Utility methods. // =========================================================================== - GetCtx() context.Context // See Core.GetCtx. - GetChars() (charLeft string, charRight string) // See Core.GetChars. - GetMaster(schema ...string) (*sql.DB, error) // See Core.GetMaster. - GetSlave(schema ...string) (*sql.DB, error) // See Core.GetSlave. - QuoteWord(s string) string // See Core.QuoteWord. - QuoteString(s string) string // See Core.QuoteString. - QuotePrefixTableName(table string) string // See Core.QuotePrefixTableName. - Tables(schema ...string) (tables []string, err error) // See Core.Tables. - TableFields(link Link, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields. - HasTable(name string) (bool, error) // See Core.HasTable. - FilteredLinkInfo() string // See Core.FilteredLinkInfo. + GetCtx() context.Context // See Core.GetCtx. + GetCore() *Core // See Core.GetCore + GetChars() (charLeft string, charRight string) // See Core.GetChars. + Tables(ctx context.Context, schema ...string) (tables []string, err error) // See Core.Tables. + TableFields(ctx context.Context, link Link, table string, schema ...string) (map[string]*TableField, error) // See Core.TableFields. + FilteredLinkInfo() string // See Core.FilteredLinkInfo. // HandleSqlBeforeCommit is a hook function, which deals with the sql string before // it's committed to underlying driver. The parameter `link` specifies the current // database connection operation object. You can modify the sql string `sql` and its // arguments `args` as you wish before they're committed to driver. // Also see Core.HandleSqlBeforeCommit. - HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) - - // =========================================================================== - // Internal methods, for internal usage purpose, you do not need consider it. - // =========================================================================== - - mappingAndFilterData(schema, table string, data map[string]interface{}, filter bool) (map[string]interface{}, error) // See Core.mappingAndFilterData. - convertFieldValueToLocalValue(fieldValue interface{}, fieldType string) interface{} // See Core.convertFieldValueToLocalValue. - convertRowsToResult(rows *sql.Rows) (Result, error) // See Core.convertRowsToResult. - addSqlToTracing(ctx context.Context, sql *Sql) // See Core.addSqlToTracing. - writeSqlToLogger(v *Sql) // See Core.writeSqlToLogger. + HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) } // Core is the base struct for database management. type Core struct { db DB // DB interface object. - ctx context.Context // Context for chaining operation only. + ctx context.Context // Context for chaining operation only. Do not set a default value in Core initialization. group string // Configuration group name. debug *gtype.Bool // Enable debug mode for the database, which can be changed in runtime. cache *gcache.Cache // Cache manager, SQL result cache only. diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 9dbb37ad3..14a14917c 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -15,9 +15,8 @@ import ( "strings" "github.com/gogf/gf/errors/gerror" - "github.com/gogf/gf/text/gstr" - "github.com/gogf/gf/internal/utils" + "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/container/gvar" "github.com/gogf/gf/os/gtime" @@ -25,6 +24,10 @@ import ( "github.com/gogf/gf/util/gconv" ) +func (c *Core) GetCore() *Core { + return c +} + // Ctx is a chaining function, which creates and returns a new DB that is a shallow copy // of current DB object and with given context in it. // Note that this returned DB object can be used only once, so do not assign it to @@ -33,6 +36,11 @@ func (c *Core) Ctx(ctx context.Context) DB { if ctx == nil { return c.db } + // It is already set context in previous chaining operation. + if c.ctx != nil { + return c.db + } + // It makes a shallow copy of current db and changes its context for next chaining operation. var ( err error newCore = &Core{} @@ -41,9 +49,10 @@ func (c *Core) Ctx(ctx context.Context) DB { *newCore = *c newCore.ctx = ctx newCore.db, err = driverMap[configNode.Type].New(newCore, configNode) - // Seldom error, just log it. if err != nil { - c.db.GetLogger().Ctx(ctx).Error(err) + // It is really a serious error here. + // Do not let it continue. + panic(err) } return newCore.db } @@ -60,7 +69,7 @@ func (c *Core) GetCtx() context.Context { // GetCtxTimeout returns the context and cancel function for specified timeout type. func (c *Core) GetCtxTimeout(timeoutType int, ctx context.Context) (context.Context, context.CancelFunc) { if ctx == nil { - ctx = c.db.GetCtx() + ctx = c.GetCtx() } else { ctx = context.WithValue(ctx, "WrappedByGetCtxTimeout", nil) } @@ -102,15 +111,14 @@ func (c *Core) Query(sql string, args ...interface{}) (rows *sql.Rows, err error if err != nil { return nil, err } - return c.db.DoQuery(link, sql, args...) + return c.DoQuery(c.GetCtx(), link, sql, args...) } // DoQuery commits the sql string and its arguments to underlying driver // through given link object and returns the execution result. -func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) { +func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) { sql, args = formatSql(sql, args) - sql, args = c.db.HandleSqlBeforeCommit(link, sql, args) - ctx := c.db.GetCtx() + sql, args = c.db.HandleSqlBeforeCommit(ctx, link, sql, args) if c.GetConfig().QueryTimeout > 0 { ctx, _ = context.WithTimeout(ctx, c.GetConfig().QueryTimeout) } @@ -127,9 +135,10 @@ func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Ro End: mTime2, Group: c.db.GetGroup(), } + // Tracing and logging. c.addSqlToTracing(ctx, sqlObj) if c.db.GetDebug() { - c.writeSqlToLogger(sqlObj) + c.writeSqlToLogger(ctx, sqlObj) } if err == nil { return rows, nil @@ -146,15 +155,14 @@ func (c *Core) Exec(sql string, args ...interface{}) (result sql.Result, err err if err != nil { return nil, err } - return c.db.DoExec(link, sql, args...) + return c.DoExec(c.GetCtx(), link, sql, args...) } // DoExec commits the sql string and its arguments to underlying driver // through given link object and returns the execution result. -func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Result, err error) { +func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interface{}) (result sql.Result, err error) { sql, args = formatSql(sql, args) - sql, args = c.db.HandleSqlBeforeCommit(link, sql, args) - ctx := c.db.GetCtx() + sql, args = c.db.HandleSqlBeforeCommit(ctx, link, sql, args) if c.GetConfig().ExecTimeout > 0 { var cancelFunc context.CancelFunc ctx, cancelFunc = context.WithTimeout(ctx, c.GetConfig().ExecTimeout) @@ -178,9 +186,10 @@ func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Re End: mTime2, Group: c.db.GetGroup(), } + // Tracing and logging. c.addSqlToTracing(ctx, sqlObj) if c.db.GetDebug() { - c.writeSqlToLogger(sqlObj) + c.writeSqlToLogger(ctx, sqlObj) } return result, formatError(err, sql, args...) } @@ -207,12 +216,11 @@ func (c *Core) Prepare(sql string, execOnMaster ...bool) (*Stmt, error) { return nil, err } } - return c.db.DoPrepare(link, sql) + return c.DoPrepare(c.GetCtx(), link, sql) } -// doPrepare calls prepare function on given link object and returns the statement object. -func (c *Core) DoPrepare(link Link, sql string) (*Stmt, error) { - ctx := c.db.GetCtx() +// DoPrepare calls prepare function on given link object and returns the statement object. +func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) { if c.GetConfig().PrepareTimeout > 0 { // DO NOT USE cancel function in prepare statement. ctx, _ = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout) @@ -232,9 +240,10 @@ func (c *Core) DoPrepare(link Link, sql string) (*Stmt, error) { Group: c.db.GetGroup(), } ) + // Tracing and logging. c.addSqlToTracing(ctx, sqlObj) if c.db.GetDebug() { - c.writeSqlToLogger(sqlObj) + c.writeSqlToLogger(ctx, sqlObj) } return &Stmt{ Stmt: stmt, @@ -245,23 +254,23 @@ func (c *Core) DoPrepare(link Link, sql string) (*Stmt, error) { // GetAll queries and returns data records from database. func (c *Core) GetAll(sql string, args ...interface{}) (Result, error) { - return c.db.DoGetAll(nil, sql, args...) + return c.DoGetAll(c.GetCtx(), nil, sql, args...) } // DoGetAll queries and returns data records from database. -func (c *Core) DoGetAll(link Link, sql string, args ...interface{}) (result Result, err error) { +func (c *Core) DoGetAll(ctx context.Context, link Link, sql string, args ...interface{}) (result Result, err error) { if link == nil { link, err = c.db.Slave() if err != nil { return nil, err } } - rows, err := c.db.DoQuery(link, sql, args...) + rows, err := c.DoQuery(ctx, link, sql, args...) if err != nil || rows == nil { return nil, err } defer rows.Close() - return c.db.convertRowsToResult(rows) + return c.convertRowsToResult(rows) } // GetOne queries and returns one record from database. @@ -279,7 +288,7 @@ func (c *Core) GetOne(sql string, args ...interface{}) (Record, error) { // GetArray queries and returns data values as slice from database. // Note that if there are multiple columns in the result, it returns just one column values randomly. func (c *Core) GetArray(sql string, args ...interface{}) ([]Value, error) { - all, err := c.db.DoGetAll(nil, sql, args...) + all, err := c.DoGetAll(c.GetCtx(), nil, sql, args...) if err != nil { return nil, err } @@ -374,91 +383,6 @@ func (c *Core) PingSlave() error { } } -// Begin starts and returns the transaction object. -// You should call Commit or Rollback functions of the transaction object -// if you no longer use the transaction. Commit or Rollback functions will also -// close the transaction automatically. -func (c *Core) Begin() (*TX, error) { - if master, err := c.db.Master(); err != nil { - return nil, err - } else { - //ctx := c.db.GetCtx() - //if c.GetConfig().TranTimeout > 0 { - // var cancelFunc context.CancelFunc - // ctx, cancelFunc = context.WithTimeout(ctx, c.GetConfig().TranTimeout) - // defer cancelFunc() - //} - var ( - sqlStr = "BEGIN" - mTime1 = gtime.TimestampMilli() - rawTx, err = master.Begin() - mTime2 = gtime.TimestampMilli() - sqlObj = &Sql{ - Sql: sqlStr, - Type: "DB.Begin", - Args: nil, - Format: sqlStr, - Error: err, - Start: mTime1, - End: mTime2, - Group: c.db.GetGroup(), - } - ) - c.db.addSqlToTracing(c.db.GetCtx(), sqlObj) - if c.db.GetDebug() { - c.db.writeSqlToLogger(sqlObj) - } - if err == nil { - return &TX{ - db: c.db, - tx: rawTx, - master: master, - }, nil - } - return nil, err - } -} - -// Transaction wraps the transaction logic using function `f`. -// It rollbacks the transaction and returns the error from function `f` if -// it returns non-nil error. It commits the transaction and returns nil if -// function `f` returns nil. -// -// Note that, you should not Commit or Rollback the transaction in function `f` -// as it is automatically handled by this function. -func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx *TX) error) (err error) { - var tx *TX - // Check transaction object from context. - tx = TXFromCtx(ctx) - if tx != nil { - return tx.Transaction(ctx, f) - } - tx, err = c.db.Begin() - if err != nil { - return err - } - // Inject transaction object into context. - ctx = WithTX(ctx, tx) - defer func() { - if err == nil { - if e := recover(); e != nil { - err = fmt.Errorf("%v", e) - } - } - if err != nil { - if e := tx.Rollback(); e != nil { - err = e - } - } else { - if e := tx.Commit(); e != nil { - err = e - } - } - }() - err = f(ctx, tx) - return -} - // Insert does "INSERT INTO ..." statement for the table. // If there's already one unique record of the data in the table, it returns error. // @@ -536,7 +460,7 @@ func (c *Core) Save(table string, data interface{}, batch ...int) (sql.Result, e return c.Model(table).Data(data).Save() } -// doInsert inserts or updates data for given table. +// DoInsert inserts or updates data for given table. // This function is usually used for custom interface definition, you do not need call it manually. // The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. // Eg: @@ -548,8 +472,8 @@ func (c *Core) Save(table string, data interface{}, batch ...int) (sql.Result, e // 1: replace: if there's unique/primary key in the data, it deletes it from table and inserts a new one; // 2: save: if there's unique/primary key in the data, it updates it or else inserts a new one; // 3: ignore: if there's unique/primary key in the data, it ignores the inserting; -func (c *Core) DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) { - table = c.db.QuotePrefixTableName(table) +func (c *Core) DoInsert(ctx context.Context, link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) { + table = c.QuotePrefixTableName(table) var ( fields []string values []string @@ -564,10 +488,10 @@ func (c *Core) DoInsert(link Link, table string, data interface{}, option int, b } switch reflectKind { case reflect.Slice, reflect.Array: - return c.db.DoBatchInsert(link, table, data, option, batch...) + return c.DoBatchInsert(ctx, link, table, data, option, batch...) case reflect.Struct: if _, ok := data.(apiInterfaces); ok { - return c.db.DoBatchInsert(link, table, data, option, batch...) + return c.DoBatchInsert(ctx, link, table, data, option, batch...) } else { dataMap = ConvertDataForTableRecord(data) } @@ -616,15 +540,11 @@ func (c *Core) DoInsert(link Link, table string, data interface{}, option int, b return nil, err } } - return c.db.DoExec( - link, - fmt.Sprintf( - "%s INTO %s(%s) VALUES(%s) %s", - operation, table, strings.Join(fields, ","), - strings.Join(values, ","), updateStr, - ), - params..., - ) + return c.DoExec(ctx, link, fmt.Sprintf( + "%s INTO %s(%s) VALUES(%s) %s", + operation, table, strings.Join(fields, ","), + strings.Join(values, ","), updateStr, + ), params...) } // BatchInsert batch inserts data. @@ -665,8 +585,8 @@ func (c *Core) BatchSave(table string, list interface{}, batch ...int) (sql.Resu // DoBatchInsert batch inserts/replaces/saves data. // This function is usually used for custom interface definition, you do not need call it manually. -func (c *Core) DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { - table = c.db.QuotePrefixTableName(table) +func (c *Core) DoBatchInsert(ctx context.Context, link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { + table = c.QuotePrefixTableName(table) var ( keys []string // Field names. values []string // Value holder string array, like: (?,?,?) @@ -777,16 +697,12 @@ func (c *Core) DoBatchInsert(link Link, table string, list interface{}, option i } valueHolder = append(valueHolder, "("+gstr.Join(values, ",")+")") if len(valueHolder) == batchNum || (i == listMapLen-1 && len(valueHolder) > 0) { - r, err := c.db.DoExec( - link, - fmt.Sprintf( - "%s INTO %s(%s) VALUES%s %s", - operation, table, keysStr, - gstr.Join(valueHolder, ","), - updateStr, - ), - params..., - ) + r, err := c.DoExec(ctx, link, fmt.Sprintf( + "%s INTO %s(%s) VALUES%s %s", + operation, table, keysStr, + gstr.Join(valueHolder, ","), + updateStr, + ), params...) if err != nil { return r, err } @@ -821,10 +737,10 @@ func (c *Core) Update(table string, data interface{}, condition interface{}, arg return c.Model(table).Data(data).Where(condition, args...).Update() } -// doUpdate does "UPDATE ... " statement for the table. +// DoUpdate does "UPDATE ... " statement for the table. // This function is usually used for custom interface definition, you do not need call it manually. -func (c *Core) DoUpdate(link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) { - table = c.db.QuotePrefixTableName(table) +func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) { + table = c.QuotePrefixTableName(table) var ( rv = reflect.ValueOf(data) kind = rv.Kind() @@ -849,7 +765,7 @@ func (c *Core) DoUpdate(link Link, table string, data interface{}, condition str if value.Value != 0 { column := k if value.Field != "" { - column = c.db.QuoteWord(value.Field) + column = c.QuoteWord(value.Field) } fields = append(fields, fmt.Sprintf("%s=%s+?", column, column)) params = append(params, value.Value) @@ -858,16 +774,16 @@ func (c *Core) DoUpdate(link Link, table string, data interface{}, condition str if value.Value != 0 { column := k if value.Field != "" { - column = c.db.QuoteWord(value.Field) + column = c.QuoteWord(value.Field) } fields = append(fields, fmt.Sprintf("%s=%s+?", column, column)) params = append(params, value.Value) } default: if s, ok := v.(Raw); ok { - fields = append(fields, c.db.QuoteWord(k)+"="+gconv.String(s)) + fields = append(fields, c.QuoteWord(k)+"="+gconv.String(s)) } else { - fields = append(fields, c.db.QuoteWord(k)+"=?") + fields = append(fields, c.QuoteWord(k)+"=?") params = append(params, v) } } @@ -888,11 +804,7 @@ func (c *Core) DoUpdate(link Link, table string, data interface{}, condition str return nil, err } } - return c.db.DoExec( - link, - fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition), - args..., - ) + return c.DoExec(ctx, link, fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition), args...) } // Delete does "DELETE FROM ... " statement for the table. @@ -912,14 +824,14 @@ func (c *Core) Delete(table string, condition interface{}, args ...interface{}) // DoDelete does "DELETE FROM ... " statement for the table. // This function is usually used for custom interface definition, you do not need call it manually. -func (c *Core) DoDelete(link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) { +func (c *Core) DoDelete(ctx context.Context, link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) { if link == nil { if link, err = c.db.Master(); err != nil { return nil, err } } - table = c.db.QuotePrefixTableName(table) - return c.db.DoExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...) + table = c.QuotePrefixTableName(table) + return c.DoExec(ctx, link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...) } // convertRowsToResult converts underlying data record type sql.Rows to Result type. @@ -955,7 +867,7 @@ func (c *Core) convertRowsToResult(rows *sql.Rows) (Result, error) { if value == nil { row[columnNames[i]] = gvar.New(nil) } else { - row[columnNames[i]] = gvar.New(c.db.convertFieldValueToLocalValue(value, columnTypes[i])) + row[columnNames[i]] = gvar.New(c.convertFieldValueToLocalValue(value, columnTypes[i])) } } records = append(records, row) @@ -977,19 +889,23 @@ func (c *Core) MarshalJSON() ([]byte, error) { // writeSqlToLogger outputs the sql object to logger. // It is enabled only if configuration "debug" is true. -func (c *Core) writeSqlToLogger(v *Sql) { - s := fmt.Sprintf("[%3d ms] [%s] %s", v.End-v.Start, v.Group, v.Format) - if v.Error != nil { - s += "\nError: " + v.Error.Error() - c.logger.Ctx(c.db.GetCtx()).Error(s) +func (c *Core) writeSqlToLogger(ctx context.Context, sql *Sql) { + var transactionIdStr string + if v := ctx.Value(transactionIdForLoggerCtx); v != nil { + transactionIdStr = fmt.Sprintf(`[%d] `, v.(uint64)) + } + s := fmt.Sprintf("[%3d ms] [%s] %s%s", sql.End-sql.Start, sql.Group, transactionIdStr, sql.Format) + if sql.Error != nil { + s += "\nError: " + sql.Error.Error() + c.logger.Ctx(ctx).Error(s) } else { - c.logger.Ctx(c.db.GetCtx()).Debug(s) + c.logger.Ctx(ctx).Debug(s) } } // HasTable determine whether the table name exists in the database. func (c *Core) HasTable(name string) (bool, error) { - tableList, err := c.db.Tables() + tableList, err := c.db.Tables(c.GetCtx()) if err != nil { return false, err } diff --git a/database/gdb/gdb_core_config.go b/database/gdb/gdb_core_config.go index 1f70a0164..8ce85dc8e 100644 --- a/database/gdb/gdb_core_config.go +++ b/database/gdb/gdb_core_config.go @@ -8,15 +8,12 @@ package gdb import ( "fmt" - "github.com/gogf/gf/os/gcache" "sync" "time" - "github.com/gogf/gf/os/glog" -) + "github.com/gogf/gf/os/gcache" -const ( - DefaultGroupName = "default" // Default group name. + "github.com/gogf/gf/os/glog" ) // Config is the configuration management object. @@ -53,6 +50,10 @@ type ConfigNode struct { TimeMaintainDisabled bool `json:"timeMaintainDisabled"` // (Optional) Disable the automatic time maintaining feature. } +const ( + DefaultGroupName = "default" // Default group name. +) + // configs is internal used configuration object. var configs struct { sync.RWMutex diff --git a/database/gdb/gdb_core_structure.go b/database/gdb/gdb_core_structure.go index 8e28f239b..02b9d11ac 100644 --- a/database/gdb/gdb_core_structure.go +++ b/database/gdb/gdb_core_structure.go @@ -7,10 +7,11 @@ package gdb import ( - "github.com/gogf/gf/util/gutil" "strings" "time" + "github.com/gogf/gf/util/gutil" + "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/os/gtime" @@ -149,7 +150,7 @@ func (c *Core) convertFieldValueToLocalValue(fieldValue interface{}, fieldType s // mappingAndFilterData automatically mappings the map key to table field and removes // all key-value pairs that are not the field of given table. func (c *Core) mappingAndFilterData(schema, table string, data map[string]interface{}, filter bool) (map[string]interface{}, error) { - if fieldsMap, err := c.db.TableFields(nil, table, schema); err == nil { + if fieldsMap, err := c.db.TableFields(c.GetCtx(), nil, table, schema); err == nil { fieldsKeyMap := make(map[string]interface{}, len(fieldsMap)) for k, _ := range fieldsMap { fieldsKeyMap[k] = nil diff --git a/database/gdb/gdb_driver_mssql.go b/database/gdb/gdb_driver_mssql.go index f288187c9..fde84e762 100644 --- a/database/gdb/gdb_driver_mssql.go +++ b/database/gdb/gdb_driver_mssql.go @@ -12,12 +12,14 @@ package gdb import ( + "context" "database/sql" "fmt" - "github.com/gogf/gf/errors/gerror" "strconv" "strings" + "github.com/gogf/gf/errors/gerror" + "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/text/gstr" @@ -77,7 +79,7 @@ func (d *DriverMssql) GetChars() (charLeft string, charRight string) { } // HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver. -func (d *DriverMssql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) { +func (d *DriverMssql) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) { var index int // Convert place holder char '?' to string "@px". str, _ := gregex.ReplaceStringFunc("\\?", sql, func(s string) string { @@ -183,14 +185,14 @@ func (d *DriverMssql) parseSql(sql string) string { // Tables retrieves and returns the tables of current schema. // It's mainly used in cli tool chain for automatically generating the models. -func (d *DriverMssql) Tables(schema ...string) (tables []string, err error) { +func (d *DriverMssql) Tables(ctx context.Context, schema ...string) (tables []string, err error) { var result Result - link, err := d.db.GetSlave(schema...) + link, err := d.GetSlave(schema...) if err != nil { return nil, err } - result, err = d.db.DoGetAll(link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`) + result, err = d.DoGetAll(ctx, link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`) if err != nil { return } @@ -205,7 +207,7 @@ func (d *DriverMssql) Tables(schema ...string) (tables []string, err error) { // TableFields retrieves and returns the fields information of specified table of current schema. // // Also see DriverMysql.TableFields. -func (d *DriverMssql) TableFields(link Link, table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverMssql) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) { charL, charR := d.GetChars() table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { @@ -224,7 +226,7 @@ func (d *DriverMssql) TableFields(link Link, table string, schema ...string) (fi result Result ) if link == nil { - link, err = d.db.GetSlave(checkSchema) + link, err = d.GetSlave(checkSchema) if err != nil { return nil } @@ -260,7 +262,7 @@ ORDER BY a.id,a.colorder`, strings.ToUpper(table), ) structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql)) - result, err = d.db.DoGetAll(link, structureSql) + result, err = d.DoGetAll(ctx, link, structureSql) if err != nil { return nil } diff --git a/database/gdb/gdb_driver_mysql.go b/database/gdb/gdb_driver_mysql.go index 9f6f6bf45..ee87ecd94 100644 --- a/database/gdb/gdb_driver_mysql.go +++ b/database/gdb/gdb_driver_mysql.go @@ -7,8 +7,10 @@ package gdb import ( + "context" "database/sql" "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/text/gregex" @@ -75,19 +77,19 @@ func (d *DriverMysql) GetChars() (charLeft string, charRight string) { } // HandleSqlBeforeCommit handles the sql before posts it to database. -func (d *DriverMysql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) { +func (d *DriverMysql) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) { return sql, args } // Tables retrieves and returns the tables of current schema. // It's mainly used in cli tool chain for automatically generating the models. -func (d *DriverMysql) Tables(schema ...string) (tables []string, err error) { +func (d *DriverMysql) Tables(ctx context.Context, schema ...string) (tables []string, err error) { var result Result - link, err := d.db.GetSlave(schema...) + link, err := d.GetSlave(schema...) if err != nil { return nil, err } - result, err = d.db.DoGetAll(link, `SHOW TABLES`) + result, err = d.DoGetAll(ctx, link, `SHOW TABLES`) if err != nil { return } @@ -111,7 +113,7 @@ func (d *DriverMysql) Tables(schema ...string) (tables []string, err error) { // // It's using cache feature to enhance the performance, which is never expired util the // process restarts. -func (d *DriverMysql) TableFields(link Link, table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverMysql) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) { charL, charR := d.GetChars() table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { @@ -130,14 +132,13 @@ func (d *DriverMysql) TableFields(link Link, table string, schema ...string) (fi result Result ) if link == nil { - link, err = d.db.GetSlave(checkSchema) + link, err = d.GetSlave(checkSchema) if err != nil { return nil } } - result, err = d.db.DoGetAll( - link, - fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.db.QuoteWord(table)), + result, err = d.DoGetAll(ctx, link, + fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.QuoteWord(table)), ) if err != nil { return nil diff --git a/database/gdb/gdb_driver_oracle.go b/database/gdb/gdb_driver_oracle.go index f8011b91a..b067de9b4 100644 --- a/database/gdb/gdb_driver_oracle.go +++ b/database/gdb/gdb_driver_oracle.go @@ -12,17 +12,19 @@ package gdb import ( + "context" "database/sql" "fmt" + "reflect" + "strconv" + "strings" + "time" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/text/gregex" "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" - "reflect" - "strconv" - "strings" - "time" ) // DriverOracle is the driver for oracle database. @@ -83,7 +85,7 @@ func (d *DriverOracle) GetChars() (charLeft string, charRight string) { } // HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver. -func (d *DriverOracle) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}) { +func (d *DriverOracle) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}) { var index int // Convert place holder char '?' to string ":vx". newSql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string { @@ -164,9 +166,9 @@ func (d *DriverOracle) parseSql(sql string) string { // Tables retrieves and returns the tables of current schema. // It's mainly used in cli tool chain for automatically generating the models. // Note that it ignores the parameter `schema` in oracle database, as it is not necessary. -func (d *DriverOracle) Tables(schema ...string) (tables []string, err error) { +func (d *DriverOracle) Tables(ctx context.Context, schema ...string) (tables []string, err error) { var result Result - result, err = d.db.DoGetAll(nil, "SELECT TABLE_NAME FROM USER_TABLES ORDER BY TABLE_NAME") + result, err = d.DoGetAll(ctx, nil, "SELECT TABLE_NAME FROM USER_TABLES ORDER BY TABLE_NAME") if err != nil { return } @@ -181,7 +183,7 @@ func (d *DriverOracle) Tables(schema ...string) (tables []string, err error) { // TableFields retrieves and returns the fields information of specified table of current schema. // // Also see DriverMysql.TableFields. -func (d *DriverOracle) TableFields(link Link, table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverOracle) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) { charL, charR := d.GetChars() table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { @@ -211,12 +213,12 @@ FROM USER_TAB_COLUMNS WHERE TABLE_NAME = '%s' ORDER BY COLUMN_ID`, ) structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql)) if link == nil { - link, err = d.db.GetSlave(checkSchema) + link, err = d.GetSlave(checkSchema) if err != nil { return nil } } - result, err = d.db.DoGetAll(link, structureSql) + result, err = d.DoGetAll(ctx, link, structureSql) if err != nil { return nil } @@ -264,7 +266,7 @@ func (d *DriverOracle) getTableUniqueIndex(table string) (fields map[string]map[ return } -func (d *DriverOracle) DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) { +func (d *DriverOracle) DoInsert(ctx context.Context, link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) { var ( fields []string values []string @@ -279,7 +281,7 @@ func (d *DriverOracle) DoInsert(link Link, table string, data interface{}, optio } switch kind { case reflect.Slice, reflect.Array: - return d.db.DoBatchInsert(link, table, data, option, batch...) + return d.DoBatchInsert(ctx, link, table, data, option, batch...) case reflect.Map: fallthrough case reflect.Struct: @@ -353,20 +355,17 @@ func (d *DriverOracle) DoInsert(link Link, table string, data interface{}, optio table, tableAlias1, strings.Join(subSqlStr, ","), tableAlias2, strings.Join(onStr, "AND"), strings.Join(updateStr, ","), strings.Join(fields, ","), strings.Join(values, ","), ) - return d.db.DoExec(link, tmp, params...) + return d.DoExec(ctx, link, tmp, params...) case insertOptionIgnore: - return d.db.DoExec(link, - fmt.Sprintf( - "INSERT /*+ IGNORE_ROW_ON_DUPKEY_INDEX(%s(%s)) */ INTO %s(%s) VALUES(%s)", - table, strings.Join(indexes, ","), table, strings.Join(fields, ","), strings.Join(values, ","), - ), - params...) + return d.DoExec(ctx, link, fmt.Sprintf( + "INSERT /*+ IGNORE_ROW_ON_DUPKEY_INDEX(%s(%s)) */ INTO %s(%s) VALUES(%s)", + table, strings.Join(indexes, ","), table, strings.Join(fields, ","), strings.Join(values, ","), + ), params...) } } - return d.db.DoExec( - link, + return d.DoExec(ctx, link, fmt.Sprintf( "INSERT INTO %s(%s) VALUES(%s)", table, strings.Join(fields, ","), strings.Join(values, ","), @@ -374,7 +373,7 @@ func (d *DriverOracle) DoInsert(link Link, table string, data interface{}, optio params...) } -func (d *DriverOracle) DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { +func (d *DriverOracle) DoBatchInsert(ctx context.Context, link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { var ( keys []string values []string @@ -435,7 +434,7 @@ func (d *DriverOracle) DoBatchInsert(link Link, table string, list interface{}, ) if option != insertOptionDefault { for _, v := range listMap { - r, err := d.db.DoInsert(link, table, v, option, 1) + r, err := d.DoInsert(ctx, link, table, v, option, 1) if err != nil { return r, err } @@ -463,7 +462,7 @@ func (d *DriverOracle) DoBatchInsert(link Link, table string, list interface{}, values = append(values, valueHolderStr) intoStr = append(intoStr, fmt.Sprintf(" INTO %s(%s) VALUES(%s) ", table, keyStr, valueHolderStr)) if len(intoStr) == batchNum { - r, err := d.db.DoExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...) + r, err := d.DoExec(ctx, link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...) if err != nil { return r, err } @@ -479,7 +478,7 @@ func (d *DriverOracle) DoBatchInsert(link Link, table string, list interface{}, } // The leftover data. if len(intoStr) > 0 { - r, err := d.db.DoExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...) + r, err := d.DoExec(ctx, link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...) if err != nil { return r, err } diff --git a/database/gdb/gdb_driver_pgsql.go b/database/gdb/gdb_driver_pgsql.go index 95cf65087..2d0bd76f6 100644 --- a/database/gdb/gdb_driver_pgsql.go +++ b/database/gdb/gdb_driver_pgsql.go @@ -12,12 +12,14 @@ package gdb import ( + "context" "database/sql" "fmt" + "strings" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/text/gstr" - "strings" "github.com/gogf/gf/text/gregex" ) @@ -75,7 +77,7 @@ func (d *DriverPgsql) GetChars() (charLeft string, charRight string) { } // HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver. -func (d *DriverPgsql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) { +func (d *DriverPgsql) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) { var index int // Convert place holder char '?' to string "$x". sql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string { @@ -88,9 +90,9 @@ func (d *DriverPgsql) HandleSqlBeforeCommit(link Link, sql string, args []interf // Tables retrieves and returns the tables of current schema. // It's mainly used in cli tool chain for automatically generating the models. -func (d *DriverPgsql) Tables(schema ...string) (tables []string, err error) { +func (d *DriverPgsql) Tables(ctx context.Context, schema ...string) (tables []string, err error) { var result Result - link, err := d.db.GetSlave(schema...) + link, err := d.GetSlave(schema...) if err != nil { return nil, err } @@ -98,7 +100,7 @@ func (d *DriverPgsql) Tables(schema ...string) (tables []string, err error) { if len(schema) > 0 && schema[0] != "" { query = fmt.Sprintf("SELECT TABLENAME FROM PG_TABLES WHERE SCHEMANAME = '%s' ORDER BY TABLENAME", schema[0]) } - result, err = d.db.DoGetAll(link, query) + result, err = d.DoGetAll(ctx, link, query) if err != nil { return } @@ -113,7 +115,7 @@ func (d *DriverPgsql) Tables(schema ...string) (tables []string, err error) { // TableFields retrieves and returns the fields information of specified table of current schema. // // Also see DriverMysql.TableFields. -func (d *DriverPgsql) TableFields(link Link, table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverPgsql) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) { charL, charR := d.GetChars() table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { @@ -141,12 +143,12 @@ ORDER BY a.attnum`, ) structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql)) if link == nil { - link, err = d.db.GetSlave(checkSchema) + link, err = d.GetSlave(checkSchema) if err != nil { return nil } } - result, err = d.db.DoGetAll(link, structureSql) + result, err = d.DoGetAll(ctx, link, structureSql) if err != nil { return nil } diff --git a/database/gdb/gdb_driver_sqlite.go b/database/gdb/gdb_driver_sqlite.go index ba000fbbb..c78a3eb27 100644 --- a/database/gdb/gdb_driver_sqlite.go +++ b/database/gdb/gdb_driver_sqlite.go @@ -11,13 +11,15 @@ package gdb import ( + "context" "database/sql" "fmt" + "strings" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/os/gfile" "github.com/gogf/gf/text/gstr" - "strings" ) // DriverSqlite is the driver for sqlite database. @@ -67,20 +69,20 @@ func (d *DriverSqlite) GetChars() (charLeft string, charRight string) { // HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver. // TODO 需要增加对Save方法的支持,可使用正则来实现替换, // TODO 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE) -func (d *DriverSqlite) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) { +func (d *DriverSqlite) HandleSqlBeforeCommit(ctx context.Context, link Link, sql string, args []interface{}) (string, []interface{}) { return sql, args } // Tables retrieves and returns the tables of current schema. // It's mainly used in cli tool chain for automatically generating the models. -func (d *DriverSqlite) Tables(schema ...string) (tables []string, err error) { +func (d *DriverSqlite) Tables(ctx context.Context, schema ...string) (tables []string, err error) { var result Result - link, err := d.db.GetSlave(schema...) + link, err := d.GetSlave(schema...) if err != nil { return nil, err } - result, err = d.db.DoGetAll(link, `SELECT NAME FROM SQLITE_MASTER WHERE TYPE='table' ORDER BY NAME`) + result, err = d.DoGetAll(ctx, link, `SELECT NAME FROM SQLITE_MASTER WHERE TYPE='table' ORDER BY NAME`) if err != nil { return } @@ -95,7 +97,7 @@ func (d *DriverSqlite) Tables(schema ...string) (tables []string, err error) { // TableFields retrieves and returns the fields information of specified table of current schema. // // Also see DriverMysql.TableFields. -func (d *DriverSqlite) TableFields(link Link, table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverSqlite) TableFields(ctx context.Context, link Link, table string, schema ...string) (fields map[string]*TableField, err error) { charL, charR := d.GetChars() table = gstr.Trim(table, charL+charR) if gstr.Contains(table, " ") { @@ -114,12 +116,12 @@ func (d *DriverSqlite) TableFields(link Link, table string, schema ...string) (f result Result ) if link == nil { - link, err = d.db.GetSlave(checkSchema) + link, err = d.GetSlave(checkSchema) if err != nil { return nil } } - result, err = d.db.DoGetAll(link, fmt.Sprintf(`PRAGMA TABLE_INFO(%s)`, table)) + result, err = d.DoGetAll(ctx, link, fmt.Sprintf(`PRAGMA TABLE_INFO(%s)`, table)) if err != nil { return nil } diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index ec2b7d7dc..13db19d09 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -9,6 +9,11 @@ package gdb import ( "bytes" "fmt" + "reflect" + "regexp" + "strings" + "time" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/empty" "github.com/gogf/gf/internal/json" @@ -16,10 +21,6 @@ import ( "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/util/gmeta" "github.com/gogf/gf/util/gutil" - "reflect" - "regexp" - "strings" - "time" "github.com/gogf/gf/internal/structs" @@ -504,7 +505,7 @@ func formatWhere(db DB, where interface{}, args []interface{}, omitEmpty bool) ( // Eg: Where/And/Or("uid>=", 1) newWhere += "?" } else if gregex.IsMatchString(regularFieldNameRegPattern, newWhere) { - newWhere = db.QuoteString(newWhere) + newWhere = db.GetCore().QuoteString(newWhere) if len(newArgs) > 0 { if utils.IsArray(newArgs[0]) { // Eg: @@ -543,9 +544,9 @@ func formatWhereInterfaces(db DB, where []interface{}, buffer *bytes.Buffer, new for i := 0; i < len(where); i += 2 { str = gconv.String(where[i]) if buffer.Len() > 0 { - buffer.WriteString(" AND " + db.QuoteWord(str) + "=?") + buffer.WriteString(" AND " + db.GetCore().QuoteWord(str) + "=?") } else { - buffer.WriteString(db.QuoteWord(str) + "=?") + buffer.WriteString(db.GetCore().QuoteWord(str) + "=?") } if s, ok := where[i+1].(Raw); ok { buffer.WriteString(gconv.String(s)) @@ -558,7 +559,7 @@ func formatWhereInterfaces(db DB, where []interface{}, buffer *bytes.Buffer, new // formatWhereKeyValue handles each key-value pair of the parameter map. func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key string, value interface{}) []interface{} { - quotedKey := db.QuoteWord(key) + quotedKey := db.GetCore().QuoteWord(key) if buffer.Len() > 0 { buffer.WriteString(" AND ") } diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index ba834a457..45084a106 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -9,9 +9,10 @@ package gdb import ( "context" "fmt" - "github.com/gogf/gf/text/gregex" "time" + "github.com/gogf/gf/text/gregex" + "github.com/gogf/gf/text/gstr" ) @@ -98,10 +99,10 @@ func (c *Core) Model(tableNameOrStruct ...interface{}) *Model { if len(tableNames) > 1 { tableStr = fmt.Sprintf( - `%s AS %s`, c.db.QuotePrefixTableName(tableNames[0]), c.db.QuoteWord(tableNames[1]), + `%s AS %s`, c.QuotePrefixTableName(tableNames[0]), c.QuoteWord(tableNames[1]), ) } else if len(tableNames) == 1 { - tableStr = c.db.QuotePrefixTableName(tableNames[0]) + tableStr = c.QuotePrefixTableName(tableNames[0]) } return &Model{ db: c.db, @@ -148,9 +149,21 @@ func (m *Model) Ctx(ctx context.Context) *Model { } model := m.getModel() model.db = model.db.Ctx(ctx) + if m.tx != nil { + model.tx = model.tx.Ctx(ctx) + } return model } +// GetCtx returns the context for current Model. +// It returns `context.Background()` is there's no context previously set. +func (m *Model) GetCtx() context.Context { + if m.tx != nil && m.tx.ctx != nil { + return m.tx.ctx + } + return m.db.GetCtx() +} + // As sets an alias name for current table. func (m *Model) As(as string) *Model { if m.tables != "" { diff --git a/database/gdb/gdb_model_cache.go b/database/gdb/gdb_model_cache.go index 0529bf5f6..b9ba72873 100644 --- a/database/gdb/gdb_model_cache.go +++ b/database/gdb/gdb_model_cache.go @@ -38,6 +38,6 @@ func (m *Model) Cache(duration time.Duration, name ...string) *Model { // cache feature is enabled. func (m *Model) checkAndRemoveCache() { if m.cacheEnabled && m.cacheDuration < 0 && len(m.cacheName) > 0 { - m.db.GetCache().Ctx(m.db.GetCtx()).Remove(m.cacheName) + m.db.GetCache().Ctx(m.GetCtx()).Remove(m.cacheName) } } diff --git a/database/gdb/gdb_model_condition.go b/database/gdb/gdb_model_condition.go index 775858ced..641fee5d7 100644 --- a/database/gdb/gdb_model_condition.go +++ b/database/gdb/gdb_model_condition.go @@ -61,53 +61,53 @@ func (m *Model) WherePri(where interface{}, args ...interface{}) *Model { // WhereBetween builds `xxx BETWEEN x AND y` statement. func (m *Model) WhereBetween(column string, min, max interface{}) *Model { - return m.Where(fmt.Sprintf(`%s BETWEEN ? AND ?`, m.db.QuoteWord(column)), min, max) + return m.Where(fmt.Sprintf(`%s BETWEEN ? AND ?`, m.db.GetCore().QuoteWord(column)), min, max) } // WhereLike builds `xxx LIKE x` statement. func (m *Model) WhereLike(column string, like interface{}) *Model { - return m.Where(fmt.Sprintf(`%s LIKE ?`, m.db.QuoteWord(column)), like) + return m.Where(fmt.Sprintf(`%s LIKE ?`, m.db.GetCore().QuoteWord(column)), like) } // WhereIn builds `xxx IN (x)` statement. func (m *Model) WhereIn(column string, in interface{}) *Model { - return m.Where(fmt.Sprintf(`%s IN (?)`, m.db.QuoteWord(column)), in) + return m.Where(fmt.Sprintf(`%s IN (?)`, m.db.GetCore().QuoteWord(column)), in) } // WhereNull builds `xxx IS NULL` statement. func (m *Model) WhereNull(columns ...string) *Model { model := m for _, column := range columns { - model = m.Where(fmt.Sprintf(`%s IS NULL`, m.db.QuoteWord(column))) + model = m.Where(fmt.Sprintf(`%s IS NULL`, m.db.GetCore().QuoteWord(column))) } return model } // WhereNotBetween builds `xxx NOT BETWEEN x AND y` statement. func (m *Model) WhereNotBetween(column string, min, max interface{}) *Model { - return m.Where(fmt.Sprintf(`%s NOT BETWEEN ? AND ?`, m.db.QuoteWord(column)), min, max) + return m.Where(fmt.Sprintf(`%s NOT BETWEEN ? AND ?`, m.db.GetCore().QuoteWord(column)), min, max) } // WhereNotLike builds `xxx NOT LIKE x` statement. func (m *Model) WhereNotLike(column string, like interface{}) *Model { - return m.Where(fmt.Sprintf(`%s NOT LIKE ?`, m.db.QuoteWord(column)), like) + return m.Where(fmt.Sprintf(`%s NOT LIKE ?`, m.db.GetCore().QuoteWord(column)), like) } // WhereNot builds `xxx != x` statement. func (m *Model) WhereNot(column string, value interface{}) *Model { - return m.Where(fmt.Sprintf(`%s != ?`, m.db.QuoteWord(column)), value) + return m.Where(fmt.Sprintf(`%s != ?`, m.db.GetCore().QuoteWord(column)), value) } // WhereNotIn builds `xxx NOT IN (x)` statement. func (m *Model) WhereNotIn(column string, in interface{}) *Model { - return m.Where(fmt.Sprintf(`%s NOT IN (?)`, m.db.QuoteWord(column)), in) + return m.Where(fmt.Sprintf(`%s NOT IN (?)`, m.db.GetCore().QuoteWord(column)), in) } // WhereNotNull builds `xxx IS NOT NULL` statement. func (m *Model) WhereNotNull(columns ...string) *Model { model := m for _, column := range columns { - model = m.Where(fmt.Sprintf(`%s IS NOT NULL`, m.db.QuoteWord(column))) + model = m.Where(fmt.Sprintf(`%s IS NOT NULL`, m.db.GetCore().QuoteWord(column))) } return model } @@ -128,48 +128,48 @@ func (m *Model) WhereOr(where interface{}, args ...interface{}) *Model { // WhereOrBetween builds `xxx BETWEEN x AND y` statement in `OR` conditions. func (m *Model) WhereOrBetween(column string, min, max interface{}) *Model { - return m.WhereOr(fmt.Sprintf(`%s BETWEEN ? AND ?`, m.db.QuoteWord(column)), min, max) + return m.WhereOr(fmt.Sprintf(`%s BETWEEN ? AND ?`, m.db.GetCore().QuoteWord(column)), min, max) } // WhereOrLike builds `xxx LIKE x` statement in `OR` conditions. func (m *Model) WhereOrLike(column string, like interface{}) *Model { - return m.WhereOr(fmt.Sprintf(`%s LIKE ?`, m.db.QuoteWord(column)), like) + return m.WhereOr(fmt.Sprintf(`%s LIKE ?`, m.db.GetCore().QuoteWord(column)), like) } // WhereOrIn builds `xxx IN (x)` statement in `OR` conditions. func (m *Model) WhereOrIn(column string, in interface{}) *Model { - return m.WhereOr(fmt.Sprintf(`%s IN (?)`, m.db.QuoteWord(column)), in) + return m.WhereOr(fmt.Sprintf(`%s IN (?)`, m.db.GetCore().QuoteWord(column)), in) } // WhereOrNull builds `xxx IS NULL` statement in `OR` conditions. func (m *Model) WhereOrNull(columns ...string) *Model { model := m for _, column := range columns { - model = m.WhereOr(fmt.Sprintf(`%s IS NULL`, m.db.QuoteWord(column))) + model = m.WhereOr(fmt.Sprintf(`%s IS NULL`, m.db.GetCore().QuoteWord(column))) } return model } // WhereOrNotBetween builds `xxx NOT BETWEEN x AND y` statement in `OR` conditions. func (m *Model) WhereOrNotBetween(column string, min, max interface{}) *Model { - return m.WhereOr(fmt.Sprintf(`%s NOT BETWEEN ? AND ?`, m.db.QuoteWord(column)), min, max) + return m.WhereOr(fmt.Sprintf(`%s NOT BETWEEN ? AND ?`, m.db.GetCore().QuoteWord(column)), min, max) } // WhereOrNotLike builds `xxx NOT LIKE x` statement in `OR` conditions. func (m *Model) WhereOrNotLike(column string, like interface{}) *Model { - return m.WhereOr(fmt.Sprintf(`%s NOT LIKE ?`, m.db.QuoteWord(column)), like) + return m.WhereOr(fmt.Sprintf(`%s NOT LIKE ?`, m.db.GetCore().QuoteWord(column)), like) } // WhereOrNotIn builds `xxx NOT IN (x)` statement. func (m *Model) WhereOrNotIn(column string, in interface{}) *Model { - return m.WhereOr(fmt.Sprintf(`%s NOT IN (?)`, m.db.QuoteWord(column)), in) + return m.WhereOr(fmt.Sprintf(`%s NOT IN (?)`, m.db.GetCore().QuoteWord(column)), in) } // WhereOrNotNull builds `xxx IS NOT NULL` statement in `OR` conditions. func (m *Model) WhereOrNotNull(columns ...string) *Model { model := m for _, column := range columns { - model = m.WhereOr(fmt.Sprintf(`%s IS NOT NULL`, m.db.QuoteWord(column))) + model = m.WhereOr(fmt.Sprintf(`%s IS NOT NULL`, m.db.GetCore().QuoteWord(column))) } return model } @@ -177,7 +177,7 @@ func (m *Model) WhereOrNotNull(columns ...string) *Model { // Group sets the "GROUP BY" statement for the model. func (m *Model) Group(groupBy string) *Model { model := m.getModel() - model.groupBy = m.db.QuoteString(groupBy) + model.groupBy = m.db.GetCore().QuoteString(groupBy) return model } @@ -215,7 +215,7 @@ func (m *Model) Order(orderBy ...string) *Model { return m } model := m.getModel() - model.orderBy = m.db.QuoteString(strings.Join(orderBy, " ")) + model.orderBy = m.db.GetCore().QuoteString(strings.Join(orderBy, " ")) return model } @@ -225,7 +225,7 @@ func (m *Model) OrderAsc(column string) *Model { return m } model := m.getModel() - model.orderBy = m.db.QuoteWord(column) + " ASC" + model.orderBy = m.db.GetCore().QuoteWord(column) + " ASC" return model } @@ -235,7 +235,7 @@ func (m *Model) OrderDesc(column string) *Model { return m } model := m.getModel() - model.orderBy = m.db.QuoteWord(column) + " DESC" + model.orderBy = m.db.GetCore().QuoteWord(column) + " DESC" return model } diff --git a/database/gdb/gdb_model_delete.go b/database/gdb/gdb_model_delete.go index 4c56cdcd4..2ce680444 100644 --- a/database/gdb/gdb_model_delete.go +++ b/database/gdb/gdb_model_delete.go @@ -9,6 +9,7 @@ package gdb import ( "database/sql" "fmt" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/text/gstr" @@ -32,10 +33,11 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) { ) // Soft deleting. if !m.unscoped && fieldNameDelete != "" { - return m.db.DoUpdate( + return m.db.GetCore().DoUpdate( + m.GetCtx(), m.getLink(true), m.tables, - fmt.Sprintf(`%s=?`, m.db.QuoteString(fieldNameDelete)), + fmt.Sprintf(`%s=?`, m.db.GetCore().QuoteString(fieldNameDelete)), conditionWhere+conditionExtra, append([]interface{}{gtime.Now().String()}, conditionArgs...), ) @@ -44,5 +46,5 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) { if !gstr.ContainsI(conditionStr, " WHERE ") { return nil, gerror.New("there should be WHERE condition statement for DELETE operation") } - return m.db.DoDelete(m.getLink(true), m.tables, conditionStr, conditionArgs...) + return m.db.GetCore().DoDelete(m.GetCtx(), m.getLink(true), m.tables, conditionStr, conditionArgs...) } diff --git a/database/gdb/gdb_model_fields.go b/database/gdb/gdb_model_fields.go index be082b070..46d7af223 100644 --- a/database/gdb/gdb_model_fields.go +++ b/database/gdb/gdb_model_fields.go @@ -114,7 +114,7 @@ func (m *Model) GetFieldsStr(prefix ...string) string { } newFields += prefixStr + k } - newFields = m.db.QuoteString(newFields) + newFields = m.db.GetCore().QuoteString(newFields) return newFields } @@ -158,7 +158,7 @@ func (m *Model) GetFieldsExStr(fields string, prefix ...string) string { } newFields += prefixStr + k } - newFields = m.db.QuoteString(newFields) + newFields = m.db.GetCore().QuoteString(newFields) return newFields } diff --git a/database/gdb/gdb_model_insert.go b/database/gdb/gdb_model_insert.go index d67978b72..3062354fa 100644 --- a/database/gdb/gdb_model_insert.go +++ b/database/gdb/gdb_model_insert.go @@ -8,12 +8,13 @@ package gdb import ( "database/sql" + "reflect" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" "github.com/gogf/gf/util/gutil" - "reflect" ) // Batch sets the batch operation number for the model. @@ -194,7 +195,8 @@ func (m *Model) doInsertWithOption(option int) (result sql.Result, err error) { list[k] = v } } - return m.db.DoBatchInsert( + return m.db.GetCore().DoBatchInsert( + m.GetCtx(), m.getLink(true), m.tables, newData, @@ -219,7 +221,8 @@ func (m *Model) doInsertWithOption(option int) (result sql.Result, err error) { data[fieldNameUpdate] = nowString } } - return m.db.DoInsert( + return m.db.GetCore().DoInsert( + m.GetCtx(), m.getLink(true), m.tables, newData, diff --git a/database/gdb/gdb_model_join.go b/database/gdb/gdb_model_join.go index 847c11abd..10e4c9255 100644 --- a/database/gdb/gdb_model_join.go +++ b/database/gdb/gdb_model_join.go @@ -8,6 +8,7 @@ package gdb import ( "fmt" + "github.com/gogf/gf/text/gstr" ) @@ -72,13 +73,13 @@ func (m *Model) doJoin(operator string, table ...string) *Model { joinStr = "(" + joinStr + ")" } } else { - joinStr = m.db.QuotePrefixTableName(table[0]) + joinStr = m.db.GetCore().QuotePrefixTableName(table[0]) } } if len(table) > 2 { model.tables += fmt.Sprintf( " %s JOIN %s AS %s ON (%s)", - operator, joinStr, m.db.QuoteWord(table[1]), table[2], + operator, joinStr, m.db.GetCore().QuoteWord(table[1]), table[2], ) } else if len(table) == 2 { model.tables += fmt.Sprintf( diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index bafd477bc..3536174dc 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -8,13 +8,14 @@ package gdb import ( "fmt" + "reflect" + "github.com/gogf/gf/container/gset" "github.com/gogf/gf/container/gvar" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/internal/json" "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" - "reflect" ) // Select is alias of Model.All. @@ -66,7 +67,7 @@ func (m *Model) getFieldsFiltered() string { if m.fieldsEx == "" { // No filtering. if !gstr.Contains(m.fields, ".") && !gstr.Contains(m.fields, " ") { - return m.db.QuoteString(m.fields) + return m.db.GetCore().QuoteString(m.fields) } return m.fields } @@ -105,7 +106,7 @@ func (m *Model) getFieldsFiltered() string { if len(newFields) > 0 { newFields += "," } - newFields += m.db.QuoteWord(k) + newFields += m.db.GetCore().QuoteWord(k) } return newFields } @@ -367,7 +368,7 @@ func (m *Model) Min(column string) (float64, error) { if len(column) == 0 { return 0, nil } - value, err := m.Fields(fmt.Sprintf(`MIN(%s)`, m.db.QuoteWord(column))).Value() + value, err := m.Fields(fmt.Sprintf(`MIN(%s)`, m.db.GetCore().QuoteWord(column))).Value() if err != nil { return 0, err } @@ -379,7 +380,7 @@ func (m *Model) Max(column string) (float64, error) { if len(column) == 0 { return 0, nil } - value, err := m.Fields(fmt.Sprintf(`MAX(%s)`, m.db.QuoteWord(column))).Value() + value, err := m.Fields(fmt.Sprintf(`MAX(%s)`, m.db.GetCore().QuoteWord(column))).Value() if err != nil { return 0, err } @@ -391,7 +392,7 @@ func (m *Model) Avg(column string) (float64, error) { if len(column) == 0 { return 0, nil } - value, err := m.Fields(fmt.Sprintf(`AVG(%s)`, m.db.QuoteWord(column))).Value() + value, err := m.Fields(fmt.Sprintf(`AVG(%s)`, m.db.GetCore().QuoteWord(column))).Value() if err != nil { return 0, err } @@ -403,7 +404,7 @@ func (m *Model) Sum(column string) (float64, error) { if len(column) == 0 { return 0, nil } - value, err := m.Fields(fmt.Sprintf(`SUM(%s)`, m.db.QuoteWord(column))).Value() + value, err := m.Fields(fmt.Sprintf(`SUM(%s)`, m.db.GetCore().QuoteWord(column))).Value() if err != nil { return 0, err } @@ -474,7 +475,7 @@ func (m *Model) FindScan(pointer interface{}, where ...interface{}) error { // doGetAllBySql does the select statement on the database. func (m *Model) doGetAllBySql(sql string, args ...interface{}) (result Result, err error) { cacheKey := "" - cacheObj := m.db.GetCache().Ctx(m.db.GetCtx()) + cacheObj := m.db.GetCache().Ctx(m.GetCtx()) // Retrieve from cache. if m.cacheEnabled && m.tx == nil { cacheKey = m.cacheName @@ -496,7 +497,7 @@ func (m *Model) doGetAllBySql(sql string, args ...interface{}) (result Result, e } } } - result, err = m.db.DoGetAll(m.getLink(false), sql, m.mergeArguments(args)...) + result, err = m.db.GetCore().DoGetAll(m.GetCtx(), m.getLink(false), sql, m.mergeArguments(args)...) // Cache the result. if cacheKey != "" && err == nil { if m.cacheDuration < 0 { diff --git a/database/gdb/gdb_model_time.go b/database/gdb/gdb_model_time.go index e167fbb35..e4047e4a1 100644 --- a/database/gdb/gdb_model_time.go +++ b/database/gdb/gdb_model_time.go @@ -140,7 +140,7 @@ func (m *Model) getConditionForSoftDeleting() string { } // Only one table. if fieldName := m.getSoftFieldNameDeleted(); fieldName != "" { - return fmt.Sprintf(`%s IS NULL`, m.db.QuoteWord(fieldName)) + return fmt.Sprintf(`%s IS NULL`, m.db.GetCore().QuoteWord(fieldName)) } return "" } @@ -163,12 +163,12 @@ func (m *Model) getConditionOfTableStringForSoftDeleting(s string) string { return "" } if len(array1) >= 3 { - return fmt.Sprintf(`%s.%s IS NULL`, m.db.QuoteWord(array1[2]), m.db.QuoteWord(field)) + return fmt.Sprintf(`%s.%s IS NULL`, m.db.GetCore().QuoteWord(array1[2]), m.db.GetCore().QuoteWord(field)) } if len(array1) >= 2 { - return fmt.Sprintf(`%s.%s IS NULL`, m.db.QuoteWord(array1[1]), m.db.QuoteWord(field)) + return fmt.Sprintf(`%s.%s IS NULL`, m.db.GetCore().QuoteWord(array1[1]), m.db.GetCore().QuoteWord(field)) } - return fmt.Sprintf(`%s.%s IS NULL`, m.db.QuoteWord(table), m.db.QuoteWord(field)) + return fmt.Sprintf(`%s.%s IS NULL`, m.db.GetCore().QuoteWord(table), m.db.GetCore().QuoteWord(field)) } // getPrimaryTableName parses and returns the primary table name. diff --git a/database/gdb/gdb_model_update.go b/database/gdb/gdb_model_update.go index ad65e723f..8252789e3 100644 --- a/database/gdb/gdb_model_update.go +++ b/database/gdb/gdb_model_update.go @@ -9,12 +9,13 @@ package gdb import ( "database/sql" "fmt" + "reflect" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" "github.com/gogf/gf/util/gutil" - "reflect" ) // Update does "UPDATE ... " statement for the model. @@ -81,7 +82,8 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro if !gstr.ContainsI(conditionStr, " WHERE ") { return nil, gerror.New("there should be WHERE condition statement for UPDATE operation") } - return m.db.DoUpdate( + return m.db.GetCore().DoUpdate( + m.GetCtx(), m.getLink(true), m.tables, newData, diff --git a/database/gdb/gdb_model_utility.go b/database/gdb/gdb_model_utility.go index 944946bb6..4dcd80a66 100644 --- a/database/gdb/gdb_model_utility.go +++ b/database/gdb/gdb_model_utility.go @@ -8,6 +8,8 @@ package gdb import ( "fmt" + "time" + "github.com/gogf/gf/container/gset" "github.com/gogf/gf/internal/empty" "github.com/gogf/gf/os/gtime" @@ -15,7 +17,6 @@ import ( "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" "github.com/gogf/gf/util/gutil" - "time" ) // TableFields retrieves and returns the fields information of specified table of current @@ -29,12 +30,12 @@ func (m *Model) TableFields(table string, schema ...string) (fields map[string]* if m.tx != nil { link = m.tx.tx } else { - link, err = m.db.GetSlave(schema...) + link, err = m.db.GetCore().GetSlave(schema...) if err != nil { return } } - return m.db.TableFields(link, table, schema...) + return m.db.TableFields(m.GetCtx(), link, table, schema...) } // getModel creates and returns a cloned model of current model if `safe` is true, or else it returns @@ -111,7 +112,7 @@ func (m *Model) filterDataForInsertOrUpdate(data interface{}) (interface{}, erro // Note that, it does not filter list item, which is also type of map, for "omit empty" feature. func (m *Model) doMappingAndFilterForInsertOrUpdateDataMap(data Map, allowOmitEmpty bool) (Map, error) { var err error - data, err = m.db.mappingAndFilterData(m.schema, m.tables, data, m.filter) + data, err = m.db.GetCore().mappingAndFilterData(m.schema, m.tables, data, m.filter) if err != nil { return nil, err } @@ -187,13 +188,13 @@ func (m *Model) getLink(master bool) Link { } switch linkType { case linkTypeMaster: - link, err := m.db.GetMaster(m.schema) + link, err := m.db.GetCore().GetMaster(m.schema) if err != nil { panic(err) } return link case linkTypeSlave: - link, err := m.db.GetSlave(m.schema) + link, err := m.db.GetCore().GetSlave(m.schema) if err != nil { panic(err) } diff --git a/database/gdb/gdb_model_with.go b/database/gdb/gdb_model_with.go index bd6d37883..1fbd891d3 100644 --- a/database/gdb/gdb_model_with.go +++ b/database/gdb/gdb_model_with.go @@ -8,12 +8,13 @@ package gdb import ( "fmt" + "reflect" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/structs" "github.com/gogf/gf/internal/utils" "github.com/gogf/gf/text/gregex" "github.com/gogf/gf/text/gstr" - "reflect" ) // With creates and returns an ORM model based on meta data of given object. @@ -38,7 +39,7 @@ func (m *Model) With(objects ...interface{}) *Model { model := m.getModel() for _, object := range objects { if m.tables == "" { - m.tables = m.db.QuotePrefixTableName(getTableNameFromOrmTag(object)) + m.tables = m.db.GetCore().QuotePrefixTableName(getTableNameFromOrmTag(object)) return model } model.withArray = append(model.withArray, object) diff --git a/database/gdb/gdb_statement.go b/database/gdb/gdb_statement.go index e856ba0b0..cc74e858c 100644 --- a/database/gdb/gdb_statement.go +++ b/database/gdb/gdb_statement.go @@ -9,6 +9,7 @@ package gdb import ( "context" "database/sql" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/os/gtime" ) @@ -72,9 +73,10 @@ func (s *Stmt) doStmtCommit(stmtType string, ctx context.Context, args ...interf Group: s.core.db.GetGroup(), } ) + // Tracing and logging. s.core.addSqlToTracing(ctx, sqlObj) if s.core.db.GetDebug() { - s.core.writeSqlToLogger(sqlObj) + s.core.writeSqlToLogger(ctx, sqlObj) } return result, err } diff --git a/database/gdb/gdb_transaction.go b/database/gdb/gdb_transaction.go index 187e9e9ea..abc46a7da 100644 --- a/database/gdb/gdb_transaction.go +++ b/database/gdb/gdb_transaction.go @@ -12,50 +12,184 @@ import ( "fmt" "reflect" + "github.com/gogf/gf/container/gtype" "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/util/gconv" + "github.com/gogf/gf/util/guid" "github.com/gogf/gf/text/gregex" ) // TX is the struct for transaction management. type TX struct { - db DB // db is the current gdb database manager. - tx *sql.Tx // tx is the raw and underlying transaction manager. - master *sql.DB // master is the raw and underlying database manager. - transactionCount int // transactionCount marks the times that Begins. + db DB // db is the current gdb database manager. + tx *sql.Tx // tx is the raw and underlying transaction manager. + ctx context.Context // ctx is the context for this transaction only. + master *sql.DB // master is the raw and underlying database manager. + transactionId string // transactionId is an unique id generated by this object for this transaction. + transactionCount int // transactionCount marks the times that Begins. } const ( - transactionPointerPrefix = "transaction" - contextTransactionKey = "TransactionObject" + transactionPointerPrefix = "transaction" + contextTransactionKeyPrefix = "TransactionObjectForGroup_" + transactionIdForLoggerCtx = "TransactionId" ) +var ( + transactionIdGenerator = gtype.NewUint64() +) + +// Begin starts and returns the transaction object. +// You should call Commit or Rollback functions of the transaction object +// if you no longer use the transaction. Commit or Rollback functions will also +// close the transaction automatically. +func (c *Core) Begin() (tx *TX, err error) { + return c.doBeginCtx(c.GetCtx()) +} + +func (c *Core) doBeginCtx(ctx context.Context) (*TX, error) { + if master, err := c.db.Master(); err != nil { + return nil, err + } else { + var ( + tx *TX + sqlStr = "BEGIN" + mTime1 = gtime.TimestampMilli() + rawTx, err = master.Begin() + mTime2 = gtime.TimestampMilli() + sqlObj = &Sql{ + Sql: sqlStr, + Type: "DB.Begin", + Args: nil, + Format: sqlStr, + Error: err, + Start: mTime1, + End: mTime2, + Group: c.db.GetGroup(), + } + ) + if err == nil { + tx = &TX{ + db: c.db, + tx: rawTx, + ctx: context.WithValue(ctx, transactionIdForLoggerCtx, transactionIdGenerator.Add(1)), + master: master, + transactionId: guid.S(), + } + ctx = tx.ctx + } + // Tracing and logging. + c.addSqlToTracing(ctx, sqlObj) + if c.db.GetDebug() { + c.writeSqlToLogger(ctx, sqlObj) + } + return tx, err + } +} + +// Transaction wraps the transaction logic using function `f`. +// It rollbacks the transaction and returns the error from function `f` if +// it returns non-nil error. It commits the transaction and returns nil if +// function `f` returns nil. +// +// Note that, you should not Commit or Rollback the transaction in function `f` +// as it is automatically handled by this function. +func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx *TX) error) (err error) { + var ( + tx *TX + ) + if ctx == nil { + ctx = c.GetCtx() + } + // Check transaction object from context. + tx = TXFromCtx(ctx, c.db.GetGroup()) + if tx != nil { + return tx.Transaction(ctx, f) + } + tx, err = c.doBeginCtx(ctx) + if err != nil { + return err + } + // Inject transaction object into context. + tx.ctx = WithTX(tx.ctx, tx) + defer func() { + if err == nil { + if e := recover(); e != nil { + err = fmt.Errorf("%v", e) + } + } + if err != nil { + if e := tx.Rollback(); e != nil { + err = e + } + } else { + if e := tx.Commit(); e != nil { + err = e + } + } + }() + err = f(tx.ctx, tx) + return +} + // WithTX injects given transaction object into context and returns a new context. func WithTX(ctx context.Context, tx *TX) context.Context { - return context.WithValue(ctx, contextTransactionKey, tx) + if tx == nil { + return ctx + } + // Check repeat injection from given. + group := tx.db.GetGroup() + if tx := TXFromCtx(ctx, group); tx != nil && tx.db.GetGroup() == group { + return ctx + } + dbCtx := tx.db.GetCtx() + if tx := TXFromCtx(dbCtx, group); tx != nil && tx.db.GetGroup() == group { + return dbCtx + } + // Inject transaction object and id into context. + ctx = context.WithValue(ctx, transactionKeyForContext(group), tx) + return ctx } // TXFromCtx retrieves and returns transaction object from context. // It is usually used in nested transaction feature, and it returns nil if it is not set previously. -func TXFromCtx(ctx context.Context) *TX { +func TXFromCtx(ctx context.Context, group string) *TX { if ctx == nil { return nil } - v := ctx.Value(contextTransactionKey) + v := ctx.Value(transactionKeyForContext(group)) if v != nil { - return v.(*TX) + tx := v.(*TX) + tx.ctx = ctx + return tx } return nil } +// transactionKeyForContext forms and returns a string for storing transaction object of certain database group into context. +func transactionKeyForContext(group string) string { + return contextTransactionKeyPrefix + group +} + +// transactionKeyForNestedPoint forms and returns the transaction key at current save point. +func (tx *TX) transactionKeyForNestedPoint() string { + return tx.db.GetCore().QuoteWord(transactionPointerPrefix + gconv.String(tx.transactionCount)) +} + +// Ctx sets the context for current transaction. +func (tx *TX) Ctx(ctx context.Context) *TX { + tx.ctx = ctx + return tx +} + // Commit commits current transaction. // Note that it releases previous saved transaction point if it's in a nested transaction procedure, // or else it commits the hole transaction. func (tx *TX) Commit() error { if tx.transactionCount > 0 { tx.transactionCount-- - _, err := tx.Exec("RELEASE SAVEPOINT " + tx.transactionKey()) + _, err := tx.Exec("RELEASE SAVEPOINT " + tx.transactionKeyForNestedPoint()) return err } var ( @@ -74,9 +208,9 @@ func (tx *TX) Commit() error { Group: tx.db.GetGroup(), } ) - tx.db.addSqlToTracing(tx.db.GetCtx(), sqlObj) + tx.db.GetCore().addSqlToTracing(tx.ctx, sqlObj) if tx.db.GetDebug() { - tx.db.writeSqlToLogger(sqlObj) + tx.db.GetCore().writeSqlToLogger(tx.ctx, sqlObj) } return err } @@ -87,7 +221,7 @@ func (tx *TX) Commit() error { func (tx *TX) Rollback() error { if tx.transactionCount > 0 { tx.transactionCount-- - _, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.transactionKey()) + _, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.transactionKeyForNestedPoint()) return err } var ( @@ -106,16 +240,16 @@ func (tx *TX) Rollback() error { Group: tx.db.GetGroup(), } ) - tx.db.addSqlToTracing(tx.db.GetCtx(), sqlObj) + tx.db.GetCore().addSqlToTracing(tx.ctx, sqlObj) if tx.db.GetDebug() { - tx.db.writeSqlToLogger(sqlObj) + tx.db.GetCore().writeSqlToLogger(tx.ctx, sqlObj) } return err } // Begin starts a nested transaction procedure. func (tx *TX) Begin() error { - _, err := tx.Exec("SAVEPOINT " + tx.transactionKey()) + _, err := tx.Exec("SAVEPOINT " + tx.transactionKeyForNestedPoint()) if err != nil { return err } @@ -126,22 +260,17 @@ func (tx *TX) Begin() error { // SavePoint performs `SAVEPOINT xxx` SQL statement that saves transaction at current point. // The parameter `point` specifies the point name that will be saved to server. func (tx *TX) SavePoint(point string) error { - _, err := tx.Exec("SAVEPOINT " + tx.db.QuoteWord(point)) + _, err := tx.Exec("SAVEPOINT " + tx.db.GetCore().QuoteWord(point)) return err } // RollbackTo performs `ROLLBACK TO SAVEPOINT xxx` SQL statement that rollbacks to specified saved transaction. // The parameter `point` specifies the point name that was saved previously. func (tx *TX) RollbackTo(point string) error { - _, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.db.QuoteWord(point)) + _, err := tx.Exec("ROLLBACK TO SAVEPOINT " + tx.db.GetCore().QuoteWord(point)) return err } -// transactionKey forms and returns the transaction key at current save point. -func (tx *TX) transactionKey() string { - return tx.db.QuoteWord(transactionPointerPrefix + gconv.String(tx.transactionCount)) -} - // Transaction wraps the transaction logic using function `f`. // It rollbacks the transaction and returns the error from function `f` if // it returns non-nil error. It commits the transaction and returns nil if @@ -150,10 +279,13 @@ func (tx *TX) transactionKey() string { // Note that, you should not Commit or Rollback the transaction in function `f` // as it is automatically handled by this function. func (tx *TX) Transaction(ctx context.Context, f func(ctx context.Context, tx *TX) error) (err error) { + if ctx != nil { + tx.ctx = ctx + } // Check transaction object from context. - if TXFromCtx(ctx) == nil { + if TXFromCtx(tx.ctx, tx.db.GetGroup()) == nil { // Inject transaction object into context. - ctx = WithTX(ctx, tx) + tx.ctx = WithTX(tx.ctx, tx) } err = tx.Begin() if err != nil { @@ -175,20 +307,20 @@ func (tx *TX) Transaction(ctx context.Context, f func(ctx context.Context, tx *T } } }() - err = f(ctx, tx) + err = f(tx.ctx, tx) return } // Query does query operation on transaction. // See Core.Query. func (tx *TX) Query(sql string, args ...interface{}) (rows *sql.Rows, err error) { - return tx.db.DoQuery(tx.tx, sql, args...) + return tx.db.GetCore().DoQuery(tx.ctx, tx.tx, sql, args...) } // Exec does none query operation on transaction. // See Core.Exec. func (tx *TX) Exec(sql string, args ...interface{}) (sql.Result, error) { - return tx.db.DoExec(tx.tx, sql, args...) + return tx.db.GetCore().DoExec(tx.ctx, tx.tx, sql, args...) } // Prepare creates a prepared statement for later queries or executions. @@ -197,7 +329,7 @@ func (tx *TX) Exec(sql string, args ...interface{}) (sql.Result, error) { // The caller must call the statement's Close method // when the statement is no longer needed. func (tx *TX) Prepare(sql string) (*Stmt, error) { - return tx.db.DoPrepare(tx.tx, sql) + return tx.db.GetCore().DoPrepare(tx.ctx, tx.tx, sql) } // GetAll queries and returns data records from database. @@ -207,7 +339,7 @@ func (tx *TX) GetAll(sql string, args ...interface{}) (Result, error) { return nil, err } defer rows.Close() - return tx.db.convertRowsToResult(rows) + return tx.db.GetCore().convertRowsToResult(rows) } // GetOne queries and returns one record from database. @@ -248,8 +380,8 @@ func (tx *TX) GetStructs(objPointerSlice interface{}, sql string, args ...interf // If parameter `pointer` is type of struct pointer, it calls GetStruct internally for // the conversion. If parameter `pointer` is type of slice, it calls GetStructs internally // for conversion. -func (tx *TX) GetScan(objPointer interface{}, sql string, args ...interface{}) error { - t := reflect.TypeOf(objPointer) +func (tx *TX) GetScan(pointer interface{}, sql string, args ...interface{}) error { + t := reflect.TypeOf(pointer) k := t.Kind() if k != reflect.Ptr { return fmt.Errorf("params should be type of pointer, but got: %v", k) @@ -257,9 +389,9 @@ func (tx *TX) GetScan(objPointer interface{}, sql string, args ...interface{}) e k = t.Elem().Kind() switch k { case reflect.Array, reflect.Slice: - return tx.db.GetStructs(objPointer, sql, args...) + return tx.GetStructs(pointer, sql, args...) case reflect.Struct: - return tx.db.GetStruct(objPointer, sql, args...) + return tx.GetStruct(pointer, sql, args...) default: return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k) } @@ -302,9 +434,9 @@ func (tx *TX) GetCount(sql string, args ...interface{}) (int, error) { // The parameter `batch` specifies the batch operation count when given data is slice. func (tx *TX) Insert(table string, data interface{}, batch ...int) (sql.Result, error) { if len(batch) > 0 { - return tx.Model(table).Data(data).Batch(batch[0]).Insert() + return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Insert() } - return tx.Model(table).Data(data).Insert() + return tx.Model(table).Ctx(tx.ctx).Data(data).Insert() } // InsertIgnore does "INSERT IGNORE INTO ..." statement for the table. @@ -318,17 +450,17 @@ func (tx *TX) Insert(table string, data interface{}, batch ...int) (sql.Result, // The parameter `batch` specifies the batch operation count when given data is slice. func (tx *TX) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) { if len(batch) > 0 { - return tx.Model(table).Data(data).Batch(batch[0]).InsertIgnore() + return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertIgnore() } - return tx.Model(table).Data(data).InsertIgnore() + return tx.Model(table).Ctx(tx.ctx).Data(data).InsertIgnore() } // InsertAndGetId performs action Insert and returns the last insert id that automatically generated. func (tx *TX) InsertAndGetId(table string, data interface{}, batch ...int) (int64, error) { if len(batch) > 0 { - return tx.Model(table).Data(data).Batch(batch[0]).InsertAndGetId() + return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).InsertAndGetId() } - return tx.Model(table).Data(data).InsertAndGetId() + return tx.Model(table).Ctx(tx.ctx).Data(data).InsertAndGetId() } // Replace does "REPLACE INTO ..." statement for the table. @@ -345,9 +477,9 @@ func (tx *TX) InsertAndGetId(table string, data interface{}, batch ...int) (int6 // `batch` specifies the batch operation count. func (tx *TX) Replace(table string, data interface{}, batch ...int) (sql.Result, error) { if len(batch) > 0 { - return tx.Model(table).Data(data).Batch(batch[0]).Replace() + return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Replace() } - return tx.Model(table).Data(data).Replace() + return tx.Model(table).Ctx(tx.ctx).Data(data).Replace() } // Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table. @@ -363,45 +495,45 @@ func (tx *TX) Replace(table string, data interface{}, batch ...int) (sql.Result, // `batch` specifies the batch operation count. func (tx *TX) Save(table string, data interface{}, batch ...int) (sql.Result, error) { if len(batch) > 0 { - return tx.Model(table).Data(data).Batch(batch[0]).Save() + return tx.Model(table).Ctx(tx.ctx).Data(data).Batch(batch[0]).Save() } - return tx.Model(table).Data(data).Save() + return tx.Model(table).Ctx(tx.ctx).Data(data).Save() } // BatchInsert batch inserts data. // The parameter `list` must be type of slice of map or struct. func (tx *TX) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) { if len(batch) > 0 { - return tx.Model(table).Data(list).Batch(batch[0]).Insert() + return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).Insert() } - return tx.Model(table).Data(list).Insert() + return tx.Model(table).Ctx(tx.ctx).Data(list).Insert() } // BatchInsertIgnore batch inserts data with ignore option. // The parameter `list` must be type of slice of map or struct. func (tx *TX) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) { if len(batch) > 0 { - return tx.Model(table).Data(list).Batch(batch[0]).InsertIgnore() + return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).InsertIgnore() } - return tx.Model(table).Data(list).InsertIgnore() + return tx.Model(table).Ctx(tx.ctx).Data(list).InsertIgnore() } // BatchReplace batch replaces data. // The parameter `list` must be type of slice of map or struct. func (tx *TX) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) { if len(batch) > 0 { - return tx.Model(table).Data(list).Batch(batch[0]).Replace() + return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).Replace() } - return tx.Model(table).Data(list).Replace() + return tx.Model(table).Ctx(tx.ctx).Data(list).Replace() } // BatchSave batch replaces data. // The parameter `list` must be type of slice of map or struct. func (tx *TX) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) { if len(batch) > 0 { - return tx.Model(table).Data(list).Batch(batch[0]).Save() + return tx.Model(table).Ctx(tx.ctx).Data(list).Batch(batch[0]).Save() } - return tx.Model(table).Data(list).Save() + return tx.Model(table).Ctx(tx.ctx).Data(list).Save() } // Update does "UPDATE ... " statement for the table. @@ -419,7 +551,7 @@ func (tx *TX) BatchSave(table string, list interface{}, batch ...int) (sql.Resul // "age IN(?,?)", 18, 50 // User{ Id : 1, UserName : "john"} func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { - return tx.Model(table).Data(data).Where(condition, args...).Update() + return tx.Model(table).Ctx(tx.ctx).Data(data).Where(condition, args...).Update() } // Delete does "DELETE FROM ... " statement for the table. @@ -434,5 +566,5 @@ func (tx *TX) Update(table string, data interface{}, condition interface{}, args // "age IN(?,?)", 18, 50 // User{ Id : 1, UserName : "john"} func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { - return tx.Model(table).Where(condition, args...).Delete() + return tx.Model(table).Ctx(tx.ctx).Where(condition, args...).Delete() } diff --git a/database/gdb/gdb_z_driver_test.go b/database/gdb/gdb_z_driver_test.go index 80686d203..e166856d9 100644 --- a/database/gdb/gdb_z_driver_test.go +++ b/database/gdb/gdb_z_driver_test.go @@ -7,11 +7,13 @@ package gdb_test import ( + "context" + "testing" + "github.com/gogf/gf/container/gtype" "github.com/gogf/gf/database/gdb" "github.com/gogf/gf/frame/g" "github.com/gogf/gf/test/gtest" - "testing" ) // MyDriver is a custom database driver, which is used for testing only. @@ -41,9 +43,9 @@ func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) { // HandleSqlBeforeCommit handles the sql before posts it to database. // It here overwrites the same method of gdb.DriverMysql and makes some custom changes. -func (d *MyDriver) HandleSqlBeforeCommit(link gdb.Link, sql string, args []interface{}) (string, []interface{}) { +func (d *MyDriver) HandleSqlBeforeCommit(ctx context.Context, link gdb.Link, sql string, args []interface{}) (string, []interface{}) { latestSqlString.Set(sql) - return d.DriverMysql.HandleSqlBeforeCommit(link, sql, args) + return d.DriverMysql.HandleSqlBeforeCommit(ctx, link, sql, args) } func init() { diff --git a/database/gdb/gdb_z_mysql_model_test.go b/database/gdb/gdb_z_mysql_model_test.go index 78db30e4f..b3087a4ea 100644 --- a/database/gdb/gdb_z_mysql_model_test.go +++ b/database/gdb/gdb_z_mysql_model_test.go @@ -2910,13 +2910,13 @@ func Test_Model_HasTable(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - result, err := db.HasTable(table) + result, err := db.GetCore().HasTable(table) t.Assert(result, true) t.AssertNil(err) }) gtest.C(t, func(t *gtest.T) { - result, err := db.HasTable("table12321") + result, err := db.GetCore().HasTable("table12321") t.Assert(result, false) t.AssertNil(err) }) diff --git a/database/gdb/gdb_z_mysql_transaction_test.go b/database/gdb/gdb_z_mysql_transaction_test.go index 1fa33ea23..b22b2a765 100644 --- a/database/gdb/gdb_z_mysql_transaction_test.go +++ b/database/gdb/gdb_z_mysql_transaction_test.go @@ -833,7 +833,7 @@ func Test_Transaction_Nested_Begin_Rollback_Commit(t *testing.T) { func Test_Transaction_Nested_TX_Transaction_UseTX(t *testing.T) { table := createTable() defer dropTable(table) - + db.SetDebug(true) gtest.C(t, func(t *gtest.T) { var ( err error @@ -897,7 +897,7 @@ func Test_Transaction_Nested_TX_Transaction_UseTX(t *testing.T) { func Test_Transaction_Nested_TX_Transaction_UseDB(t *testing.T) { table := createTable() defer dropTable(table) - + db.SetDebug(true) gtest.C(t, func(t *gtest.T) { var ( err error @@ -910,7 +910,7 @@ func Test_Transaction_Nested_TX_Transaction_UseDB(t *testing.T) { err = db.Transaction(ctx, func(ctx context.Context, tx *gdb.TX) error { err = db.Transaction(ctx, func(ctx context.Context, tx *gdb.TX) error { err = db.Transaction(ctx, func(ctx context.Context, tx *gdb.TX) error { - _, err = db.Model(table).Data(g.Map{ + _, err = db.Model(table).Ctx(ctx).Data(g.Map{ "id": 1, "passport": "USER_1", "password": "PASS_1", @@ -933,9 +933,10 @@ func Test_Transaction_Nested_TX_Transaction_UseDB(t *testing.T) { return err }) t.AssertNil(err) + // rollback err = db.Transaction(ctx, func(ctx context.Context, tx *gdb.TX) error { - _, err = tx.Model(table).Data(g.Map{ + _, err = tx.Model(table).Ctx(ctx).Data(g.Map{ "id": 2, "passport": "USER_2", "password": "PASS_2", @@ -943,6 +944,7 @@ func Test_Transaction_Nested_TX_Transaction_UseDB(t *testing.T) { "create_time": gtime.Now().String(), }).Insert() t.AssertNil(err) + // panic makes this transaction rollback. panic("error") return err }) diff --git a/os/glog/glog_logger_config.go b/os/glog/glog_logger_config.go index dd619d20e..f28e6862a 100644 --- a/os/glog/glog_logger_config.go +++ b/os/glog/glog_logger_config.go @@ -9,14 +9,15 @@ package glog import ( "errors" "fmt" + "io" + "strings" + "time" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/os/gfile" "github.com/gogf/gf/util/gconv" "github.com/gogf/gf/util/gutil" - "io" - "strings" - "time" ) // Config is the configuration object for logger. @@ -165,6 +166,24 @@ func (l *Logger) SetCtxKeys(keys ...interface{}) { l.config.CtxKeys = keys } +// AppendCtxKeys appends extra keys to logger. +// It ignores the key if it is already appended to the logger previously. +func (l *Logger) AppendCtxKeys(keys ...interface{}) { + var isExist bool + for _, key := range keys { + isExist = false + for _, ctxKey := range l.config.CtxKeys { + if ctxKey == key { + isExist = true + break + } + } + if !isExist { + l.config.CtxKeys = append(l.config.CtxKeys, key) + } + } +} + // GetCtxKeys retrieves and returns the context keys for logging. func (l *Logger) GetCtxKeys() []interface{} { return l.config.CtxKeys