From 0baa52afee5d8ff4a94d13a0b7dd77e4413a0800 Mon Sep 17 00:00:00 2001 From: jianchenma Date: Tue, 26 Jan 2021 11:09:50 +0800 Subject: [PATCH] improve statement features for package gdb --- database/gdb/gdb.go | 3 + database/gdb/gdb_core.go | 31 ++++- database/gdb/gdb_statement.go | 125 ++++++++----------- database/gdb/gdb_z_mysql_method_test.go | 8 +- database/gdb/gdb_z_mysql_transaction_test.go | 43 +++---- 5 files changed, 109 insertions(+), 101 deletions(-) diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index aebb0f5fa..1e3665e2d 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -247,6 +247,9 @@ const ( defaultMaxIdleConnCount = 10 // Max idle connection count in pool. defaultMaxOpenConnCount = 100 // Max open connection count in pool. defaultMaxConnLifeTime = 30 * time.Second // Max life time for per connection in pool in seconds. + ctxTimeoutTypeExec = iota + ctxTimeoutTypeQuery + ctxTimeoutTypePrepare ) var ( diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 1d28ff4eb..797750602 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -61,6 +61,32 @@ func (c *Core) GetCtx() context.Context { return context.Background() } +// GetCtxTimeout returns the context and cancel function for specified timeout type. +func (c *Core) GetCtxTimeout(timeoutType int, ctx context.Context) (context.Context, context.CancelFunc) { + if ctx == nil { + ctx = c.DB.GetCtx() + } else { + ctx = context.WithValue(ctx, "WrappedByGetCtxTimeout", nil) + } + switch timeoutType { + case ctxTimeoutTypeExec: + if c.DB.GetConfig().ExecTimeout > 0 { + return context.WithTimeout(ctx, c.DB.GetConfig().ExecTimeout) + } + case ctxTimeoutTypeQuery: + if c.DB.GetConfig().QueryTimeout > 0 { + return context.WithTimeout(ctx, c.DB.GetConfig().QueryTimeout) + } + case ctxTimeoutTypePrepare: + if c.DB.GetConfig().PrepareTimeout > 0 { + return context.WithTimeout(ctx, c.DB.GetConfig().PrepareTimeout) + } + default: + panic(gerror.Newf("invalid context timeout type: %d", timeoutType)) + } + return ctx, func() {} +} + // Master creates and returns a connection from master node if master-slave configured. // It returns the default connection if master-slave not configured. func (c *Core) Master() (*sql.DB, error) { @@ -241,9 +267,8 @@ func (c *Core) Prepare(sql string, execOnMaster ...bool) (*Stmt, error) { func (c *Core) DoPrepare(link Link, sql string) (*Stmt, error) { ctx := c.DB.GetCtx() if c.GetConfig().PrepareTimeout > 0 { - var cancelFunc context.CancelFunc - ctx, cancelFunc = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout) - defer cancelFunc() + // DO NOT USE cancel function in prepare statement. + ctx, _ = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout) } var ( mTime1 = gtime.TimestampMilli() diff --git a/database/gdb/gdb_statement.go b/database/gdb/gdb_statement.go index 33e74cbd9..d861141fd 100644 --- a/database/gdb/gdb_statement.go +++ b/database/gdb/gdb_statement.go @@ -9,6 +9,7 @@ package gdb import ( "context" "database/sql" + "github.com/gogf/gf/errors/gerror" "github.com/gogf/gf/os/gtime" ) @@ -27,29 +28,47 @@ type Stmt struct { sql string } -// ExecContext executes a prepared statement with the given arguments and -// returns a Result summarizing the effect of the statement. -func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { - var cancelFunc context.CancelFunc - ctx, cancelFunc = context.WithCancel(ctx) - defer cancelFunc() - if s.core.DB.GetConfig().ExecTimeout > 0 { - var cancelFuncForTimeout context.CancelFunc - ctx, cancelFuncForTimeout = context.WithTimeout(ctx, s.core.DB.GetConfig().ExecTimeout) +const ( + stmtTypeExecContext = "Statement.ExecContext" + stmtTypeQueryContext = "Statement.QueryContext" + stmtTypeQueryRowContext = "Statement.QueryRowContext" +) + +// doStmtCommit commits statement according to given `stmtType`. +func (s *Stmt) doStmtCommit(stmtType string, ctx context.Context, args ...interface{}) (result interface{}, err error) { + var ( + cancelFuncForTimeout context.CancelFunc + timestampMilli1 = gtime.TimestampMilli() + ) + switch stmtType { + case stmtTypeExecContext: + ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeExec, ctx) defer cancelFuncForTimeout() + result, err = s.Stmt.ExecContext(ctx, args...) + + case stmtTypeQueryContext: + ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeQuery, ctx) + defer cancelFuncForTimeout() + result, err = s.Stmt.QueryContext(ctx, args...) + + case stmtTypeQueryRowContext: + ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeQuery, ctx) + defer cancelFuncForTimeout() + result = s.Stmt.QueryRowContext(ctx, args...) + + default: + panic(gerror.Newf(`invalid stmtType: %s`, stmtType)) } var ( - mTime1 = gtime.TimestampMilli() - result, err = s.Stmt.ExecContext(ctx, args...) - mTime2 = gtime.TimestampMilli() - sqlObj = &Sql{ + timestampMilli2 = gtime.TimestampMilli() + sqlObj = &Sql{ Sql: s.sql, - Type: "Statement.ExecContext", + Type: stmtType, Args: args, Format: FormatSqlWithArgs(s.sql, args), Error: err, - Start: mTime1, - End: mTime2, + Start: timestampMilli1, + End: timestampMilli2, Group: s.core.DB.GetGroup(), } ) @@ -60,37 +79,24 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result return result, err } +// ExecContext executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + result, err := s.doStmtCommit(stmtTypeExecContext, ctx, args...) + if result != nil { + return result.(sql.Result), err + } + return nil, err +} + // QueryContext executes a prepared query statement with the given arguments // and returns the query results as a *Rows. func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { - var cancelFunc context.CancelFunc - ctx, cancelFunc = context.WithCancel(ctx) - defer cancelFunc() - if s.core.DB.GetConfig().QueryTimeout > 0 { - var cancelFuncForTimeout context.CancelFunc - ctx, cancelFuncForTimeout = context.WithTimeout(ctx, s.core.DB.GetConfig().QueryTimeout) - defer cancelFuncForTimeout() + result, err := s.doStmtCommit(stmtTypeQueryContext, ctx, args...) + if result != nil { + return result.(*sql.Rows), err } - var ( - mTime1 = gtime.TimestampMilli() - rows, err = s.Stmt.QueryContext(ctx, args...) - mTime2 = gtime.TimestampMilli() - sqlObj = &Sql{ - Sql: s.sql, - Type: "Statement.QueryContext", - Args: args, - Format: FormatSqlWithArgs(s.sql, args), - Error: err, - Start: mTime1, - End: mTime2, - Group: s.core.DB.GetGroup(), - } - ) - s.core.addSqlToTracing(ctx, sqlObj) - if s.core.DB.GetDebug() { - s.core.writeSqlToLogger(sqlObj) - } - return rows, err + return nil, err } // QueryRowContext executes a prepared query statement with the given arguments. @@ -100,40 +106,17 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows // Otherwise, the *Row's Scan scans the first selected row and discards // the rest. func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { - var cancelFunc context.CancelFunc - ctx, cancelFunc = context.WithCancel(ctx) - defer cancelFunc() - if s.core.DB.GetConfig().QueryTimeout > 0 { - var cancelFuncForTimeout context.CancelFunc - ctx, cancelFuncForTimeout = context.WithTimeout(ctx, s.core.DB.GetConfig().QueryTimeout) - defer cancelFuncForTimeout() + result, _ := s.doStmtCommit(stmtTypeQueryRowContext, ctx, args...) + if result != nil { + return result.(*sql.Row) } - var ( - mTime1 = gtime.TimestampMilli() - row = s.Stmt.QueryRowContext(ctx, args...) - mTime2 = gtime.TimestampMilli() - sqlObj = &Sql{ - Sql: s.sql, - Type: "Statement.QueryRowContext", - Args: args, - Format: FormatSqlWithArgs(s.sql, args), - Error: nil, - Start: mTime1, - End: mTime2, - Group: s.core.DB.GetGroup(), - } - ) - s.core.addSqlToTracing(ctx, sqlObj) - if s.core.DB.GetDebug() { - s.core.writeSqlToLogger(sqlObj) - } - return row + return nil } // Exec executes a prepared statement with the given arguments and // returns a Result summarizing the effect of the statement. func (s *Stmt) Exec(args ...interface{}) (sql.Result, error) { - return s.ExecContext(context.Background(), args) + return s.ExecContext(context.Background(), args...) } // Query executes a prepared query statement with the given arguments diff --git a/database/gdb/gdb_z_mysql_method_test.go b/database/gdb/gdb_z_mysql_method_test.go index 4f45a7f53..295b0d111 100644 --- a/database/gdb/gdb_z_mysql_method_test.go +++ b/database/gdb/gdb_z_mysql_method_test.go @@ -63,7 +63,7 @@ func Test_DB_Exec(t *testing.T) { func Test_DB_Prepare(t *testing.T) { gtest.C(t, func(t *gtest.T) { - st, err := db.Prepare("SELECT 1") + st, err := db.Prepare("SELECT 100") t.Assert(err, nil) rows, err := st.Query() @@ -71,10 +71,10 @@ func Test_DB_Prepare(t *testing.T) { array, err := rows.Columns() t.Assert(err, nil) - t.Assert(array[0], "1") + t.Assert(array[0], "100") - //err = rows.Close() - //t.Assert(err, nil) + err = rows.Close() + t.Assert(err, nil) }) } diff --git a/database/gdb/gdb_z_mysql_transaction_test.go b/database/gdb/gdb_z_mysql_transaction_test.go index 8bd97cdee..be10ab716 100644 --- a/database/gdb/gdb_z_mysql_transaction_test.go +++ b/database/gdb/gdb_z_mysql_transaction_test.go @@ -88,29 +88,26 @@ func Test_TX_Rollback(t *testing.T) { } func Test_TX_Prepare(t *testing.T) { - tx, err := db.Begin() - if err != nil { - gtest.Error(err) - } - st, err := tx.Prepare("SELECT 100") - if err != nil { - gtest.Error(err) - } - rows, err := st.Query() - if err != nil { - gtest.Error(err) - } - array, err := rows.Columns() - if err != nil { - gtest.Error(err) - } - gtest.Assert(array[0], "100") - if err := rows.Close(); err != nil { - gtest.Error(err) - } - if err := tx.Commit(); err != nil { - gtest.Error(err) - } + gtest.C(t, func(t *gtest.T) { + tx, err := db.Begin() + t.Assert(err, nil) + + st, err := tx.Prepare("SELECT 100") + t.Assert(err, nil) + + rows, err := st.Query() + t.Assert(err, nil) + + array, err := rows.Columns() + t.Assert(err, nil) + t.Assert(array[0], "100") + + rows.Close() + t.Assert(err, nil) + + tx.Commit() + t.Assert(err, nil) + }) } func Test_TX_Insert(t *testing.T) {