From b00de2c617e398b0be869e836f47be40e407dae4 Mon Sep 17 00:00:00 2001 From: John Guo Date: Mon, 27 Dec 2021 20:51:26 +0800 Subject: [PATCH] add gdb.DB.DoFilter; improve function gdb.DB.DoCommit for package gdb --- database/gdb/gdb.go | 27 ++++- database/gdb/gdb_core_underlying.go | 162 +++++++++++++++++----------- database/gdb/gdb_driver_mssql.go | 6 +- database/gdb/gdb_driver_mysql.go | 6 +- database/gdb/gdb_driver_oracle.go | 6 +- database/gdb/gdb_driver_pgsql.go | 8 +- database/gdb/gdb_driver_sqlite.go | 6 +- database/gdb/gdb_result.go | 3 + database/gdb/gdb_statement.go | 96 ++++++----------- database/gdb/gdb_z_driver_test.go | 6 +- 10 files changed, 177 insertions(+), 149 deletions(-) diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index c40c40802..687b6a115 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -100,7 +100,8 @@ type DB interface { DoDelete(ctx context.Context, link Link, table string, condition string, args ...interface{}) (result sql.Result, err error) // See Core.DoDelete. DoQuery(ctx context.Context, link Link, sql string, args ...interface{}) (result Result, err error) // See Core.DoQuery. DoExec(ctx context.Context, link Link, sql string, args ...interface{}) (result sql.Result, err error) // See Core.DoExec. - DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) // See Core.DoCommit. + DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) // See Core.DoFilter. + DoCommit(ctx context.Context, in DoCommitInput) (out *DoCommitOutput, err error) // See Core.DoCommit. DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, error) // See Core.DoPrepare. // =========================================================================== @@ -183,6 +184,22 @@ type Core struct { config *ConfigNode // Current config node. } +// DoCommitInput is the input parameters for function DoCommit. +type DoCommitInput struct { + Stmt *sql.Stmt + Link Link + Sql string + Args []interface{} + Type string +} + +// DoCommitOutput is the output parameters for function DoCommit. +type DoCommitOutput struct { + Row *sql.Row // Row is the result of Stmt.QueryRowContext. + Rows *sql.Rows // Rows is the result of query statement. + Result sql.Result // Result is the result of exec statement. +} + // Driver is the interface for integrating sql drivers into package gdb. type Driver interface { // New creates and returns a database object for specified database server. @@ -278,6 +295,14 @@ const ( dbRoleSlave = `slave` ) +const ( + DoCommitTypeExecContext = "ExecContext" + DoCommitTypeQueryContext = "QueryContext" + DoCommitTypeStmtExecContext = "Statement.ExecContext" + DoCommitTypeStmtQueryContext = "Statement.QueryContext" + DoCommitTypeStmtQueryRowContext = "Statement.QueryRowContext" +) + var ( // instances is the management map for instances. instances = gmap.NewStrAnyMap(true) diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index 15035b4cc..cdf3c937c 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -12,6 +12,8 @@ import ( "database/sql" "github.com/gogf/gf/v2/container/gvar" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/internal/intlog" "github.com/gogf/gf/v2/os/gtime" ) @@ -45,40 +47,27 @@ func (c *Core) DoQuery(ctx context.Context, link Link, sql string, args ...inter ctx, _ = context.WithTimeout(ctx, c.GetConfig().QueryTimeout) } - // Link execution. + // Sql filtering. sql, args = formatSql(sql, args) - sql, args, err = c.db.DoCommit(ctx, link, sql, args) + sql, args, err = c.db.DoFilter(ctx, link, sql, args) if err != nil { return nil, err } - - mTime1 := gtime.TimestampMilli() - rows, err := link.QueryContext(ctx, sql, args...) - mTime2 := gtime.TimestampMilli() - if err == nil { - result, err = c.convertRowsToResult(ctx, rows) + // Link execution. + var out *DoCommitOutput + out, err = c.db.DoCommit(ctx, DoCommitInput{ + Link: link, + Sql: sql, + Args: args, + Stmt: nil, + Type: DoCommitTypeQueryContext, + }) + if err != nil { + return nil, err } - sqlObj := &Sql{ - Sql: sql, - Type: sqlTypeQueryContext, - Args: args, - Format: FormatSqlWithArgs(sql, args), - Error: err, - Start: mTime1, - End: mTime2, - Group: c.db.GetGroup(), - IsTransaction: link.IsTransaction(), - RowsAffected: int64(result.Len()), - } - // Tracing and logging. - c.addSqlToTracing(ctx, sqlObj) - if c.db.GetDebug() { - c.writeSqlToLogger(ctx, sqlObj) - } - if err == nil { - return result, nil - } else { - err = formatError(err, sql, args...) + if out != nil { + result, err = c.RowsToResult(ctx, out.Rows) + return result, err } return nil, err } @@ -114,49 +103,96 @@ func (c *Core) DoExec(ctx context.Context, link Link, sql string, args ...interf defer cancelFunc() } - // Link execution. + // Sql filtering. sql, args = formatSql(sql, args) - sql, args, err = c.db.DoCommit(ctx, link, sql, args) + sql, args, err = c.db.DoFilter(ctx, link, sql, args) if err != nil { return nil, err } + // Link execution. + var out *DoCommitOutput + out, err = c.db.DoCommit(ctx, DoCommitInput{ + Link: link, + Sql: sql, + Args: args, + Stmt: nil, + Type: DoCommitTypeExecContext, + }) + if out != nil { + return out.Result, err + } + return nil, err +} - mTime1 := gtime.TimestampMilli() - if !c.db.GetDryRun() { - result, err = link.ExecContext(ctx, sql, args...) - } else { - result = new(SqlResult) - } - mTime2 := gtime.TimestampMilli() - var rowsAffected int64 - if err == nil { - rowsAffected, err = result.RowsAffected() - } - sqlObj := &Sql{ - Sql: sql, - Type: sqlTypeExecContext, - Args: args, - Format: FormatSqlWithArgs(sql, args), - Error: err, - Start: mTime1, - End: mTime2, - Group: c.db.GetGroup(), - IsTransaction: link.IsTransaction(), - RowsAffected: rowsAffected, +// DoFilter is a hook function, which filters the sql and its arguments before it's committed to underlying driver. +// The parameter `link` specifies the current database connection operation object. You can modify the sql +// string `sql` and its arguments `args` as you wish before they're committed to driver. +func (c *Core) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { + return sql, args, nil +} + +// DoCommit commits current sql and arguments to underlying sql driver. +func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (*DoCommitOutput, error) { + var ( + err error + cancelFuncForTimeout context.CancelFunc + out = &DoCommitOutput{} + timestampMilli1 = gtime.TimestampMilli() + ) + switch in.Type { + case DoCommitTypeExecContext: + if c.db.GetDryRun() { + out.Result = new(SqlResult) + } else { + out.Result, err = in.Link.ExecContext(ctx, in.Sql, in.Args...) + } + + case DoCommitTypeQueryContext: + out.Rows, err = in.Link.QueryContext(ctx, in.Sql, in.Args...) + + case DoCommitTypeStmtExecContext: + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeExec, ctx) + defer cancelFuncForTimeout() + if c.db.GetDryRun() { + out.Result = new(SqlResult) + } else { + out.Result, err = in.Stmt.ExecContext(ctx, in.Args...) + } + + case DoCommitTypeStmtQueryContext: + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeQuery, ctx) + defer cancelFuncForTimeout() + out.Rows, err = in.Stmt.QueryContext(ctx, in.Args...) + + case DoCommitTypeStmtQueryRowContext: + ctx, cancelFuncForTimeout = c.GetCtxTimeout(ctxTimeoutTypeQuery, ctx) + defer cancelFuncForTimeout() + out.Row = in.Stmt.QueryRowContext(ctx, in.Args...) + + default: + panic(gerror.NewCodef(gcode.CodeInvalidParameter, `invalid DoCommitType "%s"`, in.Type)) } + + var ( + timestampMilli2 = gtime.TimestampMilli() + sqlObj = &Sql{ + Sql: in.Sql, + Type: in.Type, + Args: in.Args, + Format: FormatSqlWithArgs(in.Sql, in.Args), + Error: err, + Start: timestampMilli1, + End: timestampMilli2, + Group: c.db.GetGroup(), + IsTransaction: in.Link.IsTransaction(), + } + ) // Tracing and logging. c.addSqlToTracing(ctx, sqlObj) if c.db.GetDebug() { c.writeSqlToLogger(ctx, sqlObj) } - return result, formatError(err, sql, args...) -} - -// DoCommit is a hook function, which deals with the sql string before it's committed to underlying driver. -// The parameter `link` specifies the current database connection operation object. You can modify the sql -// string `sql` and its arguments `args` as you wish before they're committed to driver. -func (c *Core) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { - return sql, args, nil + return out, formatError(err, in.Sql, in.Args...) } // Prepare creates a prepared statement for later queries or executions. @@ -239,8 +275,8 @@ func (c *Core) DoPrepare(ctx context.Context, link Link, sql string) (*Stmt, err }, err } -// convertRowsToResult converts underlying data record type sql.Rows to Result type. -func (c *Core) convertRowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) { +// RowsToResult converts underlying data record type sql.Rows to Result type. +func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) { if rows == nil { return nil, nil } diff --git a/database/gdb/gdb_driver_mssql.go b/database/gdb/gdb_driver_mssql.go index 85adf7094..59e4ead37 100644 --- a/database/gdb/gdb_driver_mssql.go +++ b/database/gdb/gdb_driver_mssql.go @@ -83,10 +83,10 @@ func (d *DriverMssql) GetChars() (charLeft string, charRight string) { return "\"", "\"" } -// DoCommit deals with the sql string before commits it to underlying sql driver. -func (d *DriverMssql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { +// DoFilter deals with the sql string before commits it to underlying sql driver. +func (d *DriverMssql) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { defer func() { - newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs) + newSql, newArgs, err = d.Core.DoFilter(ctx, link, newSql, newArgs) }() var index int // Convert placeholder char '?' to string "@px". diff --git a/database/gdb/gdb_driver_mysql.go b/database/gdb/gdb_driver_mysql.go index eafa9210d..9dc455105 100644 --- a/database/gdb/gdb_driver_mysql.go +++ b/database/gdb/gdb_driver_mysql.go @@ -87,9 +87,9 @@ func (d *DriverMysql) GetChars() (charLeft string, charRight string) { return "`", "`" } -// DoCommit handles the sql before posts it to database. -func (d *DriverMysql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { - return d.Core.DoCommit(ctx, link, sql, args) +// DoFilter handles the sql before posts it to database. +func (d *DriverMysql) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { + return d.Core.DoFilter(ctx, link, sql, args) } // Tables retrieves and returns the tables of current schema. diff --git a/database/gdb/gdb_driver_oracle.go b/database/gdb/gdb_driver_oracle.go index 4ee244aa0..e7567604d 100644 --- a/database/gdb/gdb_driver_oracle.go +++ b/database/gdb/gdb_driver_oracle.go @@ -86,10 +86,10 @@ func (d *DriverOracle) GetChars() (charLeft string, charRight string) { return "\"", "\"" } -// DoCommit deals with the sql string before commits it to underlying sql driver. -func (d *DriverOracle) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { +// DoFilter deals with the sql string before commits it to underlying sql driver. +func (d *DriverOracle) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { defer func() { - newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs) + newSql, newArgs, err = d.Core.DoFilter(ctx, link, newSql, newArgs) }() var index int diff --git a/database/gdb/gdb_driver_pgsql.go b/database/gdb/gdb_driver_pgsql.go index b1ad84bd7..a581ee0d2 100644 --- a/database/gdb/gdb_driver_pgsql.go +++ b/database/gdb/gdb_driver_pgsql.go @@ -37,7 +37,7 @@ func (d *DriverPgsql) New(core *Core, node *ConfigNode) (DB, error) { }, nil } -// Open creates and returns a underlying sql.DB object for pgsql. +// Open creates and returns an underlying sql.DB object for pgsql. func (d *DriverPgsql) Open(config *ConfigNode) (db *sql.DB, err error) { var ( source string @@ -85,10 +85,10 @@ func (d *DriverPgsql) GetChars() (charLeft string, charRight string) { return "\"", "\"" } -// DoCommit deals with the sql string before commits it to underlying sql driver. -func (d *DriverPgsql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { +// DoFilter deals with the sql string before commits it to underlying sql driver. +func (d *DriverPgsql) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { defer func() { - newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs) + newSql, newArgs, err = d.Core.DoFilter(ctx, link, newSql, newArgs) }() var index int diff --git a/database/gdb/gdb_driver_sqlite.go b/database/gdb/gdb_driver_sqlite.go index 7f5ed072e..dfe3acd21 100644 --- a/database/gdb/gdb_driver_sqlite.go +++ b/database/gdb/gdb_driver_sqlite.go @@ -73,9 +73,9 @@ func (d *DriverSqlite) GetChars() (charLeft string, charRight string) { return "`", "`" } -// DoCommit deals with the sql string before commits it to underlying sql driver. -func (d *DriverSqlite) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { - return d.Core.DoCommit(ctx, link, sql, args) +// DoFilter deals with the sql string before commits it to underlying sql driver. +func (d *DriverSqlite) DoFilter(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { + return d.Core.DoFilter(ctx, link, sql, args) } // Tables retrieves and returns the tables of current schema. diff --git a/database/gdb/gdb_result.go b/database/gdb/gdb_result.go index 3744bdf6a..985212e9d 100644 --- a/database/gdb/gdb_result.go +++ b/database/gdb/gdb_result.go @@ -44,6 +44,9 @@ func (r *SqlResult) MustGetInsertId() int64 { // driver may support this. // Also, See sql.Result. func (r *SqlResult) RowsAffected() (int64, error) { + if r.result == nil { + return 0, nil + } if r.affected > 0 { return r.affected, nil } diff --git a/database/gdb/gdb_statement.go b/database/gdb/gdb_statement.go index 3db6dd43a..e773cc0d2 100644 --- a/database/gdb/gdb_statement.go +++ b/database/gdb/gdb_statement.go @@ -9,10 +9,6 @@ package gdb import ( "context" "database/sql" - - "github.com/gogf/gf/v2/errors/gcode" - "github.com/gogf/gf/v2/errors/gerror" - "github.com/gogf/gf/v2/os/gtime" ) // Stmt is a prepared statement. @@ -31,65 +27,18 @@ type Stmt struct { sql string } -const ( - stmtTypeExecContext = "Statement.ExecContext" - stmtTypeQueryContext = "Statement.QueryContext" - stmtTypeQueryRowContext = "Statement.QueryRowContext" -) - -// doStmtCommit commits statement according to given `stmtType`. -func (s *Stmt) doStmtCommit(ctx context.Context, stmtType string, args ...interface{}) (result interface{}, err error) { - var ( - cancelFuncForTimeout context.CancelFunc - timestampMilli1 = gtime.TimestampMilli() - ) - switch stmtType { - case stmtTypeExecContext: - ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeExec, ctx) - defer cancelFuncForTimeout() - result, err = s.Stmt.ExecContext(ctx, args...) - - case stmtTypeQueryContext: - ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeQuery, ctx) - defer cancelFuncForTimeout() - result, err = s.Stmt.QueryContext(ctx, args...) - - case stmtTypeQueryRowContext: - ctx, cancelFuncForTimeout = s.core.GetCtxTimeout(ctxTimeoutTypeQuery, ctx) - defer cancelFuncForTimeout() - result = s.Stmt.QueryRowContext(ctx, args...) - - default: - panic(gerror.NewCodef(gcode.CodeInvalidParameter, `invalid stmtType: %s`, stmtType)) - } - var ( - timestampMilli2 = gtime.TimestampMilli() - sqlObj = &Sql{ - Sql: s.sql, - Type: stmtType, - Args: args, - Format: FormatSqlWithArgs(s.sql, args), - Error: err, - Start: timestampMilli1, - End: timestampMilli2, - Group: s.core.db.GetGroup(), - IsTransaction: s.link.IsTransaction(), - } - ) - // Tracing and logging. - s.core.addSqlToTracing(ctx, sqlObj) - if s.core.db.GetDebug() { - s.core.writeSqlToLogger(ctx, sqlObj) - } - return result, err -} - // ExecContext executes a prepared statement with the given arguments and // returns a Result summarizing the effect of the statement. func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { - result, err := s.doStmtCommit(ctx, stmtTypeExecContext, args...) - if result != nil { - return result.(sql.Result), err + out, err := s.core.db.DoCommit(ctx, DoCommitInput{ + Stmt: s.Stmt, + Link: s.link, + Sql: s.sql, + Args: args, + Type: DoCommitTypeStmtExecContext, + }) + if out != nil { + return out.Result, err } return nil, err } @@ -97,9 +46,15 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result // QueryContext executes a prepared query statement with the given arguments // and returns the query results as a *Rows. func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { - result, err := s.doStmtCommit(ctx, stmtTypeQueryContext, args...) - if result != nil { - return result.(*sql.Rows), err + out, err := s.core.db.DoCommit(ctx, DoCommitInput{ + Stmt: s.Stmt, + Link: s.link, + Sql: s.sql, + Args: args, + Type: DoCommitTypeStmtQueryContext, + }) + if out != nil { + return out.Rows, err } return nil, err } @@ -111,9 +66,18 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*sql.Rows // Otherwise, the *Row's Scan scans the first selected row and discards // the rest. func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *sql.Row { - result, _ := s.doStmtCommit(ctx, stmtTypeQueryRowContext, args...) - if result != nil { - return result.(*sql.Row) + out, err := s.core.db.DoCommit(ctx, DoCommitInput{ + Stmt: s.Stmt, + Link: s.link, + Sql: s.sql, + Args: args, + Type: DoCommitTypeStmtQueryRowContext, + }) + if err != nil { + panic(err) + } + if out != nil { + return out.Row } return nil } diff --git a/database/gdb/gdb_z_driver_test.go b/database/gdb/gdb_z_driver_test.go index 6840d58b4..62fd88193 100644 --- a/database/gdb/gdb_z_driver_test.go +++ b/database/gdb/gdb_z_driver_test.go @@ -41,11 +41,11 @@ func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) { }, nil } -// DoCommit handles the sql before posts it to database. +// DoFilter handles the sql before posts it to database. // It here overwrites the same method of gdb.DriverMysql and makes some custom changes. -func (d *MyDriver) DoCommit(ctx context.Context, link gdb.Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { +func (d *MyDriver) DoFilter(ctx context.Context, link gdb.Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) { latestSqlString.Set(sql) - return d.DriverMysql.DoCommit(ctx, link, sql, args) + return d.DriverMysql.DoFilter(ctx, link, sql, args) } func init() {