mirror of
https://gitee.com/johng/gf
synced 2026-06-06 16:21:40 +08:00
improve package gdb
This commit is contained in:
@ -8,112 +8,72 @@ package driver
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/gogf/gf/database/gdb"
|
||||
"github.com/gogf/gf/internal/intlog"
|
||||
"github.com/gogf/gf/text/gstr"
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/os/gtime"
|
||||
)
|
||||
|
||||
// MyDriver is a custom database driver, which is used for testing only.
|
||||
// For simplifying the unit testing case purpose, MyDriver struct inherits the mysql driver
|
||||
// gdb.DriverMysql and overwrites its function HandleSqlBeforeCommit.
|
||||
// So if there's any sql execution, it goes through MyDriver.HandleSqlBeforeCommit firstly and
|
||||
// then gdb.DriverMysql.HandleSqlBeforeCommit.
|
||||
// You can call it sql "HOOK" or "HiJack" as your will.
|
||||
type MyDriver struct {
|
||||
*gdb.Core
|
||||
*gdb.DriverMysql
|
||||
}
|
||||
|
||||
// Open creates and returns a underlying sql.DB object for mysql.
|
||||
func (d *MyDriver) Open(config *gdb.ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if config.LinkInfo != "" {
|
||||
source = config.LinkInfo
|
||||
} else {
|
||||
source = fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%s)/%s?charset=%s&multiStatements=true&parseTime=true&loc=Local",
|
||||
config.User, config.Pass, config.Host, config.Port, config.Name, config.Charset,
|
||||
)
|
||||
}
|
||||
intlog.Printf("Open: %s", source)
|
||||
if db, err := sql.Open("mysql", source); err == nil {
|
||||
return db, nil
|
||||
} else {
|
||||
return nil, err
|
||||
var (
|
||||
// customDriverName is my driver name, which is used for registering.
|
||||
customDriverName = "MyDriver"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// It here registers my custom driver in package initialization function "init".
|
||||
// You can later use this type in the database configuration.
|
||||
if err := gdb.Register(customDriverName, &MyDriver{}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// getChars returns the security char for this type of database.
|
||||
func (d *MyDriver) GetChars() (charLeft string, charRight string) {
|
||||
return "`", "`"
|
||||
// New creates and returns a database object for mysql.
|
||||
// It implements the interface of gdb.Driver for extra database driver installation.
|
||||
func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) {
|
||||
return &MyDriver{
|
||||
&gdb.DriverMysql{
|
||||
Core: core,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleSqlBeforeExec handles the sql before posts it to database.
|
||||
func (d *MyDriver) HandleSqlBeforeExec(sql string) string {
|
||||
return sql
|
||||
}
|
||||
|
||||
// Tables retrieves and returns the tables of current schema.
|
||||
func (d *MyDriver) Tables(schema ...string) (tables []string, err error) {
|
||||
var result gdb.Result
|
||||
link, err := d.DB.GetSlave(schema...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err = d.DB.DoGetAll(link, `SHOW TABLES`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, m := range result {
|
||||
for _, v := range m {
|
||||
tables = append(tables, v.String())
|
||||
}
|
||||
// DoQuery commits the sql string and its arguments to underlying driver
|
||||
// through given link object and returns the execution result.
|
||||
func (d *MyDriver) DoQuery(link gdb.Link, sql string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
tsMilli := gtime.TimestampMilli()
|
||||
rows, err = d.DriverMysql.DoQuery(link, sql, args...)
|
||||
if _, err := d.DriverMysql.InsertIgnore("monitor", g.Map{
|
||||
"sql": gdb.FormatSqlWithArgs(sql, args),
|
||||
"cost": gtime.TimestampMilli() - tsMilli,
|
||||
"time": gtime.Now(),
|
||||
"error": err.Error(),
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// gdb.TableFields retrieves and returns the fields information of specified table of current schema.
|
||||
//
|
||||
// Note that it returns a map containing the field name and its corresponding fields.
|
||||
// As a map is unsorted, the gdb.TableField struct has a "Index" field marks its sequence in the fields.
|
||||
//
|
||||
// It's using cache feature to enhance the performance, which is never expired util the process restarts.
|
||||
func (d *MyDriver) TableFields(table string, schema ...string) (fields map[string]*gdb.TableField, err error) {
|
||||
table = gstr.Trim(table)
|
||||
if gstr.Contains(table, " ") {
|
||||
panic("function gdb.TableFields supports only single table operations")
|
||||
}
|
||||
checkSchema := d.DB.GetSchema()
|
||||
if len(schema) > 0 && schema[0] != "" {
|
||||
checkSchema = schema[0]
|
||||
}
|
||||
v := d.DB.GetCache().GetOrSetFunc(
|
||||
fmt.Sprintf(`mysql_table_fields_%s_%s`, table, checkSchema),
|
||||
func() interface{} {
|
||||
var result gdb.Result
|
||||
var link *sql.DB
|
||||
link, err = d.DB.GetSlave(checkSchema)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
result, err = d.DB.DoGetAll(
|
||||
link,
|
||||
fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.DB.QuoteWord(table)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
fields = make(map[string]*gdb.TableField)
|
||||
for i, m := range result {
|
||||
fields[m["Field"].String()] = &gdb.TableField{
|
||||
Index: i,
|
||||
Name: m["Field"].String(),
|
||||
Type: m["Type"].String(),
|
||||
Null: m["Null"].Bool(),
|
||||
Key: m["Key"].String(),
|
||||
Default: m["Default"].Val(),
|
||||
Extra: m["Extra"].String(),
|
||||
Comment: m["Comment"].String(),
|
||||
}
|
||||
}
|
||||
return fields
|
||||
}, 0)
|
||||
if err == nil {
|
||||
fields = v.(map[string]*gdb.TableField)
|
||||
// DoExec commits the query string and its arguments to underlying driver
|
||||
// through given link object and returns the execution result.
|
||||
func (d *MyDriver) DoExec(link gdb.Link, sql string, args ...interface{}) (result sql.Result, err error) {
|
||||
tsMilli := gtime.TimestampMilli()
|
||||
result, err = d.DriverMysql.DoExec(link, sql, args...)
|
||||
if _, err := d.DriverMysql.InsertIgnore("monitor", g.Map{
|
||||
"sql": gdb.FormatSqlWithArgs(sql, args),
|
||||
"cost": gtime.TimestampMilli() - tsMilli,
|
||||
"time": gtime.Now(),
|
||||
"error": err.Error(),
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -30,29 +30,29 @@ type DB interface {
|
||||
Open(config *ConfigNode) (*sql.DB, error)
|
||||
|
||||
// Query APIs.
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
Query(sql string, args ...interface{}) (*sql.Rows, error)
|
||||
Exec(sql string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error)
|
||||
|
||||
// Internal APIs for CURD, which can be overwrote for custom CURD implements.
|
||||
DoQuery(link Link, query string, args ...interface{}) (rows *sql.Rows, err error)
|
||||
DoGetAll(link Link, query string, args ...interface{}) (result Result, err error)
|
||||
DoExec(link Link, query string, args ...interface{}) (result sql.Result, err error)
|
||||
DoPrepare(link Link, query string) (*sql.Stmt, error)
|
||||
DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Rows, err error)
|
||||
DoGetAll(link Link, sql string, args ...interface{}) (result Result, err error)
|
||||
DoExec(link Link, sql string, args ...interface{}) (result sql.Result, err error)
|
||||
DoPrepare(link Link, sql string) (*sql.Stmt, error)
|
||||
DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error)
|
||||
DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error)
|
||||
DoUpdate(link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error)
|
||||
DoDelete(link Link, table string, condition string, args ...interface{}) (result sql.Result, err error)
|
||||
|
||||
// Query APIs for convenience purpose.
|
||||
GetAll(query string, args ...interface{}) (Result, error)
|
||||
GetOne(query string, args ...interface{}) (Record, error)
|
||||
GetValue(query string, args ...interface{}) (Value, error)
|
||||
GetArray(query string, args ...interface{}) ([]Value, error)
|
||||
GetCount(query string, args ...interface{}) (int, error)
|
||||
GetStruct(objPointer interface{}, query string, args ...interface{}) error
|
||||
GetStructs(objPointerSlice interface{}, query string, args ...interface{}) error
|
||||
GetScan(objPointer interface{}, query string, args ...interface{}) error
|
||||
GetAll(sql string, args ...interface{}) (Result, error)
|
||||
GetOne(sql string, args ...interface{}) (Record, error)
|
||||
GetValue(sql string, args ...interface{}) (Value, error)
|
||||
GetArray(sql string, args ...interface{}) ([]Value, error)
|
||||
GetCount(sql string, args ...interface{}) (int, error)
|
||||
GetStruct(objPointer interface{}, sql string, args ...interface{}) error
|
||||
GetStructs(objPointerSlice interface{}, sql string, args ...interface{}) error
|
||||
GetScan(objPointer interface{}, sql string, args ...interface{}) error
|
||||
|
||||
// Master/Slave specification support.
|
||||
Master() (*sql.DB, error)
|
||||
@ -107,9 +107,9 @@ type DB interface {
|
||||
|
||||
// HandleSqlBeforeCommit is a hook function, which deals with the sql string before
|
||||
// it's committed to underlying driver. The parameter <link> specifies the current
|
||||
// database connection operation object. You can modify the sql string <query> and its
|
||||
// database connection operation object. You can modify the sql string <sql> and its
|
||||
// arguments <args> as you wish before they're committed to driver.
|
||||
HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{})
|
||||
HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{})
|
||||
|
||||
// Internal methods.
|
||||
filterFields(schema, table string, data map[string]interface{}) map[string]interface{}
|
||||
@ -161,7 +161,7 @@ type TableField struct {
|
||||
|
||||
// Link is a common database function wrapper interface.
|
||||
type Link interface {
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
Query(sql string, args ...interface{}) (*sql.Rows, error)
|
||||
Exec(sql string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(sql string) (*sql.Stmt, error)
|
||||
}
|
||||
@ -193,6 +193,8 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoRows is alias of sql.ErrNoRows.
|
||||
ErrNoRows = sql.ErrNoRows
|
||||
// instances is the management map for instances.
|
||||
instances = gmap.NewStrAnyMap(true)
|
||||
// driverMap manages all custom registered driver.
|
||||
|
||||
@ -45,75 +45,75 @@ func (c *Core) Slave() (*sql.DB, error) {
|
||||
|
||||
// Query commits one query SQL to underlying driver and returns the execution result.
|
||||
// It is most commonly used for data querying.
|
||||
func (c *Core) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
func (c *Core) Query(sql string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
link, err := c.DB.Slave()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.DB.DoQuery(link, query, args...)
|
||||
return c.DB.DoQuery(link, sql, args...)
|
||||
}
|
||||
|
||||
// doQuery commits the query string and its arguments to underlying driver
|
||||
// DoQuery commits the sql string and its arguments to underlying driver
|
||||
// through given link object and returns the execution result.
|
||||
func (c *Core) DoQuery(link Link, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
query, args = formatQuery(query, args)
|
||||
query, args = c.DB.HandleSqlBeforeCommit(link, query, args)
|
||||
func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
sql, args = formatSql(sql, args)
|
||||
sql, args = c.DB.HandleSqlBeforeCommit(link, sql, args)
|
||||
if c.DB.GetDebug() {
|
||||
mTime1 := gtime.TimestampMilli()
|
||||
rows, err = link.Query(query, args...)
|
||||
rows, err = link.Query(sql, args...)
|
||||
mTime2 := gtime.TimestampMilli()
|
||||
s := &Sql{
|
||||
Sql: query,
|
||||
Sql: sql,
|
||||
Args: args,
|
||||
Format: bindArgsToQuery(query, args),
|
||||
Format: FormatSqlWithArgs(sql, args),
|
||||
Error: err,
|
||||
Start: mTime1,
|
||||
End: mTime2,
|
||||
}
|
||||
c.writeSqlToLogger(s)
|
||||
} else {
|
||||
rows, err = link.Query(query, args...)
|
||||
rows, err = link.Query(sql, args...)
|
||||
}
|
||||
if err == nil {
|
||||
return rows, nil
|
||||
} else {
|
||||
err = formatError(err, query, args...)
|
||||
err = formatError(err, sql, args...)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Exec commits one query SQL to underlying driver and returns the execution result.
|
||||
// It is most commonly used for data inserting and updating.
|
||||
func (c *Core) Exec(query string, args ...interface{}) (result sql.Result, err error) {
|
||||
func (c *Core) Exec(sql string, args ...interface{}) (result sql.Result, err error) {
|
||||
link, err := c.DB.Master()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.DB.DoExec(link, query, args...)
|
||||
return c.DB.DoExec(link, sql, args...)
|
||||
}
|
||||
|
||||
// doExec commits the query string and its arguments to underlying driver
|
||||
// DoExec commits the sql string and its arguments to underlying driver
|
||||
// through given link object and returns the execution result.
|
||||
func (c *Core) DoExec(link Link, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
query, args = formatQuery(query, args)
|
||||
query, args = c.DB.HandleSqlBeforeCommit(link, query, args)
|
||||
func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Result, err error) {
|
||||
sql, args = formatSql(sql, args)
|
||||
sql, args = c.DB.HandleSqlBeforeCommit(link, sql, args)
|
||||
if c.DB.GetDebug() {
|
||||
mTime1 := gtime.TimestampMilli()
|
||||
result, err = link.Exec(query, args...)
|
||||
result, err = link.Exec(sql, args...)
|
||||
mTime2 := gtime.TimestampMilli()
|
||||
s := &Sql{
|
||||
Sql: query,
|
||||
Sql: sql,
|
||||
Args: args,
|
||||
Format: bindArgsToQuery(query, args),
|
||||
Format: FormatSqlWithArgs(sql, args),
|
||||
Error: err,
|
||||
Start: mTime1,
|
||||
End: mTime2,
|
||||
}
|
||||
c.writeSqlToLogger(s)
|
||||
} else {
|
||||
result, err = link.Exec(query, args...)
|
||||
result, err = link.Exec(sql, args...)
|
||||
}
|
||||
return result, formatError(err, query, args...)
|
||||
return result, formatError(err, sql, args...)
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
@ -124,7 +124,7 @@ func (c *Core) DoExec(link Link, query string, args ...interface{}) (result sql.
|
||||
//
|
||||
// The parameter <execOnMaster> specifies whether executing the sql on master node,
|
||||
// or else it executes the sql on slave node if master-slave configured.
|
||||
func (c *Core) Prepare(query string, execOnMaster ...bool) (*sql.Stmt, error) {
|
||||
func (c *Core) Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error) {
|
||||
err := (error)(nil)
|
||||
link := (Link)(nil)
|
||||
if len(execOnMaster) > 0 && execOnMaster[0] {
|
||||
@ -136,28 +136,28 @@ func (c *Core) Prepare(query string, execOnMaster ...bool) (*sql.Stmt, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c.DB.DoPrepare(link, query)
|
||||
return c.DB.DoPrepare(link, sql)
|
||||
}
|
||||
|
||||
// doPrepare calls prepare function on given link object and returns the statement object.
|
||||
func (c *Core) DoPrepare(link Link, query string) (*sql.Stmt, error) {
|
||||
return link.Prepare(query)
|
||||
func (c *Core) DoPrepare(link Link, sql string) (*sql.Stmt, error) {
|
||||
return link.Prepare(sql)
|
||||
}
|
||||
|
||||
// GetAll queries and returns data records from database.
|
||||
func (c *Core) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
return c.DB.DoGetAll(nil, query, args...)
|
||||
func (c *Core) GetAll(sql string, args ...interface{}) (Result, error) {
|
||||
return c.DB.DoGetAll(nil, sql, args...)
|
||||
}
|
||||
|
||||
// doGetAll queries and returns data records from database.
|
||||
func (c *Core) DoGetAll(link Link, query string, args ...interface{}) (result Result, err error) {
|
||||
func (c *Core) DoGetAll(link Link, sql string, args ...interface{}) (result Result, err error) {
|
||||
if link == nil {
|
||||
link, err = c.DB.Slave()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := c.DB.DoQuery(link, query, args...)
|
||||
rows, err := c.DB.DoQuery(link, sql, args...)
|
||||
if err != nil || rows == nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -166,8 +166,8 @@ func (c *Core) DoGetAll(link Link, query string, args ...interface{}) (result Re
|
||||
}
|
||||
|
||||
// GetOne queries and returns one record from database.
|
||||
func (c *Core) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
list, err := c.DB.GetAll(query, args...)
|
||||
func (c *Core) GetOne(sql string, args ...interface{}) (Record, error) {
|
||||
list, err := c.DB.GetAll(sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -179,8 +179,8 @@ func (c *Core) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
|
||||
// GetArray queries and returns data values as slice from database.
|
||||
// Note that if there're multiple columns in the result, it returns just one column values randomly.
|
||||
func (c *Core) GetArray(query string, args ...interface{}) ([]Value, error) {
|
||||
all, err := c.DB.DoGetAll(nil, query, args...)
|
||||
func (c *Core) GetArray(sql string, args ...interface{}) ([]Value, error) {
|
||||
all, err := c.DB.DoGetAll(nil, sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -189,26 +189,26 @@ func (c *Core) GetArray(query string, args ...interface{}) ([]Value, error) {
|
||||
|
||||
// GetStruct queries one record from database and converts it to given struct.
|
||||
// The parameter <pointer> should be a pointer to struct.
|
||||
func (c *Core) GetStruct(pointer interface{}, query string, args ...interface{}) error {
|
||||
one, err := c.DB.GetOne(query, args...)
|
||||
func (c *Core) GetStruct(pointer interface{}, sql string, args ...interface{}) error {
|
||||
one, err := c.DB.GetOne(sql, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(one) == 0 {
|
||||
return sql.ErrNoRows
|
||||
return ErrNoRows
|
||||
}
|
||||
return one.Struct(pointer)
|
||||
}
|
||||
|
||||
// GetStructs queries records from database and converts them to given struct.
|
||||
// The parameter <pointer> should be type of struct slice: []struct/[]*struct.
|
||||
func (c *Core) GetStructs(pointer interface{}, query string, args ...interface{}) error {
|
||||
all, err := c.DB.GetAll(query, args...)
|
||||
func (c *Core) GetStructs(pointer interface{}, sql string, args ...interface{}) error {
|
||||
all, err := c.DB.GetAll(sql, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(all) == 0 {
|
||||
return sql.ErrNoRows
|
||||
return ErrNoRows
|
||||
}
|
||||
return all.Structs(pointer)
|
||||
}
|
||||
@ -219,7 +219,7 @@ func (c *Core) GetStructs(pointer interface{}, query string, args ...interface{}
|
||||
// If parameter <pointer> is type of struct pointer, it calls GetStruct internally for
|
||||
// the conversion. If parameter <pointer> is type of slice, it calls GetStructs internally
|
||||
// for conversion.
|
||||
func (c *Core) GetScan(pointer interface{}, query string, args ...interface{}) error {
|
||||
func (c *Core) GetScan(pointer interface{}, sql string, args ...interface{}) error {
|
||||
t := reflect.TypeOf(pointer)
|
||||
k := t.Kind()
|
||||
if k != reflect.Ptr {
|
||||
@ -228,9 +228,9 @@ func (c *Core) GetScan(pointer interface{}, query string, args ...interface{}) e
|
||||
k = t.Elem().Kind()
|
||||
switch k {
|
||||
case reflect.Array, reflect.Slice:
|
||||
return c.DB.GetStructs(pointer, query, args...)
|
||||
return c.DB.GetStructs(pointer, sql, args...)
|
||||
case reflect.Struct:
|
||||
return c.DB.GetStruct(pointer, query, args...)
|
||||
return c.DB.GetStruct(pointer, sql, args...)
|
||||
}
|
||||
return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k)
|
||||
}
|
||||
@ -238,8 +238,8 @@ func (c *Core) GetScan(pointer interface{}, query string, args ...interface{}) e
|
||||
// GetValue queries and returns the field value from database.
|
||||
// The sql should queries only one field from database, or else it returns only one
|
||||
// field of the result.
|
||||
func (c *Core) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
one, err := c.DB.GetOne(query, args...)
|
||||
func (c *Core) GetValue(sql string, args ...interface{}) (Value, error) {
|
||||
one, err := c.DB.GetOne(sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -250,13 +250,13 @@ func (c *Core) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
}
|
||||
|
||||
// GetCount queries and returns the count from database.
|
||||
func (c *Core) GetCount(query string, args ...interface{}) (int, error) {
|
||||
func (c *Core) GetCount(sql string, args ...interface{}) (int, error) {
|
||||
// If the query fields do not contains function "COUNT",
|
||||
// it replaces the query string and adds the "COUNT" function to the fields.
|
||||
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
|
||||
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
|
||||
// it replaces the sql string and adds the "COUNT" function to the fields.
|
||||
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, sql) {
|
||||
sql, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, sql)
|
||||
}
|
||||
value, err := c.DB.GetValue(query, args...)
|
||||
value, err := c.DB.GetValue(sql, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@ -60,10 +60,10 @@ func (d *DriverMssql) GetChars() (charLeft string, charRight string) {
|
||||
}
|
||||
|
||||
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
|
||||
func (d *DriverMssql) HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{}) {
|
||||
func (d *DriverMssql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
|
||||
var index int
|
||||
// Convert place holder char '?' to string "@px".
|
||||
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
|
||||
str, _ := gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
|
||||
index++
|
||||
return fmt.Sprintf("@p%d", index)
|
||||
})
|
||||
|
||||
@ -64,10 +64,10 @@ func (d *DriverOracle) GetChars() (charLeft string, charRight string) {
|
||||
}
|
||||
|
||||
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
|
||||
func (d *DriverOracle) HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{}) {
|
||||
func (d *DriverOracle) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
|
||||
var index int
|
||||
// Convert place holder char '?' to string ":x".
|
||||
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
|
||||
str, _ := gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
|
||||
index++
|
||||
return fmt.Sprintf(":%d", index)
|
||||
})
|
||||
|
||||
@ -8,7 +8,6 @@ package gdb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gogf/gf/internal/empty"
|
||||
@ -222,11 +221,11 @@ func GetPrimaryKeyCondition(primary string, where ...interface{}) (newWhereCondi
|
||||
return where
|
||||
}
|
||||
|
||||
// formatQuery formats the query string and its arguments before executing.
|
||||
// formatSql formats the sql string and its arguments before executing.
|
||||
// The internal handleArguments function might be called twice during the SQL procedure,
|
||||
// but do not worry about it, it's safe and efficient.
|
||||
func formatQuery(query string, args []interface{}) (newQuery string, newArgs []interface{}) {
|
||||
return handleArguments(query, args)
|
||||
func formatSql(sql string, args []interface{}) (newQuery string, newArgs []interface{}) {
|
||||
return handleArguments(sql, args)
|
||||
}
|
||||
|
||||
// formatWhere formats where statement and its arguments.
|
||||
@ -384,8 +383,8 @@ func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key
|
||||
|
||||
// handleArguments is a nice function which handles the query and its arguments before committing to
|
||||
// underlying driver.
|
||||
func handleArguments(query string, args []interface{}) (newQuery string, newArgs []interface{}) {
|
||||
newQuery = query
|
||||
func handleArguments(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
|
||||
newSql = sql
|
||||
// Handles the slice arguments.
|
||||
if len(args) > 0 {
|
||||
for index, arg := range args {
|
||||
@ -409,12 +408,12 @@ func handleArguments(query string, args []interface{}) (newQuery string, newArgs
|
||||
// It the '?' holder count equals the length of the slice,
|
||||
// it does not implement the arguments splitting logic.
|
||||
// Eg: db.Query("SELECT ?+?", g.Slice{1, 2})
|
||||
if len(args) == 1 && gstr.Count(newQuery, "?") == rv.Len() {
|
||||
if len(args) == 1 && gstr.Count(newSql, "?") == rv.Len() {
|
||||
break
|
||||
}
|
||||
// counter is used to finding the inserting position for the '?' holder.
|
||||
counter := 0
|
||||
newQuery, _ = gregex.ReplaceStringFunc(`\?`, newQuery, func(s string) string {
|
||||
newSql, _ = gregex.ReplaceStringFunc(`\?`, newSql, func(s string) string {
|
||||
counter++
|
||||
if counter == index+1 {
|
||||
return "?" + strings.Repeat(",?", rv.Len()-1)
|
||||
@ -450,19 +449,19 @@ func handleArguments(query string, args []interface{}) (newQuery string, newArgs
|
||||
}
|
||||
|
||||
// formatError customizes and returns the SQL error.
|
||||
func formatError(err error, query string, args ...interface{}) error {
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return errors.New(fmt.Sprintf("%s, %s\n", err.Error(), bindArgsToQuery(query, args)))
|
||||
func formatError(err error, sql string, args ...interface{}) error {
|
||||
if err != nil && err != ErrNoRows {
|
||||
return errors.New(fmt.Sprintf("%s, %s\n", err.Error(), FormatSqlWithArgs(sql, args)))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// bindArgsToQuery binds the arguments to the query string and returns a complete
|
||||
// FormatSqlWithArgs binds the arguments to the sql string and returns a complete
|
||||
// sql string, just for debugging.
|
||||
func bindArgsToQuery(query string, args []interface{}) string {
|
||||
func FormatSqlWithArgs(sql string, args []interface{}) string {
|
||||
index := -1
|
||||
newQuery, _ := gregex.ReplaceStringFunc(
|
||||
`(\?|:\d+|\$\d+|@p\d+)`, query, func(s string) string {
|
||||
`(\?|:\d+|\$\d+|@p\d+)`, sql, func(s string) string {
|
||||
index++
|
||||
if len(args) > index {
|
||||
if args[index] == nil {
|
||||
|
||||
@ -145,19 +145,19 @@ func (m *Model) ForPage(page, limit int) *Model {
|
||||
}
|
||||
|
||||
// getAll does the query from database.
|
||||
func (m *Model) getAll(query string, args ...interface{}) (result Result, err error) {
|
||||
func (m *Model) getAll(sql string, args ...interface{}) (result Result, err error) {
|
||||
cacheKey := ""
|
||||
// Retrieve from cache.
|
||||
if m.cacheEnabled {
|
||||
cacheKey = m.cacheName
|
||||
if len(cacheKey) == 0 {
|
||||
cacheKey = query + "/" + gconv.String(args)
|
||||
cacheKey = sql + "/" + gconv.String(args)
|
||||
}
|
||||
if v := m.db.GetCache().Get(cacheKey); v != nil {
|
||||
return v.(Result), nil
|
||||
}
|
||||
}
|
||||
result, err = m.db.DoGetAll(m.getLink(false), query, m.mergeArguments(args)...)
|
||||
result, err = m.db.DoGetAll(m.getLink(false), sql, m.mergeArguments(args)...)
|
||||
// Cache the result.
|
||||
if len(cacheKey) > 0 && err == nil {
|
||||
if m.cacheDuration < 0 {
|
||||
|
||||
@ -33,14 +33,14 @@ func (tx *TX) Rollback() error {
|
||||
|
||||
// Query does query operation on transaction.
|
||||
// See Core.Query.
|
||||
func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
return tx.db.DoQuery(tx.tx, query, args...)
|
||||
func (tx *TX) Query(sql string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
return tx.db.DoQuery(tx.tx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec does none query operation on transaction.
|
||||
// See Core.Exec.
|
||||
func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return tx.db.DoExec(tx.tx, query, args...)
|
||||
func (tx *TX) Exec(sql string, args ...interface{}) (sql.Result, error) {
|
||||
return tx.db.DoExec(tx.tx, sql, args...)
|
||||
}
|
||||
|
||||
// Prepare creates a prepared statement for later queries or executions.
|
||||
@ -48,13 +48,13 @@ func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
// returned statement.
|
||||
// The caller must call the statement's Close method
|
||||
// when the statement is no longer needed.
|
||||
func (tx *TX) Prepare(query string) (*sql.Stmt, error) {
|
||||
return tx.db.DoPrepare(tx.tx, query)
|
||||
func (tx *TX) Prepare(sql string) (*sql.Stmt, error) {
|
||||
return tx.db.DoPrepare(tx.tx, sql)
|
||||
}
|
||||
|
||||
// GetAll queries and returns data records from database.
|
||||
func (tx *TX) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
rows, err := tx.Query(query, args...)
|
||||
func (tx *TX) GetAll(sql string, args ...interface{}) (Result, error) {
|
||||
rows, err := tx.Query(sql, args...)
|
||||
if err != nil || rows == nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -63,8 +63,8 @@ func (tx *TX) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
}
|
||||
|
||||
// GetOne queries and returns one record from database.
|
||||
func (tx *TX) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
list, err := tx.GetAll(query, args...)
|
||||
func (tx *TX) GetOne(sql string, args ...interface{}) (Record, error) {
|
||||
list, err := tx.GetAll(sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -76,8 +76,8 @@ func (tx *TX) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
|
||||
// GetStruct queries one record from database and converts it to given struct.
|
||||
// The parameter <pointer> should be a pointer to struct.
|
||||
func (tx *TX) GetStruct(obj interface{}, query string, args ...interface{}) error {
|
||||
one, err := tx.GetOne(query, args...)
|
||||
func (tx *TX) GetStruct(obj interface{}, sql string, args ...interface{}) error {
|
||||
one, err := tx.GetOne(sql, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -86,8 +86,8 @@ func (tx *TX) GetStruct(obj interface{}, query string, args ...interface{}) erro
|
||||
|
||||
// GetStructs queries records from database and converts them to given struct.
|
||||
// The parameter <pointer> should be type of struct slice: []struct/[]*struct.
|
||||
func (tx *TX) GetStructs(objPointerSlice interface{}, query string, args ...interface{}) error {
|
||||
all, err := tx.GetAll(query, args...)
|
||||
func (tx *TX) GetStructs(objPointerSlice interface{}, sql string, args ...interface{}) error {
|
||||
all, err := tx.GetAll(sql, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -100,7 +100,7 @@ func (tx *TX) GetStructs(objPointerSlice interface{}, query string, args ...inte
|
||||
// If parameter <pointer> is type of struct pointer, it calls GetStruct internally for
|
||||
// the conversion. If parameter <pointer> is type of slice, it calls GetStructs internally
|
||||
// for conversion.
|
||||
func (tx *TX) GetScan(objPointer interface{}, query string, args ...interface{}) error {
|
||||
func (tx *TX) GetScan(objPointer interface{}, sql string, args ...interface{}) error {
|
||||
t := reflect.TypeOf(objPointer)
|
||||
k := t.Kind()
|
||||
if k != reflect.Ptr {
|
||||
@ -109,9 +109,9 @@ func (tx *TX) GetScan(objPointer interface{}, query string, args ...interface{})
|
||||
k = t.Elem().Kind()
|
||||
switch k {
|
||||
case reflect.Array, reflect.Slice:
|
||||
return tx.db.GetStructs(objPointer, query, args...)
|
||||
return tx.db.GetStructs(objPointer, sql, args...)
|
||||
case reflect.Struct:
|
||||
return tx.db.GetStruct(objPointer, query, args...)
|
||||
return tx.db.GetStruct(objPointer, sql, args...)
|
||||
default:
|
||||
return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k)
|
||||
}
|
||||
@ -121,8 +121,8 @@ func (tx *TX) GetScan(objPointer interface{}, query string, args ...interface{})
|
||||
// GetValue queries and returns the field value from database.
|
||||
// The sql should queries only one field from database, or else it returns only one
|
||||
// field of the result.
|
||||
func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
one, err := tx.GetOne(query, args...)
|
||||
func (tx *TX) GetValue(sql string, args ...interface{}) (Value, error) {
|
||||
one, err := tx.GetOne(sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -133,11 +133,11 @@ func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
}
|
||||
|
||||
// GetCount queries and returns the count from database.
|
||||
func (tx *TX) GetCount(query string, args ...interface{}) (int, error) {
|
||||
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
|
||||
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
|
||||
func (tx *TX) GetCount(sql string, args ...interface{}) (int, error) {
|
||||
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, sql) {
|
||||
sql, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, sql)
|
||||
}
|
||||
value, err := tx.GetValue(query, args...)
|
||||
value, err := tx.GetValue(sql, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@ -11,29 +11,29 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_Func_bindArgsToQuery(t *testing.T) {
|
||||
func Test_Func_FormatSqlWithArgs(t *testing.T) {
|
||||
// mysql
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
var s string
|
||||
s = bindArgsToQuery("select * from table where id>=? and sex=?", []interface{}{100, 1})
|
||||
s = FormatSqlWithArgs("select * from table where id>=? and sex=?", []interface{}{100, 1})
|
||||
t.Assert(s, "select * from table where id>=100 and sex=1")
|
||||
})
|
||||
// mssql
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
var s string
|
||||
s = bindArgsToQuery("select * from table where id>=@p1 and sex=@p2", []interface{}{100, 1})
|
||||
s = FormatSqlWithArgs("select * from table where id>=@p1 and sex=@p2", []interface{}{100, 1})
|
||||
t.Assert(s, "select * from table where id>=100 and sex=1")
|
||||
})
|
||||
// pgsql
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
var s string
|
||||
s = bindArgsToQuery("select * from table where id>=$1 and sex=$2", []interface{}{100, 1})
|
||||
s = FormatSqlWithArgs("select * from table where id>=$1 and sex=$2", []interface{}{100, 1})
|
||||
t.Assert(s, "select * from table where id>=100 and sex=1")
|
||||
})
|
||||
// oracle
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
var s string
|
||||
s = bindArgsToQuery("select * from table where id>=:1 and sex=:2", []interface{}{100, 1})
|
||||
s = FormatSqlWithArgs("select * from table where id>=:1 and sex=:2", []interface{}{100, 1})
|
||||
t.Assert(s, "select * from table where id>=100 and sex=1")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user