diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 41846facd..93bea6c13 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -213,9 +213,9 @@ type TableField struct { // Link is a common database function wrapper interface. type Link interface { - Query(sql string, args ...interface{}) (*sql.Rows, error) - Exec(sql string, args ...interface{}) (sql.Result, error) - Prepare(sql string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) } // Counter is the type for update count. diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 14cab14f1..a2b6b36a2 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -78,11 +78,15 @@ func (c *Core) Query(sql string, args ...interface{}) (rows *sql.Rows, err error // DoQuery commits the sql string and its arguments to underlying driver // through given link object and returns the execution result. func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) { + ctx := c.DB.GetCtx() + if ctx == nil { + ctx = context.Background() + } sql, args = formatSql(sql, args) sql, args = c.DB.HandleSqlBeforeCommit(link, sql, args) if c.DB.GetDebug() { mTime1 := gtime.TimestampMilli() - rows, err = link.Query(sql, args...) + rows, err = link.QueryContext(ctx, sql, args...) mTime2 := gtime.TimestampMilli() s := &Sql{ Sql: sql, @@ -95,7 +99,7 @@ func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Ro } c.writeSqlToLogger(s) } else { - rows, err = link.Query(sql, args...) + rows, err = link.QueryContext(ctx, sql, args...) } if err == nil { return rows, nil @@ -118,12 +122,16 @@ func (c *Core) Exec(sql string, args ...interface{}) (result sql.Result, err err // DoExec commits the sql string and its arguments to underlying driver // through given link object and returns the execution result. func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Result, err error) { + ctx := c.DB.GetCtx() + if ctx == nil { + ctx = context.Background() + } sql, args = formatSql(sql, args) sql, args = c.DB.HandleSqlBeforeCommit(link, sql, args) if c.DB.GetDebug() { mTime1 := gtime.TimestampMilli() if !c.DB.GetDryRun() { - result, err = link.Exec(sql, args...) + result, err = link.ExecContext(ctx, sql, args...) } else { result = new(SqlResult) } @@ -140,7 +148,7 @@ func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Re c.writeSqlToLogger(s) } else { if !c.DB.GetDryRun() { - result, err = link.Exec(sql, args...) + result, err = link.ExecContext(ctx, sql, args...) } else { result = new(SqlResult) } @@ -157,8 +165,10 @@ func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Re // The parameter specifies whether executing the sql on master node, // or else it executes the sql on slave node if master-slave configured. func (c *Core) Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error) { - err := (error)(nil) - link := (Link)(nil) + var ( + err error + link Link + ) if len(execOnMaster) > 0 && execOnMaster[0] { if link, err = c.DB.Master(); err != nil { return nil, err @@ -173,7 +183,11 @@ func (c *Core) Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error) { // doPrepare calls prepare function on given link object and returns the statement object. func (c *Core) DoPrepare(link Link, sql string) (*sql.Stmt, error) { - return link.Prepare(sql) + ctx := c.DB.GetCtx() + if ctx == nil { + ctx = context.Background() + } + return link.PrepareContext(ctx, sql) } // GetAll queries and returns data records from database. diff --git a/database/gdb/gdb_z_mysql_method_test.go b/database/gdb/gdb_z_mysql_method_test.go index 2b26baa40..461cd2f98 100644 --- a/database/gdb/gdb_z_mysql_method_test.go +++ b/database/gdb/gdb_z_mysql_method_test.go @@ -1437,8 +1437,7 @@ func Test_DB_UpdateCounter(t *testing.T) { Value: 1, } updateData := g.Map{ - "views": gdbCounter, - "updated_time": gtime.Now().Unix(), + "views": gdbCounter, } result, err := db.Update(tableName, updateData, "id", 1) t.Assert(err, nil)