diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index a2b6b36a2..033a9666d 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -48,9 +48,12 @@ func (c *Core) Ctx(ctx context.Context) DB { } // GetCtx returns the context for current DB. -// Note that it might be nil. +// It returns `context.Background()` is there's no context previously set. func (c *Core) GetCtx() context.Context { - return c.ctx + if c.ctx != nil { + return c.ctx + } + return context.Background() } // Master creates and returns a connection from master node if master-slave configured. @@ -78,15 +81,11 @@ func (c *Core) Query(sql string, args ...interface{}) (rows *sql.Rows, err error // DoQuery commits the sql string and its arguments to underlying driver // through given link object and returns the execution result. func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) { - ctx := c.DB.GetCtx() - if ctx == nil { - ctx = context.Background() - } sql, args = formatSql(sql, args) sql, args = c.DB.HandleSqlBeforeCommit(link, sql, args) if c.DB.GetDebug() { mTime1 := gtime.TimestampMilli() - rows, err = link.QueryContext(ctx, sql, args...) + rows, err = link.QueryContext(c.DB.GetCtx(), sql, args...) mTime2 := gtime.TimestampMilli() s := &Sql{ Sql: sql, @@ -99,7 +98,7 @@ func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Ro } c.writeSqlToLogger(s) } else { - rows, err = link.QueryContext(ctx, sql, args...) + rows, err = link.QueryContext(c.DB.GetCtx(), sql, args...) } if err == nil { return rows, nil @@ -122,16 +121,12 @@ func (c *Core) Exec(sql string, args ...interface{}) (result sql.Result, err err // DoExec commits the sql string and its arguments to underlying driver // through given link object and returns the execution result. func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Result, err error) { - ctx := c.DB.GetCtx() - if ctx == nil { - ctx = context.Background() - } sql, args = formatSql(sql, args) sql, args = c.DB.HandleSqlBeforeCommit(link, sql, args) if c.DB.GetDebug() { mTime1 := gtime.TimestampMilli() if !c.DB.GetDryRun() { - result, err = link.ExecContext(ctx, sql, args...) + result, err = link.ExecContext(c.DB.GetCtx(), sql, args...) } else { result = new(SqlResult) } @@ -148,7 +143,7 @@ func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Re c.writeSqlToLogger(s) } else { if !c.DB.GetDryRun() { - result, err = link.ExecContext(ctx, sql, args...) + result, err = link.ExecContext(c.DB.GetCtx(), sql, args...) } else { result = new(SqlResult) } @@ -183,11 +178,7 @@ func (c *Core) Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error) { // doPrepare calls prepare function on given link object and returns the statement object. func (c *Core) DoPrepare(link Link, sql string) (*sql.Stmt, error) { - ctx := c.DB.GetCtx() - if ctx == nil { - ctx = context.Background() - } - return link.PrepareContext(ctx, sql) + return link.PrepareContext(c.DB.GetCtx(), sql) } // GetAll queries and returns data records from database. diff --git a/database/gdb/gdb_z_mysql_model_test.go b/database/gdb/gdb_z_mysql_model_test.go index 6e24164bd..c52a13919 100644 --- a/database/gdb/gdb_z_mysql_model_test.go +++ b/database/gdb/gdb_z_mysql_model_test.go @@ -2974,23 +2974,48 @@ func Test_Model_Issue1002(t *testing.T) { t.Assert(v.Int(), 1) }) // where + time.Time arguments, UTC. - t1, _ := time.Parse("2006-01-02 15:04:05", "2020-10-27 19:03:32") - t2, _ := time.Parse("2006-01-02 15:04:05", "2020-10-27 19:03:34") gtest.C(t, func(t *gtest.T) { - v, err := db.Table(table).Fields("id").Where("create_time>? and create_time? and create_time? and create_time? and create_time? and create_time? and create_time? and create_time? and create_time? and create_time if any exception occurs ans passes the exception as an error. func TryCatch(try func(), catch ...func(exception error)) { defer func() { - if e := recover(); e != nil && len(catch) > 0 { - catch[0](fmt.Errorf(`%v`, e)) + if exception := recover(); exception != nil && len(catch) > 0 { + if err, ok := exception.(error); ok { + catch[0](err) + } else { + catch[0](fmt.Errorf(`%v`, exception)) + } } }() try()