mirror of
https://gitee.com/johng/gf
synced 2026-06-06 16:21:40 +08:00
improve statement features for package gdb
This commit is contained in:
@ -247,6 +247,9 @@ const (
|
||||
defaultMaxIdleConnCount = 10 // Max idle connection count in pool.
|
||||
defaultMaxOpenConnCount = 100 // Max open connection count in pool.
|
||||
defaultMaxConnLifeTime = 30 * time.Second // Max life time for per connection in pool in seconds.
|
||||
ctxTimeoutTypeExec = iota
|
||||
ctxTimeoutTypeQuery
|
||||
ctxTimeoutTypePrepare
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@ -61,6 +61,32 @@ func (c *Core) GetCtx() context.Context {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// GetCtxTimeout returns the context and cancel function for specified timeout type.
|
||||
func (c *Core) GetCtxTimeout(timeoutType int, ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = c.DB.GetCtx()
|
||||
} else {
|
||||
ctx = context.WithValue(ctx, "WrappedByGetCtxTimeout", nil)
|
||||
}
|
||||
switch timeoutType {
|
||||
case ctxTimeoutTypeExec:
|
||||
if c.DB.GetConfig().ExecTimeout > 0 {
|
||||
return context.WithTimeout(ctx, c.DB.GetConfig().ExecTimeout)
|
||||
}
|
||||
case ctxTimeoutTypeQuery:
|
||||
if c.DB.GetConfig().QueryTimeout > 0 {
|
||||
return context.WithTimeout(ctx, c.DB.GetConfig().QueryTimeout)
|
||||
}
|
||||
case ctxTimeoutTypePrepare:
|
||||
if c.DB.GetConfig().PrepareTimeout > 0 {
|
||||
return context.WithTimeout(ctx, c.DB.GetConfig().PrepareTimeout)
|
||||
}
|
||||
default:
|
||||
panic(gerror.Newf("invalid context timeout type: %d", timeoutType))
|
||||
}
|
||||
return ctx, func() {}
|
||||
}
|
||||
|
||||
// Master creates and returns a connection from master node if master-slave configured.
|
||||
// It returns the default connection if master-slave not configured.
|
||||
func (c *Core) Master() (*sql.DB, error) {
|
||||
@ -241,9 +267,8 @@ func (c *Core) Prepare(sql string, execOnMaster ...bool) (*Stmt, error) {
|
||||
func (c *Core) DoPrepare(link Link, sql string) (*Stmt, error) {
|
||||
ctx := c.DB.GetCtx()
|
||||
if c.GetConfig().PrepareTimeout > 0 {
|
||||
var cancelFunc context.CancelFunc
|
||||
ctx, cancelFunc = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout)
|
||||
defer cancelFunc()
|
||||
// DO NOT USE cancel function in prepare statement.
|
||||
ctx, _ = context.WithTimeout(ctx, c.GetConfig().PrepareTimeout)
|
||||
}
|
||||
var (
|
||||
mTime1 = gtime.TimestampMilli()
|
||||
|
||||
@ -9,6 +9,7 @@ package gdb
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"github.com/gogf/gf/errors/gerror"
|
||||
"github.com/gogf/gf/os/gtime"
|
||||
)
|
||||
|
||||
@ -27,29 +28,47 @@ type Stmt struct {
|
||||
sql string
|
||||
}
|
||||
|
||||
// ExecContext executes a prepared statement with the given arguments and
|
||||
// returns a Result summarizing the effect of the statement.
|
||||
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
||||
var cancelFunc context.CancelFunc
|
||||
ctx, cancelFunc = context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
if s.core.DB.GetConfig().ExecTimeout > 0 {
|
||||
var cancelFuncForTimeout context.CancelFunc
|
||||
ctx, cancelFuncForTimeout = context.WithTimeout(ctx, s.core.DB.GetConfig().ExecTimeout)
|
||||
const (
|
||||
stmtTypeExecContext = "Statement.ExecContext"
|
||||
stmtTypeQueryContext = "Statement.QueryContext"
|
||||
stmtTypeQueryRowContext = "Statement.QueryRowContext"
|
||||
)
|
||||
|
||||
// doStmtCommit commits statement according to given `stmtType`.
|
||||
func (s *Stmt) doStmtCommit(stmtType string, ctx context.Context, args ...interface{}) (result interface{}, err error) {
|
||||
var (
|
||||
cancelFuncForTimeout context.CancelFunc
|
||||
timestampMilli1 = gtime.TimestampMilli()
|
||||
)
|
||||
switch stmtType {
|
||||
case stmtTypeExecContext:
|
||||
ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeExec, ctx)
|
||||
defer cancelFuncForTimeout()
|
||||
result, err = s.Stmt.ExecContext(ctx, args...)
|
||||
|
||||
case stmtTypeQueryContext:
|
||||
ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
|
||||
defer cancelFuncForTimeout()
|
||||
result, err = s.Stmt.QueryContext(ctx, args...)
|
||||
|
||||
case stmtTypeQueryRowContext:
|
||||
ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeQuery, ctx)
|
||||
defer cancelFuncForTimeout()
|
||||
result = s.Stmt.QueryRowContext(ctx, args...)
|
||||
|
||||
default:
|
||||
panic(gerror.Newf(`invalid stmtType: %s`, stmtType))
|
||||
}
|
||||
var (
|
||||
mTime1 = gtime.TimestampMilli()
|
||||
result, err = s.Stmt.ExecContext(ctx, args...)
|
||||
mTime2 = gtime.TimestampMilli()
|
||||
sqlObj = &Sql{
|
||||
timestampMilli2 = gtime.TimestampMilli()
|
||||
sqlObj = &Sql{
|
||||
Sql: s.sql,
|
||||
Type: "Statement.ExecContext",
|
||||
Type: stmtType,
|
||||
Args: args,
|
||||
Format: FormatSqlWithArgs(s.sql, args),
|
||||
Error: err,
|
||||
Start: mTime1,
|
||||
End: mTime2,
|
||||
Start: timestampMilli1,
|
||||
End: timestampMilli2,
|
||||
Group: s.core.DB.GetGroup(),
|
||||
}
|
||||
)
|
||||
@ -60,37 +79,24 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result
|
||||
return result, err
|
||||
}
|
||||
|
||||
// ExecContext executes a prepared statement with the given arguments and
|
||||
// returns a Result summarizing the effect of the statement.
|
||||
func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
||||
result, err := s.doStmtCommit(stmtTypeExecContext, ctx, args...)
|
||||
if result != nil {
|
||||
return result.(sql.Result), err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// QueryContext executes a prepared query statement with the given arguments
|
||||
// and returns the query results as a *Rows.
|
||||
func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
|
||||
var cancelFunc context.CancelFunc
|
||||
ctx, cancelFunc = context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
if s.core.DB.GetConfig().QueryTimeout > 0 {
|
||||
var cancelFuncForTimeout context.CancelFunc
|
||||
ctx, cancelFuncForTimeout = context.WithTimeout(ctx, s.core.DB.GetConfig().QueryTimeout)
|
||||
defer cancelFuncForTimeout()
|
||||
result, err := s.doStmtCommit(stmtTypeQueryContext, ctx, args...)
|
||||
if result != nil {
|
||||
return result.(*sql.Rows), err
|
||||
}
|
||||
var (
|
||||
mTime1 = gtime.TimestampMilli()
|
||||
rows, err = s.Stmt.QueryContext(ctx, args...)
|
||||
mTime2 = gtime.TimestampMilli()
|
||||
sqlObj = &Sql{
|
||||
Sql: s.sql,
|
||||
Type: "Statement.QueryContext",
|
||||
Args: args,
|
||||
Format: FormatSqlWithArgs(s.sql, args),
|
||||
Error: err,
|
||||
Start: mTime1,
|
||||
End: mTime2,
|
||||
Group: s.core.DB.GetGroup(),
|
||||
}
|
||||
)
|
||||
s.core.addSqlToTracing(ctx, sqlObj)
|
||||
if s.core.DB.GetDebug() {
|
||||
s.core.writeSqlToLogger(sqlObj)
|
||||
}
|
||||
return rows, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// QueryRowContext executes a prepared query statement with the given arguments.
|
||||
@ -100,40 +106,17 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows
|
||||
// Otherwise, the *Row's Scan scans the first selected row and discards
|
||||
// the rest.
|
||||
func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row {
|
||||
var cancelFunc context.CancelFunc
|
||||
ctx, cancelFunc = context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
if s.core.DB.GetConfig().QueryTimeout > 0 {
|
||||
var cancelFuncForTimeout context.CancelFunc
|
||||
ctx, cancelFuncForTimeout = context.WithTimeout(ctx, s.core.DB.GetConfig().QueryTimeout)
|
||||
defer cancelFuncForTimeout()
|
||||
result, _ := s.doStmtCommit(stmtTypeQueryRowContext, ctx, args...)
|
||||
if result != nil {
|
||||
return result.(*sql.Row)
|
||||
}
|
||||
var (
|
||||
mTime1 = gtime.TimestampMilli()
|
||||
row = s.Stmt.QueryRowContext(ctx, args...)
|
||||
mTime2 = gtime.TimestampMilli()
|
||||
sqlObj = &Sql{
|
||||
Sql: s.sql,
|
||||
Type: "Statement.QueryRowContext",
|
||||
Args: args,
|
||||
Format: FormatSqlWithArgs(s.sql, args),
|
||||
Error: nil,
|
||||
Start: mTime1,
|
||||
End: mTime2,
|
||||
Group: s.core.DB.GetGroup(),
|
||||
}
|
||||
)
|
||||
s.core.addSqlToTracing(ctx, sqlObj)
|
||||
if s.core.DB.GetDebug() {
|
||||
s.core.writeSqlToLogger(sqlObj)
|
||||
}
|
||||
return row
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exec executes a prepared statement with the given arguments and
|
||||
// returns a Result summarizing the effect of the statement.
|
||||
func (s *Stmt) Exec(args ...interface{}) (sql.Result, error) {
|
||||
return s.ExecContext(context.Background(), args)
|
||||
return s.ExecContext(context.Background(), args...)
|
||||
}
|
||||
|
||||
// Query executes a prepared query statement with the given arguments
|
||||
|
||||
@ -63,7 +63,7 @@ func Test_DB_Exec(t *testing.T) {
|
||||
|
||||
func Test_DB_Prepare(t *testing.T) {
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
st, err := db.Prepare("SELECT 1")
|
||||
st, err := db.Prepare("SELECT 100")
|
||||
t.Assert(err, nil)
|
||||
|
||||
rows, err := st.Query()
|
||||
@ -71,10 +71,10 @@ func Test_DB_Prepare(t *testing.T) {
|
||||
|
||||
array, err := rows.Columns()
|
||||
t.Assert(err, nil)
|
||||
t.Assert(array[0], "1")
|
||||
t.Assert(array[0], "100")
|
||||
|
||||
//err = rows.Close()
|
||||
//t.Assert(err, nil)
|
||||
err = rows.Close()
|
||||
t.Assert(err, nil)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -88,29 +88,26 @@ func Test_TX_Rollback(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_TX_Prepare(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Error(err)
|
||||
}
|
||||
st, err := tx.Prepare("SELECT 100")
|
||||
if err != nil {
|
||||
gtest.Error(err)
|
||||
}
|
||||
rows, err := st.Query()
|
||||
if err != nil {
|
||||
gtest.Error(err)
|
||||
}
|
||||
array, err := rows.Columns()
|
||||
if err != nil {
|
||||
gtest.Error(err)
|
||||
}
|
||||
gtest.Assert(array[0], "100")
|
||||
if err := rows.Close(); err != nil {
|
||||
gtest.Error(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Error(err)
|
||||
}
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
tx, err := db.Begin()
|
||||
t.Assert(err, nil)
|
||||
|
||||
st, err := tx.Prepare("SELECT 100")
|
||||
t.Assert(err, nil)
|
||||
|
||||
rows, err := st.Query()
|
||||
t.Assert(err, nil)
|
||||
|
||||
array, err := rows.Columns()
|
||||
t.Assert(err, nil)
|
||||
t.Assert(array[0], "100")
|
||||
|
||||
rows.Close()
|
||||
t.Assert(err, nil)
|
||||
|
||||
tx.Commit()
|
||||
t.Assert(err, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_TX_Insert(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user