improve statement features for package gdb

This commit is contained in:
jianchenma
2021-01-26 11:09:50 +08:00
parent 4c6d9f5eff
commit 0baa52afee
5 changed files with 109 additions and 101 deletions

View File

@ -247,6 +247,9 @@ const (
defaultMaxIdleConnCount = 10 // Max idle connection count in pool.
defaultMaxOpenConnCount = 100 // Max open connection count in pool.
defaultMaxConnLifeTime = 30 * time.Second // Max life time for per connection in pool in seconds.
ctxTimeoutTypeExec = iota
ctxTimeoutTypeQuery
ctxTimeoutTypePrepare
)
var (

View File

@ -61,6 +61,32 @@ func (c *Core) GetCtx() context.Context {
return context.Background()
}
// GetCtxTimeout returns the context and cancel function for specified timeout type.
func (c *Core) GetCtxTimeout(timeoutType int, ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = c.DB.GetCtx()
} else {
ctx = context.WithValue(ctx, "WrappedByGetCtxTimeout", nil)
}
switch timeoutType {
case ctxTimeoutTypeExec:
if c.DB.GetConfig().ExecTimeout > 0 {
return context.WithTimeout(ctx, c.DB.GetConfig().ExecTimeout)
}
case ctxTimeoutTypeQuery:
if c.DB.GetConfig().QueryTimeout > 0 {
return context.WithTimeout(ctx, c.DB.GetConfig().QueryTimeout)
}
case ctxTimeoutTypePrepare:
if c.DB.GetConfig().PrepareTimeout > 0 {
return context.WithTimeout(ctx, c.DB.GetConfig().PrepareTimeout)
}
default:
panic(gerror.Newf("invalid context timeout type: %d", timeoutType))
}
return ctx, func() {}
}
// Master creates and returns a connection from master node if master-slave configured.
// It returns the default connection if master-slave not configured.
func (c *Core) Master() (*sql.DB, error) {
@ -241,9 +267,8 @@ func (c *Core) Prepare(sql string, execOnMaster ...bool) (*Stmt, error) {
func (c *Core) DoPrepare(link Link, sql string) (*Stmt, error) {
ctx := c.DB.GetCtx()
if c.GetConfig().PrepareTimeout > 0 {
var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout)
defer cancelFunc()
// DO NOT USE cancel function in prepare statement.
ctx, _ = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout)
}
var (
mTime1 = gtime.TimestampMilli()

View File

@ -9,6 +9,7 @@ package gdb
import (
"context"
"database/sql"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/os/gtime"
)
@ -27,29 +28,47 @@ type Stmt struct {
sql string
}
// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithCancel(ctx)
defer cancelFunc()
if s.core.DB.GetConfig().ExecTimeout > 0 {
var cancelFuncForTimeout context.CancelFunc
ctx, cancelFuncForTimeout = context.WithTimeout(ctx, s.core.DB.GetConfig().ExecTimeout)
const (
stmtTypeExecContext = "Statement.ExecContext"
stmtTypeQueryContext = "Statement.QueryContext"
stmtTypeQueryRowContext = "Statement.QueryRowContext"
)
// doStmtCommit commits statement according to given `stmtType`.
func (s *Stmt) doStmtCommit(stmtType string, ctx context.Context, args ...interface{}) (result interface{}, err error) {
var (
cancelFuncForTimeout context.CancelFunc
timestampMilli1 = gtime.TimestampMilli()
)
switch stmtType {
case stmtTypeExecContext:
ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeExec, ctx)
defer cancelFuncForTimeout()
result, err = s.Stmt.ExecContext(ctx, args...)
case stmtTypeQueryContext:
ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
defer cancelFuncForTimeout()
result, err = s.Stmt.QueryContext(ctx, args...)
case stmtTypeQueryRowContext:
ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
defer cancelFuncForTimeout()
result = s.Stmt.QueryRowContext(ctx, args...)
default:
panic(gerror.Newf(`invalid stmtType: %s`, stmtType))
}
var (
mTime1 = gtime.TimestampMilli()
result, err = s.Stmt.ExecContext(ctx, args...)
mTime2 = gtime.TimestampMilli()
sqlObj = &Sql{
timestampMilli2 = gtime.TimestampMilli()
sqlObj = &Sql{
Sql: s.sql,
Type: "Statement.ExecContext",
Type: stmtType,
Args: args,
Format: FormatSqlWithArgs(s.sql, args),
Error: err,
Start: mTime1,
End: mTime2,
Start: timestampMilli1,
End: timestampMilli2,
Group: s.core.DB.GetGroup(),
}
)
@ -60,37 +79,24 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result
return result, err
}
// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
result, err := s.doStmtCommit(stmtTypeExecContext, ctx, args...)
if result != nil {
return result.(sql.Result), err
}
return nil, err
}
// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithCancel(ctx)
defer cancelFunc()
if s.core.DB.GetConfig().QueryTimeout > 0 {
var cancelFuncForTimeout context.CancelFunc
ctx, cancelFuncForTimeout = context.WithTimeout(ctx, s.core.DB.GetConfig().QueryTimeout)
defer cancelFuncForTimeout()
result, err := s.doStmtCommit(stmtTypeQueryContext, ctx, args...)
if result != nil {
return result.(*sql.Rows), err
}
var (
mTime1 = gtime.TimestampMilli()
rows, err = s.Stmt.QueryContext(ctx, args...)
mTime2 = gtime.TimestampMilli()
sqlObj = &Sql{
Sql: s.sql,
Type: "Statement.QueryContext",
Args: args,
Format: FormatSqlWithArgs(s.sql, args),
Error: err,
Start: mTime1,
End: mTime2,
Group: s.core.DB.GetGroup(),
}
)
s.core.addSqlToTracing(ctx, sqlObj)
if s.core.DB.GetDebug() {
s.core.writeSqlToLogger(sqlObj)
}
return rows, err
return nil, err
}
// QueryRowContext executes a prepared query statement with the given arguments.
@ -100,40 +106,17 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows
// Otherwise, the *Row's Scan scans the first selected row and discards
// the rest.
func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithCancel(ctx)
defer cancelFunc()
if s.core.DB.GetConfig().QueryTimeout > 0 {
var cancelFuncForTimeout context.CancelFunc
ctx, cancelFuncForTimeout = context.WithTimeout(ctx, s.core.DB.GetConfig().QueryTimeout)
defer cancelFuncForTimeout()
result, _ := s.doStmtCommit(stmtTypeQueryRowContext, ctx, args...)
if result != nil {
return result.(*sql.Row)
}
var (
mTime1 = gtime.TimestampMilli()
row = s.Stmt.QueryRowContext(ctx, args...)
mTime2 = gtime.TimestampMilli()
sqlObj = &Sql{
Sql: s.sql,
Type: "Statement.QueryRowContext",
Args: args,
Format: FormatSqlWithArgs(s.sql, args),
Error: nil,
Start: mTime1,
End: mTime2,
Group: s.core.DB.GetGroup(),
}
)
s.core.addSqlToTracing(ctx, sqlObj)
if s.core.DB.GetDebug() {
s.core.writeSqlToLogger(sqlObj)
}
return row
return nil
}
// Exec executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
func (s *Stmt) Exec(args ...interface{}) (sql.Result, error) {
return s.ExecContext(context.Background(), args)
return s.ExecContext(context.Background(), args...)
}
// Query executes a prepared query statement with the given arguments

View File

@ -63,7 +63,7 @@ func Test_DB_Exec(t *testing.T) {
func Test_DB_Prepare(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
st, err := db.Prepare("SELECT 1")
st, err := db.Prepare("SELECT 100")
t.Assert(err, nil)
rows, err := st.Query()
@ -71,10 +71,10 @@ func Test_DB_Prepare(t *testing.T) {
array, err := rows.Columns()
t.Assert(err, nil)
t.Assert(array[0], "1")
t.Assert(array[0], "100")
//err = rows.Close()
//t.Assert(err, nil)
err = rows.Close()
t.Assert(err, nil)
})
}

View File

@ -88,29 +88,26 @@ func Test_TX_Rollback(t *testing.T) {
}
func Test_TX_Prepare(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Error(err)
}
st, err := tx.Prepare("SELECT 100")
if err != nil {
gtest.Error(err)
}
rows, err := st.Query()
if err != nil {
gtest.Error(err)
}
array, err := rows.Columns()
if err != nil {
gtest.Error(err)
}
gtest.Assert(array[0], "100")
if err := rows.Close(); err != nil {
gtest.Error(err)
}
if err := tx.Commit(); err != nil {
gtest.Error(err)
}
gtest.C(t, func(t *gtest.T) {
tx, err := db.Begin()
t.Assert(err, nil)
st, err := tx.Prepare("SELECT 100")
t.Assert(err, nil)
rows, err := st.Query()
t.Assert(err, nil)
array, err := rows.Columns()
t.Assert(err, nil)
t.Assert(array[0], "100")
rows.Close()
t.Assert(err, nil)
tx.Commit()
t.Assert(err, nil)
})
}
func Test_TX_Insert(t *testing.T) {