improve package gdb

This commit is contained in:
John
2020-03-22 23:26:15 +08:00
parent 75dc1d82c1
commit 63e5a60344
9 changed files with 167 additions and 206 deletions

View File

@ -8,112 +8,72 @@ package driver
import (
"database/sql"
"fmt"
"github.com/gogf/gf/database/gdb"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/os/gtime"
)
// MyDriver is a custom database driver, which is used for testing only.
// For simplifying the unit testing case purpose, MyDriver struct inherits the mysql driver
// gdb.DriverMysql and overwrites its function HandleSqlBeforeCommit.
// So if there's any sql execution, it goes through MyDriver.HandleSqlBeforeCommit firstly and
// then gdb.DriverMysql.HandleSqlBeforeCommit.
// You can call it sql "HOOK" or "HiJack" as your will.
type MyDriver struct {
*gdb.Core
*gdb.DriverMysql
}
// Open creates and returns a underlying sql.DB object for mysql.
func (d *MyDriver) Open(config *gdb.ConfigNode) (*sql.DB, error) {
var source string
if config.LinkInfo != "" {
source = config.LinkInfo
} else {
source = fmt.Sprintf(
"%s:%s@tcp(%s:%s)/%s?charset=%s&multiStatements=true&parseTime=true&loc=Local",
config.User, config.Pass, config.Host, config.Port, config.Name, config.Charset,
)
}
intlog.Printf("Open: %s", source)
if db, err := sql.Open("mysql", source); err == nil {
return db, nil
} else {
return nil, err
var (
// customDriverName is my driver name, which is used for registering.
customDriverName = "MyDriver"
)
func init() {
// It here registers my custom driver in package initialization function "init".
// You can later use this type in the database configuration.
if err := gdb.Register(customDriverName, &MyDriver{}); err != nil {
panic(err)
}
}
// getChars returns the security char for this type of database.
func (d *MyDriver) GetChars() (charLeft string, charRight string) {
return "`", "`"
// New creates and returns a database object for mysql.
// It implements the interface of gdb.Driver for extra database driver installation.
func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) {
return &MyDriver{
&gdb.DriverMysql{
Core: core,
},
}, nil
}
// handleSqlBeforeExec handles the sql before posts it to database.
func (d *MyDriver) HandleSqlBeforeExec(sql string) string {
return sql
}
// Tables retrieves and returns the tables of current schema.
func (d *MyDriver) Tables(schema ...string) (tables []string, err error) {
var result gdb.Result
link, err := d.DB.GetSlave(schema...)
if err != nil {
return nil, err
}
result, err = d.DB.DoGetAll(link, `SHOW TABLES`)
if err != nil {
return
}
for _, m := range result {
for _, v := range m {
tables = append(tables, v.String())
}
// DoQuery commits the sql string and its arguments to underlying driver
// through given link object and returns the execution result.
func (d *MyDriver) DoQuery(link gdb.Link, sql string, args ...interface{}) (rows *sql.Rows, err error) {
tsMilli := gtime.TimestampMilli()
rows, err = d.DriverMysql.DoQuery(link, sql, args...)
if _, err := d.DriverMysql.InsertIgnore("monitor", g.Map{
"sql": gdb.FormatSqlWithArgs(sql, args),
"cost": gtime.TimestampMilli() - tsMilli,
"time": gtime.Now(),
"error": err.Error(),
}); err != nil {
panic(err)
}
return
}
// gdb.TableFields retrieves and returns the fields information of specified table of current schema.
//
// Note that it returns a map containing the field name and its corresponding fields.
// As a map is unsorted, the gdb.TableField struct has a "Index" field marks its sequence in the fields.
//
// It's using cache feature to enhance the performance, which is never expired util the process restarts.
func (d *MyDriver) TableFields(table string, schema ...string) (fields map[string]*gdb.TableField, err error) {
table = gstr.Trim(table)
if gstr.Contains(table, " ") {
panic("function gdb.TableFields supports only single table operations")
}
checkSchema := d.DB.GetSchema()
if len(schema) > 0 && schema[0] != "" {
checkSchema = schema[0]
}
v := d.DB.GetCache().GetOrSetFunc(
fmt.Sprintf(`mysql_table_fields_%s_%s`, table, checkSchema),
func() interface{} {
var result gdb.Result
var link *sql.DB
link, err = d.DB.GetSlave(checkSchema)
if err != nil {
return nil
}
result, err = d.DB.DoGetAll(
link,
fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.DB.QuoteWord(table)),
)
if err != nil {
return nil
}
fields = make(map[string]*gdb.TableField)
for i, m := range result {
fields[m["Field"].String()] = &gdb.TableField{
Index: i,
Name: m["Field"].String(),
Type: m["Type"].String(),
Null: m["Null"].Bool(),
Key: m["Key"].String(),
Default: m["Default"].Val(),
Extra: m["Extra"].String(),
Comment: m["Comment"].String(),
}
}
return fields
}, 0)
if err == nil {
fields = v.(map[string]*gdb.TableField)
// DoExec commits the query string and its arguments to underlying driver
// through given link object and returns the execution result.
func (d *MyDriver) DoExec(link gdb.Link, sql string, args ...interface{}) (result sql.Result, err error) {
tsMilli := gtime.TimestampMilli()
result, err = d.DriverMysql.DoExec(link, sql, args...)
if _, err := d.DriverMysql.InsertIgnore("monitor", g.Map{
"sql": gdb.FormatSqlWithArgs(sql, args),
"cost": gtime.TimestampMilli() - tsMilli,
"time": gtime.Now(),
"error": err.Error(),
}); err != nil {
panic(err)
}
return
}

View File

@ -30,29 +30,29 @@ type DB interface {
Open(config *ConfigNode) (*sql.DB, error)
// Query APIs.
Query(query string, args ...interface{}) (*sql.Rows, error)
Query(sql string, args ...interface{}) (*sql.Rows, error)
Exec(sql string, args ...interface{}) (sql.Result, error)
Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error)
// Internal APIs for CURD, which can be overwrote for custom CURD implements.
DoQuery(link Link, query string, args ...interface{}) (rows *sql.Rows, err error)
DoGetAll(link Link, query string, args ...interface{}) (result Result, err error)
DoExec(link Link, query string, args ...interface{}) (result sql.Result, err error)
DoPrepare(link Link, query string) (*sql.Stmt, error)
DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Rows, err error)
DoGetAll(link Link, sql string, args ...interface{}) (result Result, err error)
DoExec(link Link, sql string, args ...interface{}) (result sql.Result, err error)
DoPrepare(link Link, sql string) (*sql.Stmt, error)
DoInsert(link Link, table string, data interface{}, option int, batch ...int) (result sql.Result, err error)
DoBatchInsert(link Link, table string, list interface{}, option int, batch ...int) (result sql.Result, err error)
DoUpdate(link Link, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error)
DoDelete(link Link, table string, condition string, args ...interface{}) (result sql.Result, err error)
// Query APIs for convenience purpose.
GetAll(query string, args ...interface{}) (Result, error)
GetOne(query string, args ...interface{}) (Record, error)
GetValue(query string, args ...interface{}) (Value, error)
GetArray(query string, args ...interface{}) ([]Value, error)
GetCount(query string, args ...interface{}) (int, error)
GetStruct(objPointer interface{}, query string, args ...interface{}) error
GetStructs(objPointerSlice interface{}, query string, args ...interface{}) error
GetScan(objPointer interface{}, query string, args ...interface{}) error
GetAll(sql string, args ...interface{}) (Result, error)
GetOne(sql string, args ...interface{}) (Record, error)
GetValue(sql string, args ...interface{}) (Value, error)
GetArray(sql string, args ...interface{}) ([]Value, error)
GetCount(sql string, args ...interface{}) (int, error)
GetStruct(objPointer interface{}, sql string, args ...interface{}) error
GetStructs(objPointerSlice interface{}, sql string, args ...interface{}) error
GetScan(objPointer interface{}, sql string, args ...interface{}) error
// Master/Slave specification support.
Master() (*sql.DB, error)
@ -107,9 +107,9 @@ type DB interface {
// HandleSqlBeforeCommit is a hook function, which deals with the sql string before
// it's committed to underlying driver. The parameter <link> specifies the current
// database connection operation object. You can modify the sql string <query> and its
// database connection operation object. You can modify the sql string <sql> and its
// arguments <args> as you wish before they're committed to driver.
HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{})
HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{})
// Internal methods.
filterFields(schema, table string, data map[string]interface{}) map[string]interface{}
@ -161,7 +161,7 @@ type TableField struct {
// Link is a common database function wrapper interface.
type Link interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
Query(sql string, args ...interface{}) (*sql.Rows, error)
Exec(sql string, args ...interface{}) (sql.Result, error)
Prepare(sql string) (*sql.Stmt, error)
}
@ -193,6 +193,8 @@ const (
)
var (
// ErrNoRows is alias of sql.ErrNoRows.
ErrNoRows = sql.ErrNoRows
// instances is the management map for instances.
instances = gmap.NewStrAnyMap(true)
// driverMap manages all custom registered driver.

View File

@ -45,75 +45,75 @@ func (c *Core) Slave() (*sql.DB, error) {
// Query commits one query SQL to underlying driver and returns the execution result.
// It is most commonly used for data querying.
func (c *Core) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
func (c *Core) Query(sql string, args ...interface{}) (rows *sql.Rows, err error) {
link, err := c.DB.Slave()
if err != nil {
return nil, err
}
return c.DB.DoQuery(link, query, args...)
return c.DB.DoQuery(link, sql, args...)
}
// doQuery commits the query string and its arguments to underlying driver
// DoQuery commits the sql string and its arguments to underlying driver
// through given link object and returns the execution result.
func (c *Core) DoQuery(link Link, query string, args ...interface{}) (rows *sql.Rows, err error) {
query, args = formatQuery(query, args)
query, args = c.DB.HandleSqlBeforeCommit(link, query, args)
func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) {
sql, args = formatSql(sql, args)
sql, args = c.DB.HandleSqlBeforeCommit(link, sql, args)
if c.DB.GetDebug() {
mTime1 := gtime.TimestampMilli()
rows, err = link.Query(query, args...)
rows, err = link.Query(sql, args...)
mTime2 := gtime.TimestampMilli()
s := &Sql{
Sql: query,
Sql: sql,
Args: args,
Format: bindArgsToQuery(query, args),
Format: FormatSqlWithArgs(sql, args),
Error: err,
Start: mTime1,
End: mTime2,
}
c.writeSqlToLogger(s)
} else {
rows, err = link.Query(query, args...)
rows, err = link.Query(sql, args...)
}
if err == nil {
return rows, nil
} else {
err = formatError(err, query, args...)
err = formatError(err, sql, args...)
}
return nil, err
}
// Exec commits one query SQL to underlying driver and returns the execution result.
// It is most commonly used for data inserting and updating.
func (c *Core) Exec(query string, args ...interface{}) (result sql.Result, err error) {
func (c *Core) Exec(sql string, args ...interface{}) (result sql.Result, err error) {
link, err := c.DB.Master()
if err != nil {
return nil, err
}
return c.DB.DoExec(link, query, args...)
return c.DB.DoExec(link, sql, args...)
}
// doExec commits the query string and its arguments to underlying driver
// DoExec commits the sql string and its arguments to underlying driver
// through given link object and returns the execution result.
func (c *Core) DoExec(link Link, query string, args ...interface{}) (result sql.Result, err error) {
query, args = formatQuery(query, args)
query, args = c.DB.HandleSqlBeforeCommit(link, query, args)
func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Result, err error) {
sql, args = formatSql(sql, args)
sql, args = c.DB.HandleSqlBeforeCommit(link, sql, args)
if c.DB.GetDebug() {
mTime1 := gtime.TimestampMilli()
result, err = link.Exec(query, args...)
result, err = link.Exec(sql, args...)
mTime2 := gtime.TimestampMilli()
s := &Sql{
Sql: query,
Sql: sql,
Args: args,
Format: bindArgsToQuery(query, args),
Format: FormatSqlWithArgs(sql, args),
Error: err,
Start: mTime1,
End: mTime2,
}
c.writeSqlToLogger(s)
} else {
result, err = link.Exec(query, args...)
result, err = link.Exec(sql, args...)
}
return result, formatError(err, query, args...)
return result, formatError(err, sql, args...)
}
// Prepare creates a prepared statement for later queries or executions.
@ -124,7 +124,7 @@ func (c *Core) DoExec(link Link, query string, args ...interface{}) (result sql.
//
// The parameter <execOnMaster> specifies whether executing the sql on master node,
// or else it executes the sql on slave node if master-slave configured.
func (c *Core) Prepare(query string, execOnMaster ...bool) (*sql.Stmt, error) {
func (c *Core) Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error) {
err := (error)(nil)
link := (Link)(nil)
if len(execOnMaster) > 0 && execOnMaster[0] {
@ -136,28 +136,28 @@ func (c *Core) Prepare(query string, execOnMaster ...bool) (*sql.Stmt, error) {
return nil, err
}
}
return c.DB.DoPrepare(link, query)
return c.DB.DoPrepare(link, sql)
}
// doPrepare calls prepare function on given link object and returns the statement object.
func (c *Core) DoPrepare(link Link, query string) (*sql.Stmt, error) {
return link.Prepare(query)
func (c *Core) DoPrepare(link Link, sql string) (*sql.Stmt, error) {
return link.Prepare(sql)
}
// GetAll queries and returns data records from database.
func (c *Core) GetAll(query string, args ...interface{}) (Result, error) {
return c.DB.DoGetAll(nil, query, args...)
func (c *Core) GetAll(sql string, args ...interface{}) (Result, error) {
return c.DB.DoGetAll(nil, sql, args...)
}
// doGetAll queries and returns data records from database.
func (c *Core) DoGetAll(link Link, query string, args ...interface{}) (result Result, err error) {
func (c *Core) DoGetAll(link Link, sql string, args ...interface{}) (result Result, err error) {
if link == nil {
link, err = c.DB.Slave()
if err != nil {
return nil, err
}
}
rows, err := c.DB.DoQuery(link, query, args...)
rows, err := c.DB.DoQuery(link, sql, args...)
if err != nil || rows == nil {
return nil, err
}
@ -166,8 +166,8 @@ func (c *Core) DoGetAll(link Link, query string, args ...interface{}) (result Re
}
// GetOne queries and returns one record from database.
func (c *Core) GetOne(query string, args ...interface{}) (Record, error) {
list, err := c.DB.GetAll(query, args...)
func (c *Core) GetOne(sql string, args ...interface{}) (Record, error) {
list, err := c.DB.GetAll(sql, args...)
if err != nil {
return nil, err
}
@ -179,8 +179,8 @@ func (c *Core) GetOne(query string, args ...interface{}) (Record, error) {
// GetArray queries and returns data values as slice from database.
// Note that if there're multiple columns in the result, it returns just one column values randomly.
func (c *Core) GetArray(query string, args ...interface{}) ([]Value, error) {
all, err := c.DB.DoGetAll(nil, query, args...)
func (c *Core) GetArray(sql string, args ...interface{}) ([]Value, error) {
all, err := c.DB.DoGetAll(nil, sql, args...)
if err != nil {
return nil, err
}
@ -189,26 +189,26 @@ func (c *Core) GetArray(query string, args ...interface{}) ([]Value, error) {
// GetStruct queries one record from database and converts it to given struct.
// The parameter <pointer> should be a pointer to struct.
func (c *Core) GetStruct(pointer interface{}, query string, args ...interface{}) error {
one, err := c.DB.GetOne(query, args...)
func (c *Core) GetStruct(pointer interface{}, sql string, args ...interface{}) error {
one, err := c.DB.GetOne(sql, args...)
if err != nil {
return err
}
if len(one) == 0 {
return sql.ErrNoRows
return ErrNoRows
}
return one.Struct(pointer)
}
// GetStructs queries records from database and converts them to given struct.
// The parameter <pointer> should be type of struct slice: []struct/[]*struct.
func (c *Core) GetStructs(pointer interface{}, query string, args ...interface{}) error {
all, err := c.DB.GetAll(query, args...)
func (c *Core) GetStructs(pointer interface{}, sql string, args ...interface{}) error {
all, err := c.DB.GetAll(sql, args...)
if err != nil {
return err
}
if len(all) == 0 {
return sql.ErrNoRows
return ErrNoRows
}
return all.Structs(pointer)
}
@ -219,7 +219,7 @@ func (c *Core) GetStructs(pointer interface{}, query string, args ...interface{}
// If parameter <pointer> is type of struct pointer, it calls GetStruct internally for
// the conversion. If parameter <pointer> is type of slice, it calls GetStructs internally
// for conversion.
func (c *Core) GetScan(pointer interface{}, query string, args ...interface{}) error {
func (c *Core) GetScan(pointer interface{}, sql string, args ...interface{}) error {
t := reflect.TypeOf(pointer)
k := t.Kind()
if k != reflect.Ptr {
@ -228,9 +228,9 @@ func (c *Core) GetScan(pointer interface{}, query string, args ...interface{}) e
k = t.Elem().Kind()
switch k {
case reflect.Array, reflect.Slice:
return c.DB.GetStructs(pointer, query, args...)
return c.DB.GetStructs(pointer, sql, args...)
case reflect.Struct:
return c.DB.GetStruct(pointer, query, args...)
return c.DB.GetStruct(pointer, sql, args...)
}
return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k)
}
@ -238,8 +238,8 @@ func (c *Core) GetScan(pointer interface{}, query string, args ...interface{}) e
// GetValue queries and returns the field value from database.
// The sql should queries only one field from database, or else it returns only one
// field of the result.
func (c *Core) GetValue(query string, args ...interface{}) (Value, error) {
one, err := c.DB.GetOne(query, args...)
func (c *Core) GetValue(sql string, args ...interface{}) (Value, error) {
one, err := c.DB.GetOne(sql, args...)
if err != nil {
return nil, err
}
@ -250,13 +250,13 @@ func (c *Core) GetValue(query string, args ...interface{}) (Value, error) {
}
// GetCount queries and returns the count from database.
func (c *Core) GetCount(query string, args ...interface{}) (int, error) {
func (c *Core) GetCount(sql string, args ...interface{}) (int, error) {
// If the query fields do not contains function "COUNT",
// it replaces the query string and adds the "COUNT" function to the fields.
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
// it replaces the sql string and adds the "COUNT" function to the fields.
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, sql) {
sql, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, sql)
}
value, err := c.DB.GetValue(query, args...)
value, err := c.DB.GetValue(sql, args...)
if err != nil {
return 0, err
}

View File

@ -60,10 +60,10 @@ func (d *DriverMssql) GetChars() (charLeft string, charRight string) {
}
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverMssql) HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{}) {
func (d *DriverMssql) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
var index int
// Convert place holder char '?' to string "@px".
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
str, _ := gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
index++
return fmt.Sprintf("@p%d", index)
})

View File

@ -64,10 +64,10 @@ func (d *DriverOracle) GetChars() (charLeft string, charRight string) {
}
// HandleSqlBeforeCommit deals with the sql string before commits it to underlying sql driver.
func (d *DriverOracle) HandleSqlBeforeCommit(link Link, query string, args []interface{}) (string, []interface{}) {
func (d *DriverOracle) HandleSqlBeforeCommit(link Link, sql string, args []interface{}) (string, []interface{}) {
var index int
// Convert place holder char '?' to string ":x".
str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string {
str, _ := gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
index++
return fmt.Sprintf(":%d", index)
})

View File

@ -8,7 +8,6 @@ package gdb
import (
"bytes"
"database/sql"
"errors"
"fmt"
"github.com/gogf/gf/internal/empty"
@ -222,11 +221,11 @@ func GetPrimaryKeyCondition(primary string, where ...interface{}) (newWhereCondi
return where
}
// formatQuery formats the query string and its arguments before executing.
// formatSql formats the sql string and its arguments before executing.
// The internal handleArguments function might be called twice during the SQL procedure,
// but do not worry about it, it's safe and efficient.
func formatQuery(query string, args []interface{}) (newQuery string, newArgs []interface{}) {
return handleArguments(query, args)
func formatSql(sql string, args []interface{}) (newQuery string, newArgs []interface{}) {
return handleArguments(sql, args)
}
// formatWhere formats where statement and its arguments.
@ -384,8 +383,8 @@ func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key
// handleArguments is a nice function which handles the query and its arguments before committing to
// underlying driver.
func handleArguments(query string, args []interface{}) (newQuery string, newArgs []interface{}) {
newQuery = query
func handleArguments(sql string, args []interface{}) (newSql string, newArgs []interface{}) {
newSql = sql
// Handles the slice arguments.
if len(args) > 0 {
for index, arg := range args {
@ -409,12 +408,12 @@ func handleArguments(query string, args []interface{}) (newQuery string, newArgs
// It the '?' holder count equals the length of the slice,
// it does not implement the arguments splitting logic.
// Eg: db.Query("SELECT ?+?", g.Slice{1, 2})
if len(args) == 1 && gstr.Count(newQuery, "?") == rv.Len() {
if len(args) == 1 && gstr.Count(newSql, "?") == rv.Len() {
break
}
// counter is used to finding the inserting position for the '?' holder.
counter := 0
newQuery, _ = gregex.ReplaceStringFunc(`\?`, newQuery, func(s string) string {
newSql, _ = gregex.ReplaceStringFunc(`\?`, newSql, func(s string) string {
counter++
if counter == index+1 {
return "?" + strings.Repeat(",?", rv.Len()-1)
@ -450,19 +449,19 @@ func handleArguments(query string, args []interface{}) (newQuery string, newArgs
}
// formatError customizes and returns the SQL error.
func formatError(err error, query string, args ...interface{}) error {
if err != nil && err != sql.ErrNoRows {
return errors.New(fmt.Sprintf("%s, %s\n", err.Error(), bindArgsToQuery(query, args)))
func formatError(err error, sql string, args ...interface{}) error {
if err != nil && err != ErrNoRows {
return errors.New(fmt.Sprintf("%s, %s\n", err.Error(), FormatSqlWithArgs(sql, args)))
}
return err
}
// bindArgsToQuery binds the arguments to the query string and returns a complete
// FormatSqlWithArgs binds the arguments to the sql string and returns a complete
// sql string, just for debugging.
func bindArgsToQuery(query string, args []interface{}) string {
func FormatSqlWithArgs(sql string, args []interface{}) string {
index := -1
newQuery, _ := gregex.ReplaceStringFunc(
`(\?|:\d+|\$\d+|@p\d+)`, query, func(s string) string {
`(\?|:\d+|\$\d+|@p\d+)`, sql, func(s string) string {
index++
if len(args) > index {
if args[index] == nil {

View File

@ -145,19 +145,19 @@ func (m *Model) ForPage(page, limit int) *Model {
}
// getAll does the query from database.
func (m *Model) getAll(query string, args ...interface{}) (result Result, err error) {
func (m *Model) getAll(sql string, args ...interface{}) (result Result, err error) {
cacheKey := ""
// Retrieve from cache.
if m.cacheEnabled {
cacheKey = m.cacheName
if len(cacheKey) == 0 {
cacheKey = query + "/" + gconv.String(args)
cacheKey = sql + "/" + gconv.String(args)
}
if v := m.db.GetCache().Get(cacheKey); v != nil {
return v.(Result), nil
}
}
result, err = m.db.DoGetAll(m.getLink(false), query, m.mergeArguments(args)...)
result, err = m.db.DoGetAll(m.getLink(false), sql, m.mergeArguments(args)...)
// Cache the result.
if len(cacheKey) > 0 && err == nil {
if m.cacheDuration < 0 {

View File

@ -33,14 +33,14 @@ func (tx *TX) Rollback() error {
// Query does query operation on transaction.
// See Core.Query.
func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
return tx.db.DoQuery(tx.tx, query, args...)
func (tx *TX) Query(sql string, args ...interface{}) (rows *sql.Rows, err error) {
return tx.db.DoQuery(tx.tx, sql, args...)
}
// Exec does none query operation on transaction.
// See Core.Exec.
func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
return tx.db.DoExec(tx.tx, query, args...)
func (tx *TX) Exec(sql string, args ...interface{}) (sql.Result, error) {
return tx.db.DoExec(tx.tx, sql, args...)
}
// Prepare creates a prepared statement for later queries or executions.
@ -48,13 +48,13 @@ func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
func (tx *TX) Prepare(query string) (*sql.Stmt, error) {
return tx.db.DoPrepare(tx.tx, query)
func (tx *TX) Prepare(sql string) (*sql.Stmt, error) {
return tx.db.DoPrepare(tx.tx, sql)
}
// GetAll queries and returns data records from database.
func (tx *TX) GetAll(query string, args ...interface{}) (Result, error) {
rows, err := tx.Query(query, args...)
func (tx *TX) GetAll(sql string, args ...interface{}) (Result, error) {
rows, err := tx.Query(sql, args...)
if err != nil || rows == nil {
return nil, err
}
@ -63,8 +63,8 @@ func (tx *TX) GetAll(query string, args ...interface{}) (Result, error) {
}
// GetOne queries and returns one record from database.
func (tx *TX) GetOne(query string, args ...interface{}) (Record, error) {
list, err := tx.GetAll(query, args...)
func (tx *TX) GetOne(sql string, args ...interface{}) (Record, error) {
list, err := tx.GetAll(sql, args...)
if err != nil {
return nil, err
}
@ -76,8 +76,8 @@ func (tx *TX) GetOne(query string, args ...interface{}) (Record, error) {
// GetStruct queries one record from database and converts it to given struct.
// The parameter <pointer> should be a pointer to struct.
func (tx *TX) GetStruct(obj interface{}, query string, args ...interface{}) error {
one, err := tx.GetOne(query, args...)
func (tx *TX) GetStruct(obj interface{}, sql string, args ...interface{}) error {
one, err := tx.GetOne(sql, args...)
if err != nil {
return err
}
@ -86,8 +86,8 @@ func (tx *TX) GetStruct(obj interface{}, query string, args ...interface{}) erro
// GetStructs queries records from database and converts them to given struct.
// The parameter <pointer> should be type of struct slice: []struct/[]*struct.
func (tx *TX) GetStructs(objPointerSlice interface{}, query string, args ...interface{}) error {
all, err := tx.GetAll(query, args...)
func (tx *TX) GetStructs(objPointerSlice interface{}, sql string, args ...interface{}) error {
all, err := tx.GetAll(sql, args...)
if err != nil {
return err
}
@ -100,7 +100,7 @@ func (tx *TX) GetStructs(objPointerSlice interface{}, query string, args ...inte
// If parameter <pointer> is type of struct pointer, it calls GetStruct internally for
// the conversion. If parameter <pointer> is type of slice, it calls GetStructs internally
// for conversion.
func (tx *TX) GetScan(objPointer interface{}, query string, args ...interface{}) error {
func (tx *TX) GetScan(objPointer interface{}, sql string, args ...interface{}) error {
t := reflect.TypeOf(objPointer)
k := t.Kind()
if k != reflect.Ptr {
@ -109,9 +109,9 @@ func (tx *TX) GetScan(objPointer interface{}, query string, args ...interface{})
k = t.Elem().Kind()
switch k {
case reflect.Array, reflect.Slice:
return tx.db.GetStructs(objPointer, query, args...)
return tx.db.GetStructs(objPointer, sql, args...)
case reflect.Struct:
return tx.db.GetStruct(objPointer, query, args...)
return tx.db.GetStruct(objPointer, sql, args...)
default:
return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k)
}
@ -121,8 +121,8 @@ func (tx *TX) GetScan(objPointer interface{}, query string, args ...interface{})
// GetValue queries and returns the field value from database.
// The sql should queries only one field from database, or else it returns only one
// field of the result.
func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) {
one, err := tx.GetOne(query, args...)
func (tx *TX) GetValue(sql string, args ...interface{}) (Value, error) {
one, err := tx.GetOne(sql, args...)
if err != nil {
return nil, err
}
@ -133,11 +133,11 @@ func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) {
}
// GetCount queries and returns the count from database.
func (tx *TX) GetCount(query string, args ...interface{}) (int, error) {
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
func (tx *TX) GetCount(sql string, args ...interface{}) (int, error) {
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, sql) {
sql, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, sql)
}
value, err := tx.GetValue(query, args...)
value, err := tx.GetValue(sql, args...)
if err != nil {
return 0, err
}

View File

@ -11,29 +11,29 @@ import (
"testing"
)
func Test_Func_bindArgsToQuery(t *testing.T) {
func Test_Func_FormatSqlWithArgs(t *testing.T) {
// mysql
gtest.C(t, func(t *gtest.T) {
var s string
s = bindArgsToQuery("select * from table where id>=? and sex=?", []interface{}{100, 1})
s = FormatSqlWithArgs("select * from table where id>=? and sex=?", []interface{}{100, 1})
t.Assert(s, "select * from table where id>=100 and sex=1")
})
// mssql
gtest.C(t, func(t *gtest.T) {
var s string
s = bindArgsToQuery("select * from table where id>=@p1 and sex=@p2", []interface{}{100, 1})
s = FormatSqlWithArgs("select * from table where id>=@p1 and sex=@p2", []interface{}{100, 1})
t.Assert(s, "select * from table where id>=100 and sex=1")
})
// pgsql
gtest.C(t, func(t *gtest.T) {
var s string
s = bindArgsToQuery("select * from table where id>=$1 and sex=$2", []interface{}{100, 1})
s = FormatSqlWithArgs("select * from table where id>=$1 and sex=$2", []interface{}{100, 1})
t.Assert(s, "select * from table where id>=100 and sex=1")
})
// oracle
gtest.C(t, func(t *gtest.T) {
var s string
s = bindArgsToQuery("select * from table where id>=:1 and sex=:2", []interface{}{100, 1})
s = FormatSqlWithArgs("select * from table where id>=:1 and sex=:2", []interface{}{100, 1})
t.Assert(s, "select * from table where id>=100 and sex=1")
})
}