diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 735656ac7..e89fe953a 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -522,8 +522,11 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode { // The parameter `master` specifies whether retrieves master node connection if // master-slave nodes are configured. func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error) { + var ( + ctx = c.db.GetCtx() + node *ConfigNode + ) // Load balance. - var node *ConfigNode if c.group != "" { node, err = getConfigNodeByGroup(c.group, master) if err != nil { @@ -549,20 +552,12 @@ func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error } // Cache the underlying connection pool object by node. v := c.links.GetOrSetFuncLock(node.String(), func() interface{} { - intlog.Printf( - c.db.GetCtx(), - `open new connection, master:%#v, config:%#v, node:%#v`, - master, c.config, node, - ) + intlog.Printf(ctx, `open new connection, master:%#v, config:%#v, node:%#v`, master, c.config, node) defer func() { if err != nil { - intlog.Printf(c.db.GetCtx(), `open new connection failed: %v, %#v`, err, node) + intlog.Printf(ctx, `open new connection failed: %v, %#v`, err, node) } else { - intlog.Printf( - c.db.GetCtx(), - `open new connection success, master:%#v, config:%#v, node:%#v`, - master, c.config, node, - ) + intlog.Printf(ctx, `open new connection success, master:%#v, config:%#v, node:%#v`, master, c.config, node) } }() diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index e4ed2312e..728d942d6 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -67,9 +67,9 @@ func (c *Core) GetCtx() context.Context { } // GetCtxTimeout returns the context and cancel function for specified timeout type. -func (c *Core) GetCtxTimeout(timeoutType int, ctx context.Context) (context.Context, context.CancelFunc) { +func (c *Core) GetCtxTimeout(ctx context.Context, timeoutType int) (context.Context, context.CancelFunc) { if ctx == nil { - ctx = c.GetCtx() + ctx = c.db.GetCtx() } else { ctx = context.WithValue(ctx, "WrappedByGetCtxTimeout", nil) } @@ -255,17 +255,18 @@ func (c *Core) GetCount(ctx context.Context, sql string, args ...interface{}) (i // Union does "(SELECT xxx FROM xxx) UNION (SELECT xxx FROM xxx) ..." statement. func (c *Core) Union(unions ...*Model) *Model { - return c.doUnion(unionTypeNormal, unions...) + var ctx = c.db.GetCtx() + return c.doUnion(ctx, unionTypeNormal, unions...) } // UnionAll does "(SELECT xxx FROM xxx) UNION ALL (SELECT xxx FROM xxx) ..." statement. func (c *Core) UnionAll(unions ...*Model) *Model { - return c.doUnion(unionTypeAll, unions...) + var ctx = c.db.GetCtx() + return c.doUnion(ctx, unionTypeAll, unions...) } -func (c *Core) doUnion(unionType int, unions ...*Model) *Model { +func (c *Core) doUnion(ctx context.Context, unionType int, unions ...*Model) *Model { var ( - ctx = c.db.GetCtx() unionTypeStr string composedSqlStr string composedArgs = make([]interface{}, 0) @@ -289,10 +290,11 @@ func (c *Core) doUnion(unionType int, unions ...*Model) *Model { // PingMaster pings the master node to check authentication or keeps the connection alive. func (c *Core) PingMaster() error { + var ctx = c.db.GetCtx() if master, err := c.db.Master(); err != nil { return err } else { - if err = master.PingContext(c.GetCtx()); err != nil { + if err = master.PingContext(ctx); err != nil { err = gerror.WrapCode(gcode.CodeDbOperationError, err, `master.Ping failed`) } return err @@ -301,10 +303,11 @@ func (c *Core) PingMaster() error { // PingSlave pings the slave node to check authentication or keeps the connection alive. func (c *Core) PingSlave() error { + var ctx = c.db.GetCtx() if slave, err := c.db.Slave(); err != nil { return err } else { - if err = slave.PingContext(c.GetCtx()); err != nil { + if err = slave.PingContext(ctx); err != nil { err = gerror.WrapCode(gcode.CodeDbOperationError, err, `slave.Ping failed`) } return err @@ -663,21 +666,22 @@ func (c *Core) writeSqlToLogger(ctx context.Context, sql *Sql) { // HasTable determine whether the table name exists in the database. func (c *Core) HasTable(name string) (bool, error) { - result, err := c.GetCache().GetOrSetFuncLock( - c.GetCtx(), - fmt.Sprintf(`HasTable: %s`, name), - func(ctx context.Context) (interface{}, error) { - tableList, err := c.db.Tables(ctx) - if err != nil { - return false, err + var ( + ctx = c.db.GetCtx() + cacheKey = fmt.Sprintf(`HasTable: %s`, name) + ) + result, err := c.GetCache().GetOrSetFuncLock(ctx, cacheKey, func(ctx context.Context) (interface{}, error) { + tableList, err := c.db.Tables(ctx) + if err != nil { + return false, err + } + for _, table := range tableList { + if table == name { + return true, nil } - for _, table := range tableList { - if table == name { - return true, nil - } - } - return false, nil - }, 0, + } + return false, nil + }, 0, ) if err != nil { return false, err diff --git a/database/gdb/gdb_core_transaction.go b/database/gdb/gdb_core_transaction.go index dc57d38f6..67750f92c 100644 --- a/database/gdb/gdb_core_transaction.go +++ b/database/gdb/gdb_core_transaction.go @@ -71,7 +71,7 @@ func (c *Core) doBeginCtx(ctx context.Context) (*TX, error) { func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx *TX) error) (err error) { var tx *TX if ctx == nil { - ctx = c.GetCtx() + ctx = c.db.GetCtx() } // Check transaction object from context. tx = TXFromCtx(ctx, c.db.GetGroup()) diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index c9b051823..a8bc2ef60 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -234,7 +234,7 @@ func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutp out.RawResult = sqlStmt case SqlTypeStmtExecContext: - ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeExec, ctx) + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeExec) defer cancelFuncForTimeout() if c.db.GetDryRun() { sqlResult = new(SqlResult) @@ -244,13 +244,13 @@ func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutp out.RawResult = sqlResult case SqlTypeStmtQueryContext: - ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeQuery, ctx) + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery) defer cancelFuncForTimeout() stmtSqlRows, err = in.Stmt.QueryContext(ctx, in.Args...) out.RawResult = stmtSqlRows case SqlTypeStmtQueryRowContext: - ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeQuery, ctx) + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctx, ctxTimeoutTypeQuery) defer cancelFuncForTimeout() stmtSqlRow = in.Stmt.QueryRowContext(ctx, in.Args...) out.RawResult = stmtSqlRow diff --git a/database/gdb/gdb_core_utility.go b/database/gdb/gdb_core_utility.go index bbfeb74f1..fab680639 100644 --- a/database/gdb/gdb_core_utility.go +++ b/database/gdb/gdb_core_utility.go @@ -150,11 +150,12 @@ func (c *Core) Tables(schema ...string) (tables []string, err error) { // // It does nothing in default. func (c *Core) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { + var ctx = c.db.GetCtx() // It does nothing if given table is empty, especially in sub-query. if table == "" { return map[string]*TableField{}, nil } - return c.db.TableFields(c.GetCtx(), table, schema...) + return c.db.TableFields(ctx, table, schema...) } // HasField determine whether the field exists in the table. diff --git a/database/gdb/gdb_driver_mysql.go b/database/gdb/gdb_driver_mysql.go index a7e161185..40b8aa933 100644 --- a/database/gdb/gdb_driver_mysql.go +++ b/database/gdb/gdb_driver_mysql.go @@ -38,6 +38,7 @@ func (d *DriverMysql) New(core *Core, node *ConfigNode) (DB, error) { // Note that it converts time.Time argument to local timezone in default. func (d *DriverMysql) Open(config *ConfigNode) (db *sql.DB, err error) { var ( + ctx = d.GetCtx() source string underlyingDriverName = "mysql" ) @@ -56,7 +57,7 @@ func (d *DriverMysql) Open(config *ConfigNode) (db *sql.DB, err error) { source = fmt.Sprintf("%s&loc=%s", source, url.QueryEscape(config.Timezone)) } } - intlog.Printf(d.GetCtx(), "Open: %s", source) + intlog.Printf(ctx, "Open: %s", source) if db, err = sql.Open(underlyingDriverName, source); err != nil { err = gerror.WrapCodef( gcode.CodeDbOperationError, err, diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 8878d0682..ed54bd65f 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -8,6 +8,7 @@ package gdb import ( "bytes" + "context" "fmt" "reflect" "regexp" @@ -364,7 +365,7 @@ func isKeyValueCanBeOmitEmpty(omitEmpty bool, whereType string, key, value inter } // formatWhereHolder formats where statement and its arguments for `Where` and `Having` statements. -func formatWhereHolder(db DB, in formatWhereHolderInput) (newWhere string, newArgs []interface{}) { +func formatWhereHolder(ctx context.Context, db DB, in formatWhereHolderInput) (newWhere string, newArgs []interface{}) { var ( buffer = bytes.NewBuffer(nil) reflectInfo = reflection.OriginValueAndKind(in.Where) @@ -393,7 +394,7 @@ func formatWhereHolder(db DB, in formatWhereHolderInput) (newWhere string, newAr } case reflect.Struct: - // If the `where` parameter is DO struct, it then adds `OmitNil` option for this condition, + // If the `where` parameter is `DO` struct, it then adds `OmitNil` option for this condition, // which will filter all nil parameters in `where`. if isDoStruct(in.Where) { in.OmitNil = true @@ -523,7 +524,9 @@ func formatWhereHolder(db DB, in formatWhereHolderInput) (newWhere string, newAr whereStr, _ = gregex.ReplaceStringFunc(`(\?)`, whereStr, func(s string) string { index++ if i+len(newArgs) == index { - sqlWithHolder, holderArgs := model.getFormattedSqlAndArgs(model.GetCtx(), queryTypeNormal, false) + sqlWithHolder, holderArgs := model.getFormattedSqlAndArgs( + ctx, queryTypeNormal, false, + ) newArgs = append(newArgs, holderArgs...) // Automatically adding the brackets. return "(" + sqlWithHolder + ")" diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index 51e4bc2ae..6d72f9b21 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -93,6 +93,7 @@ const ( // db.Model("? AS a, ? AS b", subQuery1, subQuery2) func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model { var ( + ctx = c.db.GetCtx() tableStr string tableName string extraArgs []interface{} @@ -105,7 +106,7 @@ func (c *Core) Model(tableNameQueryOrStruct ...interface{}) *Model { Where: conditionStr, Args: tableNameQueryOrStruct[1:], } - tableStr, extraArgs = formatWhereHolder(c.db, formatWhereHolderInput{ + tableStr, extraArgs = formatWhereHolder(ctx, c.db, formatWhereHolderInput{ ModelWhereHolder: whereHolder, OmitNil: false, OmitEmpty: false, diff --git a/database/gdb/gdb_model_cache.go b/database/gdb/gdb_model_cache.go index e9f66aa52..7ed513d45 100644 --- a/database/gdb/gdb_model_cache.go +++ b/database/gdb/gdb_model_cache.go @@ -7,6 +7,7 @@ package gdb import ( + "context" "time" "github.com/gogf/gf/v2/internal/intlog" @@ -44,11 +45,9 @@ func (m *Model) Cache(option CacheOption) *Model { // checkAndRemoveCache checks and removes the cache in insert/update/delete statement if // cache feature is enabled. -func (m *Model) checkAndRemoveCache() { +func (m *Model) checkAndRemoveCache(ctx context.Context) { if m.cacheEnabled && m.cacheOption.Duration < 0 && len(m.cacheOption.Name) > 0 { - ctx := m.GetCtx() - _, err := m.db.GetCache().Remove(ctx, m.cacheOption.Name) - if err != nil { + if _, err := m.db.GetCache().Remove(ctx, m.cacheOption.Name); err != nil { intlog.Errorf(ctx, `%+v`, err) } } diff --git a/database/gdb/gdb_model_delete.go b/database/gdb/gdb_model_delete.go index 68bcfdb6f..b6406e5df 100644 --- a/database/gdb/gdb_model_delete.go +++ b/database/gdb/gdb_model_delete.go @@ -20,17 +20,18 @@ import ( // The optional parameter `where` is the same as the parameter of Model.Where function, // see Model.Where. func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) { + var ctx = m.GetCtx() if len(where) > 0 { return m.Where(where[0], where[1:]...).Delete() } defer func() { if err == nil { - m.checkAndRemoveCache() + m.checkAndRemoveCache(ctx) } }() var ( fieldNameDelete = m.getSoftFieldNameDeleted() - conditionWhere, conditionExtra, conditionArgs = m.formatCondition(false, false) + conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false) ) // Soft deleting. if !m.unscoped && fieldNameDelete != "" { @@ -47,7 +48,7 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) { Condition: conditionWhere + conditionExtra, Args: append([]interface{}{gtime.Now().String()}, conditionArgs...), } - return in.Next(m.GetCtx()) + return in.Next(ctx) } conditionStr := conditionWhere + conditionExtra if !gstr.ContainsI(conditionStr, " WHERE ") { @@ -69,5 +70,5 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) { Condition: conditionStr, Args: conditionArgs, } - return in.Next(m.GetCtx()) + return in.Next(ctx) } diff --git a/database/gdb/gdb_model_insert.go b/database/gdb/gdb_model_insert.go index ad2bfed60..12300d5cf 100644 --- a/database/gdb/gdb_model_insert.go +++ b/database/gdb/gdb_model_insert.go @@ -7,6 +7,7 @@ package gdb import ( + "context" "database/sql" "reflect" @@ -38,7 +39,10 @@ func (m *Model) Batch(batch int) *Model { // Data(g.Map{"uid": 10000, "name":"john"}) // Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}). func (m *Model) Data(data ...interface{}) *Model { - model := m.getModel() + var ( + ctx = m.GetCtx() + model = m.getModel() + ) if len(data) > 1 { if s := gconv.String(data[0]); gstr.Contains(s, "?") { model.data = s @@ -83,7 +87,7 @@ func (m *Model) Data(data ...interface{}) *Model { } list := make(List, reflectInfo.OriginValue.Len()) for i := 0; i < reflectInfo.OriginValue.Len(); i++ { - list[i] = m.db.ConvertDataForRecord(m.GetCtx(), reflectInfo.OriginValue.Index(i).Interface()) + list[i] = m.db.ConvertDataForRecord(ctx, reflectInfo.OriginValue.Index(i).Interface()) } model.data = list @@ -100,15 +104,15 @@ func (m *Model) Data(data ...interface{}) *Model { list = make(List, len(array)) ) for i := 0; i < len(array); i++ { - list[i] = m.db.ConvertDataForRecord(m.GetCtx(), array[i]) + list[i] = m.db.ConvertDataForRecord(ctx, array[i]) } model.data = list } else { - model.data = m.db.ConvertDataForRecord(m.GetCtx(), data[0]) + model.data = m.db.ConvertDataForRecord(ctx, data[0]) } case reflect.Map: - model.data = m.db.ConvertDataForRecord(m.GetCtx(), data[0]) + model.data = m.db.ConvertDataForRecord(ctx, data[0]) default: model.data = data[0] @@ -164,18 +168,20 @@ func (m *Model) OnDuplicateEx(onDuplicateEx ...interface{}) *Model { // The optional parameter `data` is the same as the parameter of Model.Data function, // see Model.Data. func (m *Model) Insert(data ...interface{}) (result sql.Result, err error) { + var ctx = m.GetCtx() if len(data) > 0 { return m.Data(data...).Insert() } - return m.doInsertWithOption(InsertOptionDefault) + return m.doInsertWithOption(ctx, InsertOptionDefault) } // InsertAndGetId performs action Insert and returns the last insert id that automatically generated. func (m *Model) InsertAndGetId(data ...interface{}) (lastInsertId int64, err error) { + var ctx = m.GetCtx() if len(data) > 0 { return m.Data(data...).InsertAndGetId() } - result, err := m.doInsertWithOption(InsertOptionDefault) + result, err := m.doInsertWithOption(ctx, InsertOptionDefault) if err != nil { return 0, err } @@ -186,20 +192,22 @@ func (m *Model) InsertAndGetId(data ...interface{}) (lastInsertId int64, err err // The optional parameter `data` is the same as the parameter of Model.Data function, // see Model.Data. func (m *Model) InsertIgnore(data ...interface{}) (result sql.Result, err error) { + var ctx = m.GetCtx() if len(data) > 0 { return m.Data(data...).InsertIgnore() } - return m.doInsertWithOption(InsertOptionIgnore) + return m.doInsertWithOption(ctx, InsertOptionIgnore) } // Replace does "REPLACE INTO ..." statement for the model. // The optional parameter `data` is the same as the parameter of Model.Data function, // see Model.Data. func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) { + var ctx = m.GetCtx() if len(data) > 0 { return m.Data(data...).Replace() } - return m.doInsertWithOption(InsertOptionReplace) + return m.doInsertWithOption(ctx, InsertOptionReplace) } // Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the model. @@ -209,17 +217,18 @@ func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) { // It updates the record if there's primary or unique index in the saving data, // or else it inserts a new record into the table. func (m *Model) Save(data ...interface{}) (result sql.Result, err error) { + var ctx = m.GetCtx() if len(data) > 0 { return m.Data(data...).Save() } - return m.doInsertWithOption(InsertOptionSave) + return m.doInsertWithOption(ctx, InsertOptionSave) } // doInsertWithOption inserts data with option parameter. -func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err error) { +func (m *Model) doInsertWithOption(ctx context.Context, insertOption int) (result sql.Result, err error) { defer func() { if err == nil { - m.checkAndRemoveCache() + m.checkAndRemoveCache(ctx) } }() if m.data == nil { @@ -246,11 +255,11 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err case List: list = value for i, v := range list { - list[i] = m.db.ConvertDataForRecord(m.GetCtx(), v) + list[i] = m.db.ConvertDataForRecord(ctx, v) } case Map: - list = List{m.db.ConvertDataForRecord(m.GetCtx(), value)} + list = List{m.db.ConvertDataForRecord(ctx, value)} default: reflectInfo := reflection.OriginValueAndKind(newData) @@ -259,21 +268,21 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err case reflect.Slice, reflect.Array: list = make(List, reflectInfo.OriginValue.Len()) for i := 0; i < reflectInfo.OriginValue.Len(); i++ { - list[i] = m.db.ConvertDataForRecord(m.GetCtx(), reflectInfo.OriginValue.Index(i).Interface()) + list[i] = m.db.ConvertDataForRecord(ctx, reflectInfo.OriginValue.Index(i).Interface()) } case reflect.Map: - list = List{m.db.ConvertDataForRecord(m.GetCtx(), value)} + list = List{m.db.ConvertDataForRecord(ctx, value)} case reflect.Struct: if v, ok := value.(iInterfaces); ok { array := v.Interfaces() list = make(List, len(array)) for i := 0; i < len(array); i++ { - list[i] = m.db.ConvertDataForRecord(m.GetCtx(), array[i]) + list[i] = m.db.ConvertDataForRecord(ctx, array[i]) } } else { - list = List{m.db.ConvertDataForRecord(m.GetCtx(), value)} + list = List{m.db.ConvertDataForRecord(ctx, value)} } default: @@ -323,7 +332,7 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err Data: list, Option: doInsertOption, } - return in.Next(m.GetCtx()) + return in.Next(ctx) } func (m *Model) formatDoInsertOption(insertOption int, columnNames []string) (option DoInsertOption, err error) { diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index 057b2f146..b2b11b57f 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -30,7 +30,8 @@ import ( // The optional parameter `where` is the same as the parameter of Model.Where function, // see Model.Where. func (m *Model) All(where ...interface{}) (Result, error) { - return m.doGetAll(m.GetCtx(), false, where...) + var ctx = m.GetCtx() + return m.doGetAll(ctx, false, where...) } // doGetAll does "SELECT FROM ..." statement for the model. @@ -44,7 +45,7 @@ func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{}) if len(where) > 0 { return m.Where(where[0], where[1:]...).All() } - sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(m.GetCtx(), queryTypeNormal, limit1) + sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(ctx, queryTypeNormal, limit1) return m.doGetAllBySql(ctx, queryTypeNormal, sqlWithHolder, holderArgs...) } @@ -131,10 +132,11 @@ func (m *Model) Chunk(size int, handler ChunkHandler) { // The optional parameter `where` is the same as the parameter of Model.Where function, // see Model.Where. func (m *Model) One(where ...interface{}) (Record, error) { + var ctx = m.GetCtx() if len(where) > 0 { return m.Where(where[0], where[1:]...).One() } - all, err := m.doGetAll(m.GetCtx(), true) + all, err := m.doGetAll(ctx, true) if err != nil { return nil, err } @@ -151,6 +153,7 @@ func (m *Model) One(where ...interface{}) (Record, error) { // and fieldsAndWhere[1:] is treated as where condition fields. // Also see Model.Fields and Model.Where functions. func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) { + var ctx = m.GetCtx() if len(fieldsAndWhere) > 0 { if len(fieldsAndWhere) > 2 { return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1], fieldsAndWhere[2:]...).Value() @@ -163,7 +166,6 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) { var ( all Result err error - ctx = m.GetCtx() ) if all, err = m.doGetAll(ctx, true); err != nil { return nil, err @@ -373,11 +375,11 @@ func (m *Model) ScanList(structSlicePointer interface{}, bindToAttrName string, // The optional parameter `where` is the same as the parameter of Model.Where function, // see Model.Where. func (m *Model) Count(where ...interface{}) (int, error) { + var ctx = m.GetCtx() if len(where) > 0 { return m.Where(where[0], where[1:]...).Count() } var ( - ctx = m.GetCtx() sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeCount, false) all, err = m.doGetAllBySql(ctx, queryTypeCount, sqlWithHolder, holderArgs...) ) @@ -566,7 +568,7 @@ func (m *Model) doGetAllBySql(ctx context.Context, queryType int, sql string, ar if cacheKey != "" && err == nil { if m.cacheOption.Duration < 0 { if _, errCache := cacheObj.Remove(ctx, cacheKey); errCache != nil { - intlog.Errorf(m.GetCtx(), `%+v`, errCache) + intlog.Errorf(ctx, `%+v`, errCache) } } else { // In case of Cache Penetration. @@ -574,7 +576,7 @@ func (m *Model) doGetAllBySql(ctx context.Context, queryType int, sql string, ar result = Result{} } if errCache := cacheObj.Set(ctx, cacheKey, result, m.cacheOption.Duration); errCache != nil { - intlog.Errorf(m.GetCtx(), `%+v`, errCache) + intlog.Errorf(ctx, `%+v`, errCache) } } } @@ -595,7 +597,7 @@ func (m *Model) getFormattedSqlAndArgs(ctx context.Context, queryType int, limit sqlWithHolder = fmt.Sprintf("SELECT %s FROM (%s) AS T", queryFields, m.rawSql) return sqlWithHolder, nil } - conditionWhere, conditionExtra, conditionArgs := m.formatCondition(false, true) + conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, false, true) sqlWithHolder = fmt.Sprintf("SELECT %s FROM %s%s", queryFields, m.tables, conditionWhere+conditionExtra) if len(m.groupBy) > 0 { sqlWithHolder = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", sqlWithHolder) @@ -603,7 +605,7 @@ func (m *Model) getFormattedSqlAndArgs(ctx context.Context, queryType int, limit return sqlWithHolder, conditionArgs default: - conditionWhere, conditionExtra, conditionArgs := m.formatCondition(limit1, false) + conditionWhere, conditionExtra, conditionArgs := m.formatCondition(ctx, limit1, false) // Raw SQL Model, especially for UNION/UNION ALL featured SQL. if m.rawSql != "" { sqlWithHolder = fmt.Sprintf( @@ -627,7 +629,7 @@ func (m *Model) getFormattedSqlAndArgs(ctx context.Context, queryType int, limit // Note that this function does not change any attribute value of the `m`. // // The parameter `limit1` specifies whether limits querying only one record if m.limit is not set. -func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWhere string, conditionExtra string, conditionArgs []interface{}) { +func (m *Model) formatCondition(ctx context.Context, limit1 bool, isCountStatement bool) (conditionWhere string, conditionExtra string, conditionArgs []interface{}) { autoPrefix := "" if gstr.Contains(m.tables, " JOIN ") { autoPrefix = m.db.GetCore().QuoteWord( @@ -647,7 +649,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh switch holder.Operator { case whereHolderOperatorWhere: if conditionWhere == "" { - newWhere, newArgs := formatWhereHolder(m.db, formatWhereHolderInput{ + newWhere, newArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{ ModelWhereHolder: holder, OmitNil: m.option&optionOmitNilWhere > 0, OmitEmpty: m.option&optionOmitEmptyWhere > 0, @@ -663,7 +665,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh fallthrough case whereHolderOperatorAnd: - newWhere, newArgs := formatWhereHolder(m.db, formatWhereHolderInput{ + newWhere, newArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{ ModelWhereHolder: holder, OmitNil: m.option&optionOmitNilWhere > 0, OmitEmpty: m.option&optionOmitEmptyWhere > 0, @@ -682,7 +684,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh } case whereHolderOperatorOr: - newWhere, newArgs := formatWhereHolder(m.db, formatWhereHolderInput{ + newWhere, newArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{ ModelWhereHolder: holder, OmitNil: m.option&optionOmitNilWhere > 0, OmitEmpty: m.option&optionOmitEmptyWhere > 0, @@ -733,7 +735,7 @@ func (m *Model) formatCondition(limit1 bool, isCountStatement bool) (conditionWh Args: gconv.Interfaces(m.having[1]), Prefix: autoPrefix, } - havingStr, havingArgs := formatWhereHolder(m.db, formatWhereHolderInput{ + havingStr, havingArgs := formatWhereHolder(ctx, m.db, formatWhereHolderInput{ ModelWhereHolder: havingHolder, OmitNil: m.option&optionOmitNilWhere > 0, OmitEmpty: m.option&optionOmitEmptyWhere > 0, diff --git a/database/gdb/gdb_model_update.go b/database/gdb/gdb_model_update.go index f6a39a09f..817f6fa5f 100644 --- a/database/gdb/gdb_model_update.go +++ b/database/gdb/gdb_model_update.go @@ -25,6 +25,7 @@ import ( // and dataAndWhere[1:] is treated as where condition fields. // Also see Model.Data and Model.Where functions. func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err error) { + var ctx = m.GetCtx() if len(dataAndWhere) > 0 { if len(dataAndWhere) > 2 { return m.Data(dataAndWhere[0]).Where(dataAndWhere[1], dataAndWhere[2:]...).Update() @@ -36,7 +37,7 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro } defer func() { if err == nil { - m.checkAndRemoveCache() + m.checkAndRemoveCache(ctx) } }() if m.data == nil { @@ -46,11 +47,11 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro updateData = m.data reflectInfo = reflection.OriginTypeAndKind(updateData) fieldNameUpdate = m.getSoftFieldNameUpdated() - conditionWhere, conditionExtra, conditionArgs = m.formatCondition(false, false) + conditionWhere, conditionExtra, conditionArgs = m.formatCondition(ctx, false, false) ) switch reflectInfo.OriginKind { case reflect.Map, reflect.Struct: - dataMap := m.db.ConvertDataForRecord(m.GetCtx(), m.data) + dataMap := m.db.ConvertDataForRecord(ctx, m.data) // Automatically update the record updating time. if !m.unscoped && fieldNameUpdate != "" { dataMap[fieldNameUpdate] = gtime.Now().String() @@ -89,7 +90,7 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro Condition: conditionStr, Args: m.mergeArguments(conditionArgs), } - return in.Next(m.GetCtx()) + return in.Next(ctx) } // Increment increments a column's value by a given amount.