improve package gdb

This commit is contained in:
John
2020-03-22 23:26:15 +08:00
parent 75dc1d82c1
commit 63e5a60344
9 changed files with 167 additions and 206 deletions

View File

@ -8,112 +8,72 @@ package driver
import (
"database/sql"
"fmt"
"github.com/gogf/gf/database/gdb"
"github.com/gogf/gf/internal/intlog"
"github.com/gogf/gf/text/gstr"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/os/gtime"
)
// MyDriver is a custom database driver, which is used for testing only.
// For simplifying the unit testing case purpose, MyDriver struct inherits the mysql driver
// gdb.DriverMysql and overwrites its function HandleSqlBeforeCommit.
// So if there's any sql execution, it goes through MyDriver.HandleSqlBeforeCommit firstly and
// then gdb.DriverMysql.HandleSqlBeforeCommit.
// You can call it sql "HOOK" or "HiJack" as your will.
type MyDriver struct {
*gdb.Core
*gdb.DriverMysql
}
// Open creates and returns a underlying sql.DB object for mysql.
func (d *MyDriver) Open(config *gdb.ConfigNode) (*sql.DB, error) {
var source string
if config.LinkInfo != "" {
source = config.LinkInfo
} else {
source = fmt.Sprintf(
"%s:%s@tcp(%s:%s)/%s?charset=%s&multiStatements=true&parseTime=true&loc=Local",
config.User, config.Pass, config.Host, config.Port, config.Name, config.Charset,
)
}
intlog.Printf("Open: %s", source)
if db, err := sql.Open("mysql", source); err == nil {
return db, nil
} else {
return nil, err
var (
// customDriverName is my driver name, which is used for registering.
customDriverName = "MyDriver"
)
func init() {
// It here registers my custom driver in package initialization function "init".
// You can later use this type in the database configuration.
if err := gdb.Register(customDriverName, &MyDriver{}); err != nil {
panic(err)
}
}
// getChars returns the security char for this type of database.
func (d *MyDriver) GetChars() (charLeft string, charRight string) {
return "`", "`"
// New creates and returns a database object for mysql.
// It implements the interface of gdb.Driver for extra database driver installation.
func (d *MyDriver) New(core *gdb.Core, node *gdb.ConfigNode) (gdb.DB, error) {
return &MyDriver{
&gdb.DriverMysql{
Core: core,
},
}, nil
}
// handleSqlBeforeExec handles the sql before posts it to database.
func (d *MyDriver) HandleSqlBeforeExec(sql string) string {
return sql
}
// Tables retrieves and returns the tables of current schema.
func (d *MyDriver) Tables(schema ...string) (tables []string, err error) {
var result gdb.Result
link, err := d.DB.GetSlave(schema...)
if err != nil {
return nil, err
}
result, err = d.DB.DoGetAll(link, `SHOW TABLES`)
if err != nil {
return
}
for _, m := range result {
for _, v := range m {
tables = append(tables, v.String())
}
// DoQuery commits the sql string and its arguments to underlying driver
// through given link object and returns the execution result.
func (d *MyDriver) DoQuery(link gdb.Link, sql string, args ...interface{}) (rows *sql.Rows, err error) {
tsMilli := gtime.TimestampMilli()
rows, err = d.DriverMysql.DoQuery(link, sql, args...)
if _, err := d.DriverMysql.InsertIgnore("monitor", g.Map{
"sql": gdb.FormatSqlWithArgs(sql, args),
"cost": gtime.TimestampMilli() - tsMilli,
"time": gtime.Now(),
"error": err.Error(),
}); err != nil {
panic(err)
}
return
}
// gdb.TableFields retrieves and returns the fields information of specified table of current schema.
//
// Note that it returns a map containing the field name and its corresponding fields.
// As a map is unsorted, the gdb.TableField struct has a "Index" field marks its sequence in the fields.
//
// It's using cache feature to enhance the performance, which is never expired util the process restarts.
func (d *MyDriver) TableFields(table string, schema ...string) (fields map[string]*gdb.TableField, err error) {
table = gstr.Trim(table)
if gstr.Contains(table, " ") {
panic("function gdb.TableFields supports only single table operations")
}
checkSchema := d.DB.GetSchema()
if len(schema) > 0 && schema[0] != "" {
checkSchema = schema[0]
}
v := d.DB.GetCache().GetOrSetFunc(
fmt.Sprintf(`mysql_table_fields_%s_%s`, table, checkSchema),
func() interface{} {
var result gdb.Result
var link *sql.DB
link, err = d.DB.GetSlave(checkSchema)
if err != nil {
return nil
}
result, err = d.DB.DoGetAll(
link,
fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.DB.QuoteWord(table)),
)
if err != nil {
return nil
}
fields = make(map[string]*gdb.TableField)
for i, m := range result {
fields[m["Field"].String()] = &gdb.TableField{
Index: i,
Name: m["Field"].String(),
Type: m["Type"].String(),
Null: m["Null"].Bool(),
Key: m["Key"].String(),
Default: m["Default"].Val(),
Extra: m["Extra"].String(),
Comment: m["Comment"].String(),
}
}
return fields
}, 0)
if err == nil {
fields = v.(map[string]*gdb.TableField)
// DoExec commits the query string and its arguments to underlying driver
// through given link object and returns the execution result.
func (d *MyDriver) DoExec(link gdb.Link, sql string, args ...interface{}) (result sql.Result, err error) {
tsMilli := gtime.TimestampMilli()
result, err = d.DriverMysql.DoExec(link, sql, args...)
if _, err := d.DriverMysql.InsertIgnore("monitor", g.Map{
"sql": gdb.FormatSqlWithArgs(sql, args),
"cost": gtime.TimestampMilli() - tsMilli,
"time": gtime.Now(),
"error": err.Error(),
}); err != nil {
panic(err)
}
return
}