refract gdb package, add complete unit test cases, almost there

This commit is contained in:
John
2018-12-15 15:50:39 +08:00
parent d5e46f2b42
commit e67aa63a50
36 changed files with 1530 additions and 745 deletions

View File

@ -30,27 +30,39 @@ const (
// 数据库操作接口
type DB interface {
// 建立数据库连接方法(开发者一般不需要直接调用)
Open(config *ConfigNode) (*sql.DB, error)
// SQL操作方法
Query(query string, args ...interface{}) (*sql.Rows, error)
Exec(sql string, args ...interface{}) (sql.Result, error)
Prepare(sql string, execOnMaster...bool) (*sql.Stmt, error)
doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error)
doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error)
doPrepare(link dbLink, query string) (*sql.Stmt, error)
doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error)
doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error)
doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error)
doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error)
// 数据库查询
GetAll(query string, args ...interface{}) (Result, error)
GetOne(query string, args ...interface{}) (Record, error)
GetValue(query string, args ...interface{}) (Value, error)
GetCount(query string, args ...interface{}) (int, error)
GetStruct(obj interface{}, query string, args ...interface{}) error
// Ping
// 创建底层数据库master/slave链接对象
Master() (*sql.DB, error)
Slave() (*sql.DB, error)
// Ping
PingMaster() error
PingSlave() error
// 连接属性设置
SetMaxIdleConns(n int)
SetMaxOpenConns(n int)
SetConnMaxLifetime(n int)
// 开启事务操作
Begin() (*Tx, error)
Begin() (*TX, error)
// 数据表插入/更新/保存操作
Insert(table string, data Map) (sql.Result, error)
@ -72,17 +84,26 @@ type DB interface {
// 设置管理
SetDebug(debug bool)
GetQueriedSqls() []*Sql
PrintQueriedSqls()
SetMaxIdleConns(n int)
SetMaxOpenConns(n int)
SetConnMaxLifetime(n int)
// 内部方法接口
open(c *ConfigNode) (*sql.DB, error)
getCache() (*gcache.Cache)
getChars() (charLeft string, charRight string)
getDebug() bool
putSql(s *Sql)
formatCondition(condition interface{}) (where string)
handleSqlBeforeExec(sql string) string
}
// 执行底层数据库操作的核心接口
type dbLink interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
Exec(sql string, args ...interface{}) (sql.Result, error)
Prepare(sql string) (*sql.Stmt, error)
}
// 数据库链接对象
type dbBase struct {
db DB // 数据库对象
@ -228,36 +249,36 @@ func getConfigNodeByPriority(cg ConfigGroup) *ConfigNode {
}
// 获得底层数据库链接对象
func (db *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
func (bs *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
// 负载均衡
node, err := getConfigNodeByGroup(db.group, master)
node, err := getConfigNodeByGroup(bs.group, master)
if err != nil {
return nil, err
}
// 检查缓存连接池对象
cacheKey := node.String()
if v := db.cache.Get(cacheKey); v != nil {
if v := bs.cache.Get(cacheKey); v != nil {
return v.(*sql.DB), nil
}
v := db.cache.GetOrSetFuncLock(node.String(), func() interface{} {
sqlDb, err = db.db.open(node)
v := bs.cache.GetOrSetFuncLock(node.String(), func() interface{} {
sqlDb, err = bs.db.Open(node)
if err != nil {
return nil
}
if n := db.maxIdleConnCount.Val(); n > 0 {
if n := bs.maxIdleConnCount.Val(); n > 0 {
sqlDb.SetMaxIdleConns(n)
} else if node.MaxIdleConnCount > 0 {
sqlDb.SetMaxIdleConns(node.MaxIdleConnCount)
}
if n := db.maxOpenConnCount.Val(); n > 0 {
if n := bs.maxOpenConnCount.Val(); n > 0 {
sqlDb.SetMaxOpenConns(n)
} else if node.MaxOpenConnCount > 0 {
sqlDb.SetMaxOpenConns(node.MaxOpenConnCount)
}
if n := db.maxConnLifetime.Val(); n > 0 {
if n := bs.maxConnLifetime.Val(); n > 0 {
sqlDb.SetConnMaxLifetime(time.Duration(n) * time.Second)
} else if node.MaxConnLifetime > 0 {
sqlDb.SetConnMaxLifetime(time.Duration(node.MaxConnLifetime) * time.Second)
@ -271,11 +292,11 @@ func (db *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
}
// 创建底层数据库master链接对象
func (db *dbBase) Master() (*sql.DB, error) {
return db.getSqlDb(true)
func (bs *dbBase) Master() (*sql.DB, error) {
return bs.getSqlDb(true)
}
// 创建底层数据库slave链接对象
func (db *dbBase) Slave() (*sql.DB, error) {
return db.getSqlDb(false)
func (bs *dbBase) Slave() (*sql.DB, error) {
return bs.getSqlDb(false)
}

View File

@ -14,7 +14,7 @@ import (
"gitee.com/johng/gf/g/os/gcache"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gconv"
"gitee.com/johng/gf/g/util/gstr"
"gitee.com/johng/gf/g/util/gregex"
"reflect"
"strings"
)
@ -24,13 +24,13 @@ const (
)
// 获取已经执行的SQL列表(仅在debug=true时有效)
func (db *dbBase) GetQueriedSqls() []*Sql {
if db.sqls == nil {
func (bs *dbBase) GetQueriedSqls() []*Sql {
if bs.sqls == nil {
return nil
}
sqls := make([]*Sql, 0)
db.sqls.Prev()
db.sqls.RLockIteratorPrev(func(value interface{}) bool {
bs.sqls.Prev()
bs.sqls.RLockIteratorPrev(func(value interface{}) bool {
if value == nil {
return false
}
@ -41,8 +41,8 @@ func (db *dbBase) GetQueriedSqls() []*Sql {
}
// 打印已经执行的SQL列表(仅在debug=true时有效)
func (db *dbBase) PrintQueriedSqls() {
sqls := db.GetQueriedSqls()
func (bs *dbBase) PrintQueriedSqls() {
sqls := bs.GetQueriedSqls()
for k, v := range sqls {
fmt.Println(len(sqls) - k, ":")
fmt.Println(" Sql :", v.Sql)
@ -55,51 +55,106 @@ func (db *dbBase) PrintQueriedSqls() {
}
// 数据库sql查询操作主要执行查询
func (db *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
link, err := db.Slave()
func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
link, err := bs.db.Slave()
if err != nil {
return nil,err
}
return doQuery(db.db, link, query, args...)
return bs.db.doQuery(link, query, args...)
}
// 数据库sql查询操作主要执行查询
func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) {
query = bs.db.handleSqlBeforeExec(query)
if bs.db.getDebug() {
mTime1 := gtime.Millisecond()
rows, err = link.Query(query, args...)
mTime2 := gtime.Millisecond()
s := &Sql {
Sql : query,
Args : args,
Error : err,
Start : mTime1,
End : mTime2,
}
bs.sqls.Put(s)
printSql(s)
} else {
rows, err = link.Query(query, args ...)
}
if err == nil {
return rows, nil
} else {
err = formatError(err, query, args...)
}
return nil, err
}
// 执行一条sql并返回执行情况主要用于非查询操作
func (db *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) {
link, err := db.Master()
func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) {
link, err := bs.db.Master()
if err != nil {
return nil,err
}
return doExec(db.db, link, query, args...)
return bs.db.doExec(link, query, args...)
}
// 执行一条sql并返回执行情况主要用于非查询操作
func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) {
query = bs.db.handleSqlBeforeExec(query)
if bs.db.getDebug() {
mTime1 := gtime.Millisecond()
result, err = link.Exec(query, args ...)
mTime2 := gtime.Millisecond()
s := &Sql{
Sql : query,
Args : args,
Error : err,
Start : mTime1,
End : mTime2,
}
bs.sqls.Put(s)
printSql(s)
} else {
result, err = link.Exec(query, args ...)
}
return result, formatError(err, query, args...)
}
// SQL预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作; 默认执行在Slave上, 通过第二个参数指定执行在Master上
func (db *dbBase) Prepare(query string, execOnMaster...bool) (*sql.Stmt, error) {
func (bs *dbBase) Prepare(query string, execOnMaster...bool) (*sql.Stmt, error) {
err := (error)(nil)
sqldb := (*sql.DB)(nil)
link := (dbLink)(nil)
if len(execOnMaster) > 0 && execOnMaster[0] {
if sqldb, err = db.Master(); err != nil {
if link, err = bs.db.Master(); err != nil {
return nil, err
}
} else {
if sqldb, err = db.Slave(); err != nil {
if link, err = bs.db.Slave(); err != nil {
return nil, err
}
}
return sqldb.Prepare(query)
return bs.db.doPrepare(link, query)
}
// SQL预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作
func (bs *dbBase) doPrepare(link dbLink, query string) (*sql.Stmt, error) {
return link.Prepare(query)
}
// 数据库查询,获取查询结果集,以列表结构返回
func (db *dbBase) GetAll(query string, args ...interface{}) (Result, error) {
rows, err := db.Query(query, args ...)
func (bs *dbBase) GetAll(query string, args ...interface{}) (Result, error) {
rows, err := bs.Query(query, args ...)
if err != nil || rows == nil {
return nil, err
}
defer rows.Close()
return rowsToResult(rows)
}
// 数据库查询,获取查询结果记录,以关联数组结构返回
func (db *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
list, err := db.GetAll(query, args ...)
func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
list, err := bs.GetAll(query, args ...)
if err != nil {
return nil, err
}
@ -110,18 +165,17 @@ func (db *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
}
// 数据库查询获取查询结果记录自动映射数据到给定的struct对象中
func (db *dbBase) GetStruct(obj interface{}, query string, args ...interface{}) error {
one, err := db.GetOne(query, args...)
func (bs *dbBase) GetStruct(obj interface{}, query string, args ...interface{}) error {
one, err := bs.GetOne(query, args...)
if err != nil {
return err
}
return one.ToStruct(obj)
}
// 数据库查询,获取查询字段值
func (db *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
one, err := db.GetOne(query, args ...)
func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
one, err := bs.GetOne(query, args ...)
if err != nil {
return nil, err
}
@ -132,35 +186,20 @@ func (db *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
}
// 数据库查询,获取查询数量
func (db *dbBase) GetCount(query string, args ...interface{}) (int, error) {
val, err := db.GetValue(query, args ...)
func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) {
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
}
value, err := bs.GetValue(query, args ...)
if err != nil {
return 0, err
}
return gconv.Int(val), nil
}
// 数据表查询其中tables可以是多个联表查询语句这种查询方式较复杂建议使用链式操作
func (db *dbBase) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) {
s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables)
if condition != nil {
s += fmt.Sprintf("WHERE %s ", db.formatCondition(condition))
}
if len(groupBy) > 0 {
s += fmt.Sprintf("GROUP BY %s ", groupBy)
}
if len(orderBy) > 0 {
s += fmt.Sprintf("ORDER BY %s ", orderBy)
}
if limit > 0 {
s += fmt.Sprintf("LIMIT %d,%d ", first, limit)
}
return db.GetAll(s, args ... )
return value.Int(), nil
}
// ping一下判断或保持数据库链接(master)
func (db *dbBase) PingMaster() error {
if master, err := db.Master(); err != nil {
func (bs *dbBase) PingMaster() error {
if master, err := bs.db.Master(); err != nil {
return err
} else {
return master.Ping()
@ -168,8 +207,8 @@ func (db *dbBase) PingMaster() error {
}
// ping一下判断或保持数据库链接(slave)
func (db *dbBase) PingSlave() error {
if slave, err := db.Slave(); err != nil {
func (bs *dbBase) PingSlave() error {
if slave, err := bs.db.Slave(); err != nil {
return err
} else {
return slave.Ping()
@ -178,13 +217,13 @@ func (db *dbBase) PingSlave() error {
// 事务操作,开启,会返回一个底层的事务操作对象链接如需要嵌套事务,那么可以使用该对象,否则请忽略
// 只有在tx.Commit/tx.Rollback时链接会自动Close
func (db *dbBase) Begin() (*TX, error) {
if master, err := db.Master(); err != nil {
func (bs *dbBase) Begin() (*TX, error) {
if master, err := bs.db.Master(); err != nil {
return nil, err
} else {
if tx, err := master.Begin(); err == nil {
return &TX {
db : db.db,
db : bs.db,
tx : tx,
master : master,
}, nil
@ -194,16 +233,31 @@ func (db *dbBase) Begin() (*TX, error) {
}
}
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
func (bs *dbBase) Insert(table string, data Map) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, OPTION_INSERT)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (bs *dbBase) Replace(table string, data Map) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, OPTION_REPLACE)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (bs *dbBase) Save(table string, data Map) (sql.Result, error) {
return bs.db.doInsert(nil, table, data, OPTION_SAVE)
}
// insert、replace, save ignore操作
// 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做
func (db *dbBase) insert(table string, data Map, option uint8) (sql.Result, error) {
func (bs *dbBase) doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error) {
var fields []string
var values []string
var params []interface{}
charl, charr := db.db.getChars()
charl, charr := bs.db.getChars()
for k, v := range data {
fields = append(fields, charl + k + charr)
values = append(values, "?")
@ -223,48 +277,53 @@ func (db *dbBase) insert(table string, data Map, option uint8) (sql.Result, erro
}
updatestr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
}
return db.Exec(
fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(fields, ","),
strings.Join(values, ","),
updatestr),
params...
)
if link == nil {
if link, err = bs.db.Master(); err != nil {
return nil, err
}
}
return bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(fields, ","),
strings.Join(values, ","), updatestr),
params...)
}
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
func (db *dbBase) Insert(table string, data Map) (sql.Result, error) {
return db.insert(table, data, OPTION_INSERT)
// CURD操作:批量数据指定批次量写入
func (bs *dbBase) BatchInsert(table string, list List, batch int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_INSERT)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (db *dbBase) Replace(table string, data Map) (sql.Result, error) {
return db.insert(table, data, OPTION_REPLACE)
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (bs *dbBase) BatchReplace(table string, list List, batch int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_REPLACE)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (db *dbBase) Save(table string, data Map) (sql.Result, error) {
return db.insert(table, data, OPTION_SAVE)
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (bs *dbBase) BatchSave(table string, list List, batch int) (sql.Result, error) {
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_SAVE)
}
// 批量写入数据
func (db *dbBase) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) {
func (bs *dbBase) doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error) {
var keys []string
var values []string
var bvalues []string
var params []interface{}
var result sql.Result
var size = len(list)
// 判断长度
if size < 1 {
if len(list) < 1 {
return result, errors.New("empty data list")
}
if link == nil {
if link, err = bs.db.Master(); err != nil {
return
}
}
// 首先获取字段名称及记录长度
for k, _ := range list[0] {
keys = append(keys, k)
values = append(values, "?")
}
charl, charr := db.db.getChars()
charl, charr := bs.db.getChars()
keyStr := charl + strings.Join(keys, charl + "," + charr) + charr
valueHolderStr := "(" + strings.Join(values, ",") + ")"
// 操作判断
@ -283,13 +342,13 @@ func (db *dbBase) batchInsert(table string, list List, batch int, option uint8)
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
}
// 构造批量写入数据格式(注意map的遍历是无序的)
for i := 0; i < size; i++ {
for i := 0; i < len(list); i++ {
for _, k := range keys {
params = append(params, list[i][k])
}
bvalues = append(bvalues, valueHolderStr)
if len(bvalues) == batch {
r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
operation, table, keyStr, strings.Join(bvalues, ","),
updatestr),
params...)
@ -303,7 +362,7 @@ func (db *dbBase) batchInsert(table string, list List, batch int, option uint8)
}
// 处理最后不构成指定批量的数据
if len(bvalues) > 0 {
r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
operation, table, keyStr, strings.Join(bvalues, ","),
updatestr),
params...)
@ -315,27 +374,22 @@ func (db *dbBase) batchInsert(table string, list List, batch int, option uint8)
return result, nil
}
// CURD操作:批量数据指定批次量写入
func (db *dbBase) BatchInsert(table string, list List, batch int) (sql.Result, error) {
return db.batchInsert(table, list, batch, OPTION_INSERT)
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (db *dbBase) BatchReplace(table string, list List, batch int) (sql.Result, error) {
return db.batchInsert(table, list, batch, OPTION_REPLACE)
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (db *dbBase) BatchSave(table string, list List, batch int) (sql.Result, error) {
return db.batchInsert(table, list, batch, OPTION_SAVE)
// CURD操作:数据更新统一采用sql预处理
// data参数支持字符串或者关联数组类型内部会自行做判断处理
func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
link, err := bs.db.Master()
if err != nil {
return nil, err
}
return bs.db.doUpdate(link, table, data, condition, args ...)
}
// CURD操作:数据更新统一采用sql预处理
// data参数支持字符串或者关联数组类型内部会自行做判断处理
func (db *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
var params []interface{}
var updates string
charl, charr := db.db.getChars()
func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error) {
params := ([]interface{})(nil)
updates := ""
charl, charr := bs.db.getChars()
refValue := reflect.ValueOf(data)
if refValue.Kind() == reflect.Map {
var fields []string
@ -351,44 +405,29 @@ func (db *dbBase) Update(table string, data interface{}, condition interface{},
for _, v := range args {
params = append(params, gconv.String(v))
}
return db.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, db.formatCondition(condition)), params...)
if link == nil {
if link, err = bs.db.Master(); err != nil {
return nil, err
}
}
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, formatCondition(condition)), params...)
}
// CURD操作:删除数据
func (db *dbBase) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
return db.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, db.formatCondition(condition)), args...)
func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
link, err := bs.db.Master()
if err != nil {
return nil, err
}
return bs.db.doDelete(link, table, condition, args ...)
}
// 格式化SQL查询条件
func (db *dbBase) formatCondition(condition interface{}) (where string) {
if reflect.ValueOf(condition).Kind() == reflect.Map {
ks := reflect.ValueOf(condition).MapKeys()
vs := reflect.ValueOf(condition)
for _, k := range ks {
key := gconv.String(k.Interface())
value := gconv.String(vs.MapIndex(k).Interface())
isNum := gstr.IsNumeric(value)
if len(where) > 0 {
where += " AND "
}
if isNum || value == "?" {
where += key + "=" + value
} else {
where += key + "='" + value + "'"
}
}
} else {
where += gconv.String(condition)
}
return
// CURD操作:删除数据
func (bs *dbBase) doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, formatCondition(condition)), args...)
}
// 获得缓存对象
func (db *dbBase) getCache() *gcache.Cache {
return db.cache
func (bs *dbBase) getCache() *gcache.Cache {
return bs.cache
}
// 记录执行的SQL
func (db *dbBase) putSql(s *Sql) {
db.sqls.Put(s)
}

View File

@ -123,19 +123,19 @@ func SetDefaultGroup (groupName string) {
}
// 设置数据库连接池中空闲链接的大小
func (db *dbBase) SetMaxIdleConns(n int) {
db.maxIdleConnCount.Set(n)
func (bs *dbBase) SetMaxIdleConns(n int) {
bs.maxIdleConnCount.Set(n)
}
// 设置数据库连接池最大打开的链接数量
func (db *dbBase) SetMaxOpenConns(n int) {
db.maxOpenConnCount.Set(n)
func (bs *dbBase) SetMaxOpenConns(n int) {
bs.maxOpenConnCount.Set(n)
}
// 设置数据库连接可重复利用的时间,超过该时间则被关闭废弃
// 如果 d <= 0 表示该链接会一直重复利用
func (db *dbBase) SetConnMaxLifetime(n int) {
db.maxConnLifetime.Set(n)
func (bs *dbBase) SetConnMaxLifetime(n int) {
bs.maxConnLifetime.Set(n)
}
// 节点配置转换为字符串
@ -150,14 +150,14 @@ func (node *ConfigNode) String() string {
}
// 是否开启调试服务
func (db *dbBase) SetDebug(debug bool) {
db.debug.Set(debug)
if debug && db.sqls == nil {
db.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH)
func (bs *dbBase) SetDebug(debug bool) {
bs.debug.Set(debug)
if debug && bs.sqls == nil {
bs.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH)
}
}
// 获取是否开启调试服务
func (db *dbBase) getDebug() bool {
return db.debug.Val()
func (bs *dbBase) getDebug() bool {
return bs.debug.Val()
}

View File

@ -13,65 +13,12 @@ import (
"gitee.com/johng/gf/g/container/gvar"
"gitee.com/johng/gf/g/os/glog"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gconv"
"gitee.com/johng/gf/g/util/gstr"
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
"strings"
"reflect"
)
type dbLink interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
Exec(sql string, args ...interface{}) (sql.Result, error)
Prepare(sql string) (*sql.Stmt, error)
}
// 数据库sql查询操作主要执行查询
func doQuery(db DB, link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) {
query = db.handleSqlBeforeExec(query)
if db.getDebug() {
mTime1 := gtime.Millisecond()
rows, err = link.Query(query, args...)
mTime2 := gtime.Millisecond()
s := &Sql{
Sql : query,
Args : args,
Error : err,
Start : mTime1,
End : mTime2,
}
db.putSql(s)
printSql(s)
} else {
rows, err = link.Query(query, args ...)
}
if err == nil {
return rows, nil
} else {
err = formatError(err, query, args...)
}
return nil, err
}
// 执行一条sql并返回执行情况主要用于非查询操作
func doExec(db DB, link dbLink, query string, args ...interface{}) (result sql.Result, err error) {
query = db.handleSqlBeforeExec(query)
if db.getDebug() {
mTime1 := gtime.Millisecond()
result, err = link.Exec(query, args ...)
mTime2 := gtime.Millisecond()
s := &Sql{
Sql : query,
Args : args,
Error : err,
Start : mTime1,
End : mTime2,
}
db.putSql(s)
printSql(s)
} else {
result, err = link.Exec(query, args ...)
}
return result, formatError(err, query, args...)
}
// 将数据查询的列表数据*sql.Rows转换为Result类型
func rowsToResult(rows *sql.Rows) (Result, error) {
// 列名称列表
@ -103,34 +50,31 @@ func rowsToResult(rows *sql.Rows) (Result, error) {
return records, nil
}
func formatInsertQuery(db DB, table string, data Map, option uint8) (string, []interface{}) {
var fields []string
var values []string
var params []interface{}
charl, charr := db.getChars()
for k, v := range data {
fields = append(fields, charl + k + charr)
values = append(values, "?")
params = append(params, v)
}
operation := getInsertOperationByOption(option)
updatestr := ""
if option == OPTION_SAVE {
var updates []string
for k, _ := range data {
updates = append(updates,
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
charl, k, charr,
charl, k, charr,
),
)
// 格式化SQL查询条件
func formatCondition(condition interface{}) (where string) {
if reflect.ValueOf(condition).Kind() == reflect.Map {
ks := reflect.ValueOf(condition).MapKeys()
vs := reflect.ValueOf(condition)
for _, k := range ks {
key := gconv.String(k.Interface())
value := gconv.String(vs.MapIndex(k).Interface())
isNum := gstr.IsNumeric(value)
if len(where) > 0 {
where += " AND "
}
if isNum || value == "?" {
where += key + "=" + value
} else {
where += key + "='" + value + "'"
}
}
updatestr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
} else {
where += gconv.String(condition)
}
return fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(fields, ","),
strings.Join(values, ","), updatestr),
params
if len(where) == 0 {
where = "1"
}
return
}
// 打印SQL对象(仅在debug=true时有效)
@ -163,7 +107,7 @@ func formatError(err error, query string, args ...interface{}) error {
}
// 根据insert选项获得操作名称
func getInsertOperationByOption(option uint8) string {
func getInsertOperationByOption(option int) string {
oper := "INSERT"
switch option {
case OPTION_REPLACE:

View File

@ -17,8 +17,8 @@ import (
// 数据库链式操作模型对象
type Model struct {
tx *Tx // 数据库事务对象
db DB // 数据库操作对象
tx *TX // 数据库事务对象
tablesInit string // 初始化Model时的表名称(可以是多个)
tables string // 数据库操作表
fields string // 操作字段
@ -36,9 +36,9 @@ type Model struct {
}
// 链式操作,数据表字段,可支持多个表,以半角逗号连接
func (db *dbBase) Table(tables string) (*Model) {
return &Model{
db : db.db,
func (bs *dbBase) Table(tables string) (*Model) {
return &Model {
db : bs.db,
tablesInit : tables,
tables : tables,
fields : "*",
@ -46,12 +46,12 @@ func (db *dbBase) Table(tables string) (*Model) {
}
// 链式操作,数据表字段,可支持多个表,以半角逗号连接
func (db *dbBase) From(tables string) (*Model) {
return db.Table(tables)
func (bs *dbBase) From(tables string) (*Model) {
return bs.db.Table(tables)
}
// (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接
func (tx *Tx) Table(tables string) (*Model) {
func (tx *TX) Table(tables string) (*Model) {
return &Model{
db : tx.db,
tx : tx,
@ -61,7 +61,7 @@ func (tx *Tx) Table(tables string) (*Model) {
}
// (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接
func (tx *Tx) From(tables string) (*Model) {
func (tx *TX) From(tables string) (*Model) {
return tx.Table(tables)
}
@ -100,7 +100,7 @@ func (md *Model) Fields(fields string) (*Model) {
// 链式操作condition支持string & gdb.Map
func (md *Model) Where(where interface{}, args ...interface{}) (*Model) {
md.where = md.db.formatCondition(where)
md.where = formatCondition(where)
md.whereArgs = append(md.whereArgs, args...)
// 支持 Where("uid", 1)这种格式
if len(args) == 1 && strings.Index(md.where , "?") < 0 {
@ -111,14 +111,14 @@ func (md *Model) Where(where interface{}, args ...interface{}) (*Model) {
// 链式操作添加AND条件到Where中
func (md *Model) And(where interface{}, args ...interface{}) (*Model) {
md.where += " AND " + md.db.formatCondition(where)
md.where += " AND " + formatCondition(where)
md.whereArgs = append(md.whereArgs, args...)
return md
}
// 链式操作添加OR条件到Where中
func (md *Model) Or(where interface{}, args ...interface{}) (*Model) {
md.where += " OR " + md.db.formatCondition(where)
md.where += " OR " + formatCondition(where)
md.whereArgs = append(md.whereArgs, args...)
return md
}
@ -220,9 +220,9 @@ func (md *Model) Replace() (result sql.Result, err error) {
}
} else if dataMap, ok := md.data.(Map); ok {
if md.tx == nil {
return md.db.Insert(md.tables, dataMap)
return md.db.Replace(md.tables, dataMap)
} else {
return md.tx.Insert(md.tables, dataMap)
return md.tx.Replace(md.tables, dataMap)
}
}
return nil, errors.New("replacing into table with invalid data type")
@ -286,9 +286,6 @@ func (md *Model) Delete() (result sql.Result, err error) {
}
md.clear()
}()
if md.where == "" {
return nil, errors.New("where is required while deleting")
}
if md.tx == nil {
return md.db.Delete(md.tables, md.where, md.whereArgs...)
} else {
@ -320,13 +317,13 @@ func (md *Model) Cache(time int, name ... string) *Model {
// 链式操作select
func (md *Model) Select() (Result, error) {
defer md.clear()
return md.getAll(md.getFormattedSql(), md.whereArgs...)
return md.All()
}
// 链式操作,查询所有记录
func (md *Model) All() (Result, error) {
return md.Select()
defer md.clear()
return md.getAll(md.getFormattedSql(), md.whereArgs...)
}
// 链式操作,查询单条记录

View File

@ -28,12 +28,13 @@ type dbMssql struct {
}
// 创建SQL操作对象
func (db *dbMssql) open(c *ConfigNode) (*sql.DB, error) {
var source string
if c.Linkinfo != "" {
source = c.Linkinfo
func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) {
source := ""
if config.Linkinfo != "" {
source = config.Linkinfo
} else {
source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable", c.User, c.Pass, c.Host, c.Port, c.Name)
source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable",
config.User, config.Pass, config.Host, config.Port, config.Name)
}
if db, err := sql.Open("sqlserver", source); err == nil {
return db, nil

View File

@ -18,12 +18,12 @@ type dbMysql struct {
}
// 创建SQL操作对象内部采用了lazy link处理
func (db *dbMysql) open (c *ConfigNode) (*sql.DB, error) {
func (db *dbMysql) Open (config *ConfigNode) (*sql.DB, error) {
var source string
if c.Linkinfo != "" {
source = c.Linkinfo
if config.Linkinfo != "" {
source = config.Linkinfo
} else {
source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", c.User, c.Pass, c.Host, c.Port, c.Name)
source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true", config.User, config.Pass, config.Host, config.Port, config.Name)
}
if db, err := sql.Open("mysql", source); err == nil {
return db, nil

View File

@ -27,12 +27,12 @@ type dbOracle struct {
}
// 创建SQL操作对象
func (db *dbOracle) open(c *ConfigNode) (*sql.DB, error) {
func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) {
var source string
if c.Linkinfo != "" {
source = c.Linkinfo
if config.Linkinfo != "" {
source = config.Linkinfo
} else {
source = fmt.Sprintf("%s/%s@%s", c.User, c.Pass, c.Name)
source = fmt.Sprintf("%s/%s@%s", config.User, config.Pass, config.Name)
}
if db, err := sql.Open("oci8", source); err == nil {
return db, nil

View File

@ -24,12 +24,12 @@ type dbPgsql struct {
}
// 创建SQL操作对象内部采用了lazy link处理
func (db *dbPgsql) open (c *ConfigNode) (*sql.DB, error) {
func (db *dbPgsql) Open (config *ConfigNode) (*sql.DB, error) {
var source string
if c.Linkinfo != "" {
source = c.Linkinfo
if config.Linkinfo != "" {
source = config.Linkinfo
} else {
source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", c.User, c.Pass, c.Host, c.Port, c.Name)
source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", config.User, config.Pass, config.Host, config.Port, config.Name)
}
if db, err := sql.Open("postgres", source); err == nil {
return db, nil

View File

@ -22,12 +22,12 @@ type dbSqlite struct {
*dbBase
}
func (db *dbSqlite) open(c *ConfigNode) (*sql.DB, error) {
func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) {
var source string
if c.Linkinfo != "" {
source = c.Linkinfo
if config.Linkinfo != "" {
source = config.Linkinfo
} else {
source = c.Name
source = config.Name
}
if db, err := sql.Open("sqlite3", source); err == nil {
return db, nil

View File

@ -7,15 +7,9 @@
package gdb
import (
"fmt"
"errors"
"strings"
"reflect"
"database/sql"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gconv"
"gitee.com/johng/gf/g/util/gregex"
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
"gitee.com/johng/gf/g/container/gvar"
)
// 数据库事务对象
@ -37,49 +31,27 @@ func (tx *TX) Rollback() error {
// (事务)数据库sql查询操作主要执行查询
func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
return doQuery(tx.db, tx.tx, query, args...)
return tx.db.doQuery(tx.tx, query, args...)
}
// (事务)执行一条sql并返回执行情况主要用于非查询操作
func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
return doExec(tx.db, tx.tx, query, args...)
return tx.db.doExec(tx.tx, query, args...)
}
// sql预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作
func (tx *TX) Prepare(query string) (*sql.Stmt, error) {
return tx.db.doPrepare(tx.tx, query)
}
// 数据库查询,获取查询结果集,以列表结构返回
func (tx *TX) GetAll(query string, args ...interface{}) (Result, error) {
// 执行sql
rows, err := tx.Query(query, args ...)
if err != nil || rows == nil {
return nil, err
}
// 列名称列表
columns, err := rows.Columns()
if err != nil {
return nil, err
}
// 返回结构组装
values := make([]sql.RawBytes, len(columns))
scanArgs := make([]interface{}, len(values))
records := make(Result, 0)
for i := range values {
scanArgs[i] = &values[i]
}
for rows.Next() {
err = rows.Scan(scanArgs...)
if err != nil {
return records, err
}
row := make(Record)
// 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址
for i, col := range values {
v := make([]byte, len(col))
copy(v, col)
row[columns[i]] = gvar.New(v, false)
}
//fmt.Printf("%p\n", row["typeid"])
records = append(records, row)
}
return records, nil
defer rows.Close()
return rowsToResult(rows)
}
// 数据库查询,获取查询结果记录,以关联数组结构返回
@ -103,7 +75,6 @@ func (tx *TX) GetStruct(obj interface{}, query string, args ...interface{}) erro
return one.ToStruct(obj)
}
// 数据库查询,获取查询字段值
func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) {
one, err := tx.GetOne(query, args ...)
@ -118,185 +89,54 @@ func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) {
// 数据库查询,获取查询数量
func (tx *TX) GetCount(query string, args ...interface{}) (int, error) {
val, err := tx.GetValue(query, args ...)
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
}
value, err := tx.GetValue(query, args ...)
if err != nil {
return 0, err
}
return gconv.Int(val), nil
}
// 数据表查询其中tables可以是多个联表查询语句这种查询方式较复杂建议使用链式操作
func (tx *TX) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) {
s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables)
if condition != nil {
s += fmt.Sprintf("WHERE %s ", tx.db.formatCondition(condition))
}
if len(groupBy) > 0 {
s += fmt.Sprintf("GROUP BY %s ", groupBy)
}
if len(orderBy) > 0 {
s += fmt.Sprintf("ORDER BY %s ", orderBy)
}
if limit > 0 {
s += fmt.Sprintf("LIMIT %d,%d ", first, limit)
}
return tx.GetAll(s, args ... )
}
// sql预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作
func (tx *TX) Prepare(query string) (*sql.Stmt, error) {
return tx.tx.Prepare(query)
}
// insert、replace, save ignore操作
// 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做
func (tx *TX) insert(table string, data Map, option uint8) (sql.Result, error) {
var keys []string
var values []string
var params []interface{}
for k, v := range data {
keys = append(keys, tx.db.charl + k + tx.db.charr)
values = append(values, "?")
params = append(params, v)
}
operation := tx.db.getInsertOperationByOption(option)
updatestr := ""
if option == OPTION_SAVE {
var updates []string
for k, _ := range data {
updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k))
}
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
}
return tx.Exec(
fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
operation, table, strings.Join(keys, ","),
strings.Join(values, ","),
updatestr),
params...
)
return value.Int(), nil
}
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
func (tx *TX) Insert(table string, data Map) (sql.Result, error) {
return tx.insert(table, data, OPTION_INSERT)
return tx.db.doInsert(tx.tx, table, data, OPTION_INSERT)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (tx *TX) Replace(table string, data Map) (sql.Result, error) {
return tx.insert(table, data, OPTION_REPLACE)
return tx.db.doInsert(tx.tx, table, data, OPTION_REPLACE)
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (tx *TX) Save(table string, data Map) (sql.Result, error) {
return tx.insert(table, data, OPTION_SAVE)
}
// 批量写入数据
func (tx *TX) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) {
var keys []string
var values []string
var bvalues []string
var params []interface{}
var result sql.Result
var size = len(list)
// 判断长度
if size < 1 {
return result, errors.New("empty data list")
}
// 首先获取字段名称及记录长度
for k, _ := range list[0] {
keys = append(keys, k)
values = append(values, "?")
}
keyStr := tx.db.charl + strings.Join(keys, tx.db.charl + "," + tx.db.charr) + tx.db.charr
valueHolderStr := "(" + strings.Join(values, ",") + ")"
// 操作判断
operation := tx.db.getInsertOperationByOption(option)
updatestr := ""
if option == OPTION_SAVE {
var updates []string
for _, k := range keys {
updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k))
}
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
}
// 构造批量写入数据格式(注意map的遍历是无序的)
for i := 0; i < size; i++ {
for _, k := range keys {
params = append(params, list[i][k])
}
bvalues = append(bvalues, valueHolderStr)
if len(bvalues) == batch {
r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
operation, table, keyStr, strings.Join(bvalues, ","),
updatestr),
params...)
if err != nil {
return result, err
}
result = r
params = params[:0]
bvalues = bvalues[:0]
}
}
// 处理最后不构成指定批量的数据
if len(bvalues) > 0 {
r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
operation, table, keyStr, strings.Join(bvalues, ","),
updatestr),
params...)
if err != nil {
return result, err
}
result = r
}
return result, nil
return tx.db.doInsert(tx.tx, table, data, OPTION_SAVE)
}
// CURD操作:批量数据指定批次量写入
func (tx *TX) BatchInsert(table string, list List, batch int) (sql.Result, error) {
return tx.batchInsert(table, list, batch, OPTION_INSERT)
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_INSERT)
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
func (tx *TX) BatchReplace(table string, list List, batch int) (sql.Result, error) {
return tx.batchInsert(table, list, batch, OPTION_REPLACE)
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_REPLACE)
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
func (tx *TX) BatchSave(table string, list List, batch int) (sql.Result, error) {
return tx.batchInsert(table, list, batch, OPTION_SAVE)
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_SAVE)
}
// CURD操作:数据更新统一采用sql预处理
// data参数支持字符串或者关联数组类型内部会自行做判断处理
func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
var params []interface{}
var updates string
refValue := reflect.ValueOf(data)
if refValue.Kind() == reflect.Map {
var fields []string
keys := refValue.MapKeys()
for _, k := range keys {
fields = append(fields, fmt.Sprintf("%s%s%s=?", tx.db.charl, k, tx.db.charr))
params = append(params, gconv.String(refValue.MapIndex(k).Interface()))
updates = strings.Join(fields, ",")
}
} else {
updates = gconv.String(data)
}
for _, v := range args {
params = append(params, gconv.String(v))
}
return tx.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, tx.db.formatCondition(condition)), params...)
return tx.db.doUpdate(tx.tx, table, data, condition, args ...)
}
// CURD操作:删除数据
func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
return tx.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, tx.db.formatCondition(condition)), args...)
return tx.db.doDelete(tx.tx, table, condition, args ...)
}

View File

@ -0,0 +1,215 @@
package gdb_test
import (
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/database/gdb"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gtest"
"testing"
)
var (
// 数据库对象/接口
db gdb.DB
)
// 初始化连接参数。
// 测试前需要修改连接参数。
func init() {
gdb.AddDefaultConfigNode(gdb.ConfigNode{
Host: "127.0.0.1",
Port: "3306",
User: "root",
Pass: "12345678",
Name: "test",
Type: "mysql",
Role: "master",
Charset: "utf8",
Priority: 1,
})
if r, err := gdb.New(); err != nil {
gtest.Fatal(err)
} else {
db = r
}
// 准备测试数据结构
if _, err := db.Exec("DROP TABLE `user`"); err != nil {
gtest.Fatal(err)
}
if _, err := db.Exec(`
CREATE TABLE user (
id int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '用户ID',
passport varchar(45) NOT NULL COMMENT '账号',
password char(32) NOT NULL COMMENT '密码',
nickname varchar(45) NOT NULL COMMENT '昵称',
create_time timestamp NOT NULL COMMENT '创建时间/注册时间',
PRIMARY KEY (id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
`); err != nil {
gtest.Fatal(err)
}
}
func TestDbBase_Query(t *testing.T) {
if _, err := db.Query("SELECT ?", 1); err != nil {
gtest.Fatal(err)
}
if _, err := db.Query("ERROR"); err == nil {
gtest.Fatal("FAIL")
}
}
func TestDbBase_Exec(t *testing.T) {
if _, err := db.Exec("SELECT ?", 1); err != nil {
gtest.Fatal(err)
}
if _, err := db.Exec("ERROR"); err == nil {
gtest.Fatal("FAIL")
}
}
func TestDbBase_Prepare(t *testing.T) {
st, err := db.Prepare("SELECT 100")
if err != nil {
gtest.Fatal(err)
}
rows, err := st.Query()
if err != nil {
gtest.Fatal(err)
}
array, err := rows.Columns()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(array[0], "100")
if err := rows.Close(); err != nil {
gtest.Fatal(err)
}
}
func TestDbBase_Insert(t *testing.T) {
if _, err := db.Insert("user", g.Map{
"id" : 1,
"passport" : "t1",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T1",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
}
func TestDbBase_BatchInsert(t *testing.T) {
if _, err := db.BatchInsert("user", g.List {
{
"id" : 2,
"passport" : "t2",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T2",
"create_time" : gtime.Now().String(),
},
{
"id" : 3,
"passport" : "t3",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T3",
"create_time" : gtime.Now().String(),
},
}, 10); err != nil {
gtest.Fatal(err)
}
}
func TestDbBase_Save(t *testing.T) {
if _, err := db.Save("user", g.Map{
"id" : 1,
"passport" : "t1",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T11",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
}
func TestDbBase_Replace(t *testing.T) {
if _, err := db.Save("user", g.Map{
"id" : 1,
"passport" : "t1",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T111",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
}
func TestDbBase_Update(t *testing.T) {
if result, err := db.Update("user", "create_time='2010-10-10 00:00:01'", "id=3"); err != nil {
gtest.Fatal(err)
} else {
n, _ := result.RowsAffected()
gtest.Assert(n, 1)
}
}
func TestDbBase_GetAll(t *testing.T) {
if result, err := db.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(len(result), 1)
}
}
func TestDbBase_GetOne(t *testing.T) {
if record, err := db.GetOne("SELECT * FROM user WHERE passport=?", "t1"); err != nil {
gtest.Fatal(err)
} else {
if record == nil {
gtest.Fatal("FAIL")
}
gtest.Assert(record["nickname"].String(), "T111")
}
}
func TestDbBase_GetValue(t *testing.T) {
if value, err := db.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.Int(), 3)
}
}
func TestDbBase_GetCount(t *testing.T) {
if count, err := db.GetCount("SELECT * FROM user"); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(count, 3)
}
}
func TestDbBase_GetStruct(t *testing.T) {
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime gtime.Time
}
user := new(User)
if err := db.GetStruct(user, "SELECT * FROM user WHERE id=?", 3); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(user.CreateTime.String(), "2010-10-10 00:00:01")
}
}
func TestDbBase_Delete(t *testing.T) {
if result, err := db.Delete("user", nil); err != nil {
gtest.Fatal(err)
} else {
n, _ := result.RowsAffected()
gtest.Assert(n, 3)
}
}

View File

@ -0,0 +1,177 @@
package gdb_test
import (
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gtest"
"testing"
)
func TestModel_Insert(t *testing.T) {
result, err := db.Table("user").Data(g.Map{
"id" : 1,
"passport" : "t1",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T1",
"create_time" : gtime.Now().String(),
}).Insert()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.LastInsertId()
gtest.Assert(n, 1)
}
func TestModel_Batch(t *testing.T) {
result, err := db.Table("user").Data(g.List{
{
"id" : 2,
"passport" : "t2",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T2",
"create_time" : gtime.Now().String(),
},
{
"id" : 3,
"passport" : "t3",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T3",
"create_time" : gtime.Now().String(),
},
}).Batch(10).Insert()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 2)
}
func TestModel_Replace(t *testing.T) {
result, err := db.Table("user").Data(g.Map{
"id" : 1,
"passport" : "t11",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T11",
"create_time" : gtime.Now().String(),
}).Replace()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 2)
}
func TestModel_Save(t *testing.T) {
result, err := db.Table("user").Data(g.Map{
"id" : 1,
"passport" : "t111",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T111",
"create_time" : gtime.Now().String(),
}).Save()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 2)
}
func TestModel_Update(t *testing.T) {
result, err := db.Table("user").Data("passport", "t22").Where("passport=?", "t2").Update()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 1)
}
func TestModel_All(t *testing.T) {
result, err := db.Table("user").All()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 3)
}
func TestModel_One(t *testing.T) {
record, err := db.Table("user").Where("id", 1).One()
if err != nil {
gtest.Fatal(err)
}
if record == nil {
gtest.Fatal("FAIL")
}
gtest.Assert(record["nickname"].String(), "T111")
}
func TestModel_Value(t *testing.T) {
value, err := db.Table("user").Fields("nickname").Where("id", 1).Value()
if err != nil {
gtest.Fatal(err)
}
if value == nil {
gtest.Fatal("FAIL")
}
gtest.Assert(value.String(), "T111")
}
func TestModel_Count(t *testing.T) {
count, err := db.Table("user").Count()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(count, 3)
}
func TestModel_Select(t *testing.T) {
result, err := db.Table("user").Select()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 3)
}
func TestModel_Struct(t *testing.T) {
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime gtime.Time
}
user := new(User)
err := db.Table("user").Where("id=1").Struct(user)
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(user.NickName, "T111")
}
func TestModel_OrderBy(t *testing.T) {
result, err := db.Table("user").OrderBy("id DESC").Select()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 3)
gtest.Assert(result[0]["nickname"].String(), "T3")
}
func TestModel_GroupBy(t *testing.T) {
result, err := db.Table("user").GroupBy("id").Select()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(len(result), 3)
gtest.Assert(result[0]["nickname"].String(), "T111")
}
func TestModel_Delete(t *testing.T) {
result, err := db.Table("user").Delete()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 3)
}

View File

@ -0,0 +1,372 @@
package gdb_test
import (
"gitee.com/johng/gf/g"
"gitee.com/johng/gf/g/os/gtime"
"gitee.com/johng/gf/g/util/gtest"
"testing"
)
func TestTX_Query(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if rows, err := tx.Query("SELECT ?", 1); err != nil {
gtest.Fatal(err)
} else {
rows.Close()
}
if _, err := tx.Query("ERROR"); err == nil {
gtest.Fatal("FAIL")
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Exec(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.Exec("SELECT ?", 1); err != nil {
gtest.Fatal(err)
}
if _, err := tx.Exec("ERROR"); err == nil {
gtest.Fatal("FAIL")
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Commit(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Rollback(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if err := tx.Rollback(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Prepare(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
st, err := tx.Prepare("SELECT 100")
if err != nil {
gtest.Fatal(err)
}
rows, err := st.Query()
if err != nil {
gtest.Fatal(err)
}
array, err := rows.Columns()
if err != nil {
gtest.Fatal(err)
}
gtest.Assert(array[0], "100")
if err := rows.Close(); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Insert(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.Insert("user", g.Map {
"id" : 1,
"passport" : "t1",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T1",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
if n, err := db.Table("user").Count(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(n, 1)
}
}
func TestTX_BatchInsert(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.BatchInsert("user", g.List {
{
"id" : 2,
"passport" : "t",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T2",
"create_time" : gtime.Now().String(),
},
{
"id" : 3,
"passport" : "t3",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T3",
"create_time" : gtime.Now().String(),
},
}, 10); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
if n, err := db.Table("user").Count(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(n, 3)
}
}
func TestTX_BatchReplace(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.BatchReplace("user", g.List {
{
"id" : 2,
"passport" : "t2",
"password" : "p2",
"nickname" : "T2",
"create_time" : gtime.Now().String(),
},
{
"id" : 4,
"passport" : "t4",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T4",
"create_time" : gtime.Now().String(),
},
}, 10); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
// 数据数量
if n, err := db.Table("user").Count(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(n, 4)
}
// 检查replace后的数值
if value, err := db.Table("user").Fields("password").Where("id", 2).Value(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.String(), "p2")
}
}
func TestTX_BatchSave(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.BatchSave("user", g.List {
{
"id" : 4,
"passport" : "t4",
"password" : "p4",
"nickname" : "T4",
"create_time" : gtime.Now().String(),
},
}, 10); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
// 数据数量
if n, err := db.Table("user").Count(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(n, 4)
}
// 检查replace后的数值
if value, err := db.Table("user").Fields("password").Where("id", 4).Value(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.String(), "p4")
}
}
func TestTX_Replace(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.Replace("user", g.Map {
"id" : 1,
"passport" : "t11",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T11",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
if err := tx.Rollback(); err != nil {
gtest.Fatal(err)
}
if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.String(), "T1")
}
}
func TestTX_Save(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.Save("user", g.Map {
"id" : 1,
"passport" : "t11",
"password" : "25d55ad283aa400af464c76d713c07ad",
"nickname" : "T11",
"create_time" : gtime.Now().String(),
}); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.String(), "T11")
}
}
func TestTX_GetAll(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if result, err := tx.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(len(result), 1)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_GetOne(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if record, err := tx.GetOne("SELECT * FROM user WHERE passport=?", "t2"); err != nil {
gtest.Fatal(err)
} else {
if record == nil {
gtest.Fatal("FAIL")
}
gtest.Assert(record["nickname"].String(), "T2")
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_GetValue(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if value, err := tx.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(value.Int(), 3)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_GetCount(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if count, err := tx.GetCount("SELECT * FROM user"); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(count, 4)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_GetStruct(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime gtime.Time
}
user := new(User)
if err := tx.GetStruct(user, "SELECT * FROM user WHERE id=?", 1); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(user.NickName, "T11")
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
}
func TestTX_Delete(t *testing.T) {
tx, err := db.Begin()
if err != nil {
gtest.Fatal(err)
}
if _, err := tx.Delete("user", nil); err != nil {
gtest.Fatal(err)
}
if err := tx.Commit(); err != nil {
gtest.Fatal(err)
}
if n, err := db.Table("user").Count(); err != nil {
gtest.Fatal(err)
} else {
gtest.Assert(n, 0)
}
}

View File

@ -120,7 +120,7 @@ func Config(file...string) *gcfg.Config {
}
// 数据库操作对象,使用了连接池
func Database(name...string) *gdb.Db {
func Database(name...string) gdb.DB {
config := Config()
group := gdb.DEFAULT_GROUP_NAME
if len(name) > 0 {
@ -195,7 +195,7 @@ func Database(name...string) *gdb.Db {
return nil
})
if db != nil {
return db.(*gdb.Db)
return db.(gdb.DB)
}
return nil
}

View File

@ -44,12 +44,12 @@ func Config(file...string) *gcfg.Config {
}
// 数据库操作对象,使用了连接池
func Database(name...string) *gdb.Db {
func Database(name...string) gdb.DB {
return gins.Database(name...)
}
// (别名)Database
func DB(name...string) *gdb.Db {
func DB(name...string) gdb.DB {
return gins.Database(name...)
}

View File

@ -31,7 +31,7 @@ type Logger struct {
file *gtype.String // 日志文件名称格式
level *gtype.Int // 日志输出等级
btSkip *gtype.Int // 错误产生时的backtrace回调信息skip条数
btEnabled *gtype.Bool // 是否当打印错误时同时开启backtrace打印
btStatus *gtype.Int // 是否当打印错误时同时开启backtrace打印(默认-1表示默认打印逻辑 - 错误才打印)
printHeader *gtype.Bool // 是否不打印前缀信息(时间,级别等)
alsoStdPrint *gtype.Bool // 控制台打印开关,当输出到文件/自定义输出时也同时打印到终端
}
@ -65,7 +65,7 @@ func New() *Logger {
file : gtype.NewString(gDEFAULT_FILE_FORMAT),
level : gtype.NewInt(defaultLevel.Val()),
btSkip : gtype.NewInt(),
btEnabled : gtype.NewBool(true),
btStatus : gtype.NewInt(-1),
printHeader : gtype.NewBool(true),
alsoStdPrint : gtype.NewBool(true),
}
@ -80,7 +80,7 @@ func (l *Logger) Clone() *Logger {
file : l.file.Clone(),
level : l.level.Clone(),
btSkip : l.btSkip.Clone(),
btEnabled : l.btEnabled.Clone(),
btStatus : l.btStatus.Clone(),
printHeader : l.printHeader.Clone(),
alsoStdPrint : l.alsoStdPrint.Clone(),
}
@ -106,7 +106,12 @@ func (l *Logger) SetDebug(debug bool) {
}
func (l *Logger) SetBacktrace(enabled bool) {
l.btEnabled.Set(enabled)
if enabled {
l.btStatus.Set(1)
} else {
l.btStatus.Set(0)
}
}
// 设置BacktraceSkip
@ -214,27 +219,37 @@ func (l *Logger) doStdLockPrint(std io.Writer, s string) {
// 核心打印数据方法(标准输出)
func (l *Logger) stdPrint(s string) {
if l.btStatus.Val() == 1 {
s = l.appendBacktrace(s)
}
l.print(os.Stdout, s)
}
// 核心打印数据方法(标准错误)
func (l *Logger) errPrint(s string) {
// 记录调用回溯信息
if l.btEnabled.Val() {
tracestr := l.GetBacktrace()
if tracestr != "" {
backtrace := "Backtrace:" + ln + tracestr
if s[len(s) - 1] == byte('\n') {
s = s + backtrace + ln
} else {
s = s + ln + backtrace + ln
}
}
status := l.btStatus.Val()
if status == -1 || status == 1 {
s = l.appendBacktrace(s)
}
// 防止串日志情况这里不使用stderr而是使用stdout
l.print(os.Stdout, s)
}
// 输出内容中添加回溯信息
func (l *Logger) appendBacktrace(s string) string {
trace := l.GetBacktrace()
if trace != "" {
backtrace := "Backtrace:" + ln + trace
if s[len(s) - 1] == byte('\n') {
s = s + backtrace + ln
} else {
s = s + ln + backtrace + ln
}
}
return s
}
// 直接打印回溯信息参数skip表示调用端往上多少级开始回溯
func (l *Logger) PrintBacktrace(skip...int) {
l.Println(l.GetBacktrace(skip...))

View File

@ -157,12 +157,10 @@ func bindVarToStruct(elem reflect.Value, name string, value interface{}) (err er
structFieldValue := elem.FieldByName(name)
// 键名与对象属性匹配检测map中如果有struct不存在的属性那么不做处理直接return
if !structFieldValue.IsValid() {
//return errors.New(fmt.Sprintf(`invalid struct attribute of name "%s"`, name))
return nil
}
// CanSet的属性必须为公开属性(首字母大写)
if !structFieldValue.CanSet() {
//return errors.New(fmt.Sprintf(`struct attribute of name "%s" cannot be set`, name))
return nil
}
// 必须将value转换为struct属性的数据类型这里必须用到gconv包
@ -181,12 +179,10 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e
structFieldValue := elem.FieldByIndex([]int{index})
// 键名与对象属性匹配检测
if !structFieldValue.IsValid() {
//return errors.New(fmt.Sprintf("invalid struct attribute at index %d", index))
return nil
}
// CanSet的属性必须为公开属性(首字母大写)
if !structFieldValue.CanSet() {
//return errors.New(fmt.Sprintf("struct attribute cannot be set at index %d", index))
return nil
}
// 必须将value转换为struct属性的数据类型这里必须用到gconv包

30
g/util/gtest/gtest.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://gitee.com/johng/gf.
// Package gtest provides useful test utils.
// 测试模块.
package gtest
import (
"fmt"
"gitee.com/johng/gf/g/os/glog"
"gitee.com/johng/gf/g/util/gconv"
"os"
)
// 断言判断
func Assert(value, expect interface{}) {
if gconv.String(value) != gconv.String(expect) {
glog.Backtrace(true, 1).Printfln(`[ASSERT] VALUE: %v, EXPECT: %v`, value, expect)
os.Exit(1)
}
}
// 提示错误并退出
func Fatal(message...interface{}) {
glog.Backtrace(true, 1).Println(`[FATAL] `, fmt.Sprint(message...))
os.Exit(1)
}

View File

@ -9,7 +9,7 @@ import (
// 本文件用于gf框架的mysql数据库操作示例不作为单元测试使用
var db *gdb.Db
var db gdb.DB
// 初始化配置及创建数据库
func init () {
@ -17,7 +17,7 @@ func init () {
Host : "127.0.0.1",
Port : "3306",
User : "root",
Pass : "8692651",
Pass : "12345678",
Name : "test",
Type : "mysql",
Role : "master",

View File

@ -1,28 +1,16 @@
package main
import (
"gitee.com/johng/gf/g/database/gdb"
"gitee.com/johng/gf/g"
"time"
)
func main() {
gdb.AddDefaultConfigNode(gdb.ConfigNode {
Host : "127.0.0.1",
Port : "3306",
User : "root",
Pass : "12345678",
Name : "test",
Type : "mysql",
Role : "master",
Charset : "utf8",
MaxIdleConnCount : 10,
MaxOpenConnCount : 10,
MaxConnLifetime : 10,
})
db, err := gdb.New()
if err != nil {
panic(err)
}
db := g.DB()
db.SetMaxIdleConns(10)
db.SetMaxOpenConns(10)
db.SetConnMaxLifetime(10)
// 开启调试模式以便于记录所有执行的SQL
db.SetDebug(true)

View File

@ -2,24 +2,11 @@ package main
import (
"fmt"
"gitee.com/johng/gf/g/database/gdb"
"gitee.com/johng/gf/g"
)
func main() {
gdb.AddDefaultConfigNode(gdb.ConfigNode {
Host : "192.168.1.11",
Port : "3306",
User : "root",
Pass : "8692651",
Name : "test",
Type : "mysql",
Role : "master",
Charset : "utf8",
})
db, err := gdb.New()
if err != nil {
panic(err)
}
db := g.DB()
// 开启调试模式以便于记录所有执行的SQL
db.SetDebug(true)

View File

@ -1,41 +0,0 @@
package test
import (
"strings"
"testing"
)
var s = "/name/john///."
var c = "./"
func t1(s string) string {
if len(s) == 0 {
return s
}
for _, cut := range c {
for s[len(s) - 1] == uint8(cut) {
s = s[:len(s) - 1]
if len(s) == 0 {
return s
}
}
}
return s
}
func t2(s string) string {
return strings.TrimRight(s, c)
}
func Benchmark_t1(b *testing.B) {
for i := 0; i < b.N; i++ {
t1(s)
}
}
func Benchmark_t2(b *testing.B) {
for i := 0; i < b.N; i++ {
t2(s)
}
}

View File

@ -1,25 +1,15 @@
package main
import "fmt"
import (
"fmt"
"gitee.com/johng/gf/g/util/gregex"
)
type User struct {
Uid int
}
func New() *User {
return &User{
100,
}
}
func (user *User) Clear() {
user = New()
}
func main() {
user := New()
user.Uid = 10000
fmt.Println(user)
user.Clear()
fmt.Println(user)
query := "select * from user"
q, err := gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
fmt.Println(err)
fmt.Println(q)
}

View File

@ -4,6 +4,7 @@ go:
- 1.8.x
- 1.9.x
- 1.10.x
- 1.11.x
- master
before_install:

View File

@ -35,6 +35,7 @@ Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com>
Hirotaka Yamamoto <ymmt2005 at gmail.com>
ICHINOSE Shogo <shogo82148 at gmail.com>
Ilia Cimpoes <ichimpoesh at gmail.com>
INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com>
James Harr <james.harr at gmail.com>
@ -72,7 +73,9 @@ Shuode Li <elemount at qq.com>
Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Steven Hartland <steven.hartland at multiplay.co.uk>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Xiangyu Hu <xiangyu.hu at outlook.com>
Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc>
@ -88,3 +91,4 @@ Keybase Inc.
Percona LLC
Pivotal Inc.
Stripe Inc.
Multiplay Ltd.

View File

@ -58,7 +58,7 @@ _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface.
Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`:
```go
import "database/sql"
import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
import _ "github.com/go-sql-driver/mysql"
db, err := sql.Open("mysql", "user:password@/dbname")
```
@ -328,11 +328,11 @@ Timeout for establishing connections, aka dial timeout. The value must be a deci
```
Type: bool / string
Valid Values: true, false, skip-verify, <name>
Valid Values: true, false, skip-verify, preferred, <name>
Default: false
```
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use `preferred` to use TLS only when advertised by the server, this is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
##### `writeTimeout`
@ -431,7 +431,7 @@ See [context support in the database/sql package](https://golang.org/doc/go1.8#d
### `LOAD DATA LOCAL INFILE` support
For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
```go
import "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
import "github.com/go-sql-driver/mysql"
```
Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).

View File

@ -234,64 +234,64 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro
if err != nil {
return err
}
return mc.writeAuthSwitchPacket(enc, false)
return mc.writeAuthSwitchPacket(enc)
}
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) {
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
switch plugin {
case "caching_sha2_password":
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
return authResp, false, nil
return authResp, nil
case "mysql_old_password":
if !mc.cfg.AllowOldPasswords {
return nil, false, ErrOldPassword
return nil, ErrOldPassword
}
// Note: there are edge cases where this should work but doesn't;
// this is currently "wontfix":
// https://github.com/go-sql-driver/mysql/issues/184
authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd)
return authResp, true, nil
authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
return authResp, nil
case "mysql_clear_password":
if !mc.cfg.AllowCleartextPasswords {
return nil, false, ErrCleartextPassword
return nil, ErrCleartextPassword
}
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
return []byte(mc.cfg.Passwd), true, nil
return append([]byte(mc.cfg.Passwd), 0), nil
case "mysql_native_password":
if !mc.cfg.AllowNativePasswords {
return nil, false, ErrNativePassword
return nil, ErrNativePassword
}
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
// Native password authentication only need and will need 20-byte challenge.
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
return authResp, false, nil
return authResp, nil
case "sha256_password":
if len(mc.cfg.Passwd) == 0 {
return nil, true, nil
return []byte{0}, nil
}
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
return []byte(mc.cfg.Passwd), true, nil
return append([]byte(mc.cfg.Passwd), 0), nil
}
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
return []byte{1}, false, nil
return []byte{1}, nil
}
// encrypted password
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
return enc, false, err
return enc, err
default:
errLog.Print("unknown auth plugin:", plugin)
return nil, false, ErrUnknownPlugin
return nil, ErrUnknownPlugin
}
}
@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
plugin = newPlugin
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
return err
}
if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil {
if err = mc.writeAuthSwitchPacket(authResp); err != nil {
return err
}
@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
case cachingSha2PasswordPerformFullAuthentication:
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true)
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
if err != nil {
return err
}
@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
data := mc.buf.takeSmallBuffer(4 + 1)
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
return err
}
data[4] = cachingSha2PasswordRequestPublicKey
mc.writePacket(data)
// parse public key
data, err := mc.readPacket()
if err != nil {
if data, err = mc.readPacket(); err != nil {
return err
}

View File

@ -85,11 +85,11 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
plugin := "caching_sha2_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -130,11 +130,11 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) {
plugin := "caching_sha2_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -172,11 +172,11 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) {
plugin := "caching_sha2_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -228,11 +228,11 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) {
plugin := "caching_sha2_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -280,11 +280,11 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
plugin := "caching_sha2_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -336,7 +336,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) {
plugin := "mysql_clear_password"
// Send Client Authentication Packet
_, _, err := mc.auth(authData, plugin)
_, err := mc.auth(authData, plugin)
if err != ErrCleartextPassword {
t.Errorf("expected ErrCleartextPassword, got %v", err)
}
@ -353,11 +353,11 @@ func TestAuthFastCleartextPassword(t *testing.T) {
plugin := "mysql_clear_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -367,8 +367,8 @@ func TestAuthFastCleartextPassword(t *testing.T) {
authRespEnd := authRespStart + 1 + len(authResp)
writtenAuthRespLen := conn.written[authRespStart]
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116}
if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
}
conn.written = nil
@ -396,11 +396,11 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
plugin := "mysql_clear_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -410,9 +410,9 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
authRespEnd := authRespStart + 1 + len(authResp)
writtenAuthRespLen := conn.written[authRespStart]
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
if writtenAuthRespLen != 0 {
t.Fatalf("unexpected written auth response (%d bytes): %v",
writtenAuthRespLen, writtenAuthResp)
expectedAuthResp := []byte{0}
if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
}
conn.written = nil
@ -439,7 +439,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) {
plugin := "mysql_native_password"
// Send Client Authentication Packet
_, _, err := mc.auth(authData, plugin)
_, err := mc.auth(authData, plugin)
if err != ErrNativePassword {
t.Errorf("expected ErrNativePassword, got %v", err)
}
@ -455,11 +455,11 @@ func TestAuthFastNativePassword(t *testing.T) {
plugin := "mysql_native_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -498,11 +498,11 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) {
plugin := "mysql_native_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -540,11 +540,11 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
plugin := "sha256_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -554,7 +554,8 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
authRespEnd := authRespStart + 1 + len(authResp)
writtenAuthRespLen := conn.written[authRespStart]
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
if writtenAuthRespLen != 0 {
expectedAuthResp := []byte{0}
if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
}
conn.written = nil
@ -587,11 +588,11 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) {
plugin := "sha256_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -636,11 +637,11 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) {
plugin := "sha256_password"
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
@ -669,7 +670,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
plugin := "sha256_password"
// send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
t.Fatal(err)
}
@ -677,18 +678,18 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
// unset TLS config to prevent the actual establishment of a TLS wrapper
mc.cfg.tls = nil
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
err = mc.writeHandshakeResponsePacket(authResp, plugin)
if err != nil {
t.Fatal(err)
}
// check written auth response
authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
authRespEnd := authRespStart + 1 + len(authResp) + 1
authRespEnd := authRespStart + 1 + len(authResp)
writtenAuthRespLen := conn.written[authRespStart]
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
}
conn.written = nil
@ -1064,6 +1065,22 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) {
}
}
// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request.
func TestOldAuthSwitchNotAllowed(t *testing.T) {
conn, mc := newRWMockConn(2)
// OldAuthSwitch request
conn.data = []byte{1, 0, 0, 2, 0xfe}
conn.maxReads = 1
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
84, 96, 101, 92, 123, 121, 107}
plugin := "mysql_native_password"
err := mc.handleAuthResult(authData, plugin)
if err != ErrOldPassword {
t.Errorf("expected ErrOldPassword, got %v", err)
}
}
func TestAuthSwitchOldPassword(t *testing.T) {
conn, mc := newRWMockConn(2)
mc.cfg.AllowOldPasswords = true
@ -1092,6 +1109,32 @@ func TestAuthSwitchOldPassword(t *testing.T) {
}
}
// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request.
func TestOldAuthSwitch(t *testing.T) {
conn, mc := newRWMockConn(2)
mc.cfg.AllowOldPasswords = true
mc.cfg.Passwd = "secret"
// OldAuthSwitch request
conn.data = []byte{1, 0, 0, 2, 0xfe}
// auth response
conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
conn.maxReads = 2
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
84, 96, 101, 92, 123, 121, 107}
plugin := "mysql_native_password"
if err := mc.handleAuthResult(authData, plugin); err != nil {
t.Errorf("got error: %v", err)
}
expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0}
if !bytes.Equal(conn.written, expectedReply) {
t.Errorf("got unexpected data: %v", conn.written)
}
}
func TestAuthSwitchOldPasswordEmpty(t *testing.T) {
conn, mc := newRWMockConn(2)
mc.cfg.AllowOldPasswords = true
@ -1120,6 +1163,33 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) {
}
}
// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request.
func TestOldAuthSwitchPasswordEmpty(t *testing.T) {
conn, mc := newRWMockConn(2)
mc.cfg.AllowOldPasswords = true
mc.cfg.Passwd = ""
// OldAuthSwitch request.
conn.data = []byte{1, 0, 0, 2, 0xfe}
// auth response
conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
conn.maxReads = 2
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
84, 96, 101, 92, 123, 121, 107}
plugin := "mysql_native_password"
if err := mc.handleAuthResult(authData, plugin); err != nil {
t.Errorf("got error: %v", err)
}
expectedReply := []byte{1, 0, 0, 3, 0}
if !bytes.Equal(conn.written, expectedReply) {
t.Errorf("got unexpected data: %v", conn.written)
}
}
func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) {
conn, mc := newRWMockConn(2)
mc.cfg.Passwd = ""

View File

@ -22,17 +22,17 @@ const defaultBufSize = 4096
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
type buffer struct {
buf []byte
buf []byte // buf is a byte buffer who's length and capacity are equal.
nc net.Conn
idx int
length int
timeout time.Duration
}
// newBuffer allocates and returns a new buffer.
func newBuffer(nc net.Conn) buffer {
var b [defaultBufSize]byte
return buffer{
buf: b[:],
buf: make([]byte, defaultBufSize),
nc: nc,
}
}
@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) {
return b.buf[offset:b.idx], nil
}
// returns a buffer with the requested size.
// takeBuffer returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
func (b *buffer) takeBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil
return nil, ErrBusyBuffer
}
// test (cheap) general case first
if length <= defaultBufSize || length <= cap(b.buf) {
return b.buf[:length]
if length <= cap(b.buf) {
return b.buf[:length], nil
}
if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf
return b.buf, nil
}
return make([]byte, length)
// buffer is larger than we want to store.
return make([]byte, length), nil
}
// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// takeSmallBuffer is shortcut which can be used if length is
// known to be smaller than defaultBufSize.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil
return nil, ErrBusyBuffer
}
return b.buf[:length]
return b.buf[:length], nil
}
// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// cap and len of the returned buffer will be equal.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() []byte {
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
if b.length > 0 {
return nil
return nil, ErrBusyBuffer
}
return b.buf
return b.buf, nil
}
// store stores buf, an updated buffer, if its suitable to do so.
func (b *buffer) store(buf []byte) error {
if b.length > 0 {
return ErrBusyBuffer
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
b.buf = buf[:cap(buf)]
}
return nil
}

View File

@ -19,16 +19,6 @@ import (
"time"
)
// a copy of context.Context for Go 1.7 and earlier
type mysqlContext interface {
Done() <-chan struct{}
Err() error
// defined in context.Context, but not used in this driver:
// Deadline() (deadline time.Time, ok bool)
// Value(key interface{}) interface{}
}
type mysqlConn struct {
buf buffer
netConn net.Conn
@ -45,7 +35,7 @@ type mysqlConn struct {
// for context support (Go 1.8+)
watching bool
watcher chan<- mysqlContext
watcher chan<- context.Context
closech chan struct{}
finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled
@ -192,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
return "", driver.ErrSkip
}
buf := mc.buf.takeCompleteBuffer()
if buf == nil {
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return "", ErrInvalidConn
}
buf = buf[:0]
@ -475,7 +465,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
defer mc.finish()
if err = mc.writeCommandPacket(comPing); err != nil {
return
return mc.markBadConn(err)
}
return mc.readResultOK()
@ -595,33 +585,32 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error {
mc.cleanup()
return nil
}
// When ctx is already cancelled, don't watch it.
if err := ctx.Err(); err != nil {
return err
}
// When ctx is not cancellable, don't watch it.
if ctx.Done() == nil {
return nil
}
mc.watching = true
select {
default:
case <-ctx.Done():
return ctx.Err()
}
// When watcher is not alive, can't watch it.
if mc.watcher == nil {
return nil
}
mc.watching = true
mc.watcher <- ctx
return nil
}
func (mc *mysqlConn) startWatcher() {
watcher := make(chan mysqlContext, 1)
watcher := make(chan context.Context, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx mysqlContext
var ctx context.Context
select {
case ctx = <-watcher:
case <-mc.closech:

View File

@ -9,7 +9,10 @@
package mysql
import (
"context"
"database/sql/driver"
"errors"
"net"
"testing"
)
@ -79,3 +82,76 @@ func TestCheckNamedValue(t *testing.T) {
t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
}
}
// TestCleanCancel tests passed context is cancelled at start.
// No packet should be sent. Connection should keep current status.
func TestCleanCancel(t *testing.T) {
mc := &mysqlConn{
closech: make(chan struct{}),
}
mc.startWatcher()
defer mc.cleanup()
ctx, cancel := context.WithCancel(context.Background())
cancel()
for i := 0; i < 3; i++ { // Repeat same behavior
err := mc.Ping(ctx)
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %#v", err)
}
if mc.closed.IsSet() {
t.Error("expected mc is not closed, closed actually")
}
if mc.watching {
t.Error("expected watching is false, but true")
}
}
}
func TestPingMarkBadConnection(t *testing.T) {
nc := badConnection{err: errors.New("boom")}
ms := &mysqlConn{
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
}
err := ms.Ping(context.Background())
if err != driver.ErrBadConn {
t.Errorf("expected driver.ErrBadConn, got %#v", err)
}
}
func TestPingErrInvalidConn(t *testing.T) {
nc := badConnection{err: errors.New("failed to write"), n: 10}
ms := &mysqlConn{
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
}
err := ms.Ping(context.Background())
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %#v", err)
}
}
type badConnection struct {
n int
err error
net.Conn
}
func (bc badConnection) Write(b []byte) (n int, err error) {
return bc.n, bc.err
}
func (bc badConnection) Close() error {
return nil
}

View File

@ -9,7 +9,7 @@
// The driver should be used via the database/sql package:
//
// import "database/sql"
// import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
// import _ "github.com/go-sql-driver/mysql"
//
// db, err := sql.Open("mysql", "user:password@/dbname")
//
@ -50,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) {
// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formated
// the DSN string is formatted
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
var err error
@ -77,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
}
if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
errLog.Print("net.Error from Dial()': ", nerr.Error())
return nil, driver.ErrBadConn
}
return nil, err
}
@ -110,18 +114,18 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
}
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
authResp, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, addNUL, err = mc.auth(authData, plugin)
authResp, err = mc.auth(authData, plugin)
if err != nil {
mc.cleanup()
return nil, err
}
}
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
mc.cleanup()
return nil, err
}

View File

@ -85,6 +85,23 @@ type DBTest struct {
db *sql.DB
}
type netErrorMock struct {
temporary bool
timeout bool
}
func (e netErrorMock) Temporary() bool {
return e.temporary
}
func (e netErrorMock) Timeout() bool {
return e.timeout
}
func (e netErrorMock) Error() string {
return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout)
}
func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
@ -1287,7 +1304,7 @@ func TestFoundRows(t *testing.T) {
}
func TestTLS(t *testing.T) {
tlsTest := func(dbt *DBTest) {
tlsTestReq := func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
if err == ErrNoTLS {
dbt.Skip("server does not support TLS")
@ -1304,19 +1321,27 @@ func TestTLS(t *testing.T) {
dbt.Fatal(err.Error())
}
if value == nil {
dbt.Fatal("no Cipher")
if (*value == nil) || (len(*value) == 0) {
dbt.Fatalf("no Cipher")
} else {
dbt.Logf("Cipher: %s", *value)
}
}
}
tlsTestOpt := func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
dbt.Fatalf("error on Ping: %s", err.Error())
}
}
runTests(t, dsn+"&tls=skip-verify", tlsTest)
runTests(t, dsn+"&tls=preferred", tlsTestOpt)
runTests(t, dsn+"&tls=skip-verify", tlsTestReq)
// Verify that registering / using a custom cfg works
RegisterTLSConfig("custom-skip-verify", &tls.Config{
InsecureSkipVerify: true,
})
runTests(t, dsn+"&tls=custom-skip-verify", tlsTest)
runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq)
}
func TestReuseClosedConnection(t *testing.T) {
@ -1801,6 +1826,38 @@ func TestConcurrent(t *testing.T) {
})
}
func testDialError(t *testing.T, dialErr error, expectErr error) {
RegisterDial("mydial", func(addr string) (net.Conn, error) {
return nil, dialErr
})
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
_, err = db.Exec("DO 1")
if err != expectErr {
t.Fatalf("was expecting %s. Got: %s", dialErr, err)
}
}
func TestDialUnknownError(t *testing.T) {
testErr := fmt.Errorf("test")
testDialError(t, testErr, testErr)
}
func TestDialNonRetryableNetErr(t *testing.T) {
testErr := netErrorMock{}
testDialError(t, testErr, testErr)
}
func TestDialTemporaryNetErr(t *testing.T) {
testErr := netErrorMock{temporary: true}
testDialError(t, testErr, driver.ErrBadConn)
}
// Tests custom dial functions
func TestCustomDial(t *testing.T) {
if !available {

View File

@ -560,7 +560,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
} else {
cfg.TLSConfig = "false"
}
} else if vl := strings.ToLower(value); vl == "skip-verify" {
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
cfg.TLSConfig = vl
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else {

View File

@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
mc.sequence++
// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)1 bytes long
// multiple of (2^24)-1 bytes long
if pktLen == 0 {
// there was no previous packet
if prevData == nil {
@ -194,7 +194,11 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
return nil, "", ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
return nil, "", ErrNoTLS
if mc.cfg.TLSConfig == "preferred" {
mc.cfg.tls = nil
} else {
return nil, "", ErrNoTLS
}
}
pos += 2
@ -243,7 +247,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
@ -269,7 +273,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
// encode length of the auth plugin data
var authRespLEIBuf [9]byte
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
authRespLen := len(authResp)
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
if len(authRespLEI) > 1 {
// if the length can not be written in 1 byte, it must be written as a
// length encoded integer
@ -277,9 +282,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
}
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
if addNUL {
pktLen++
}
// To specify a db name
if n := len(mc.cfg.DBName); n > 0 {
@ -288,10 +290,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
}
// Calculate packet length and get buffer with that size
data := mc.buf.takeSmallBuffer(pktLen + 4)
if data == nil {
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
@ -350,10 +352,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
// Auth Data [length encoded integer]
pos += copy(data[pos:], authRespLEI)
pos += copy(data[pos:], authResp)
if addNUL {
data[pos] = 0x00
pos++
}
// Databasename [null terminated string]
if len(mc.cfg.DBName) > 0 {
@ -364,30 +362,24 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
pos += copy(data[pos:], plugin)
data[pos] = 0x00
pos++
// Send Auth packet
return mc.writePacket(data)
return mc.writePacket(data[:pos])
}
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
pktLen := 4 + len(authData)
if addNUL {
pktLen++
}
data := mc.buf.takeSmallBuffer(pktLen)
if data == nil {
data, err := mc.buf.takeSmallBuffer(pktLen)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
// Add the auth data [EOF]
copy(data[4:], authData)
if addNUL {
data[pktLen-1] = 0x00
}
return mc.writePacket(data)
}
@ -399,10 +391,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence
mc.sequence = 0
data := mc.buf.takeSmallBuffer(4 + 1)
if data == nil {
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
@ -418,10 +410,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
mc.sequence = 0
pktLen := 1 + len(arg)
data := mc.buf.takeBuffer(pktLen + 4)
if data == nil {
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
@ -439,10 +431,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence
mc.sequence = 0
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
if data == nil {
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
@ -479,7 +471,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
return data[1:], "", err
case iEOF:
if len(data) < 1 {
if len(data) == 1 {
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
return nil, "mysql_old_password", nil
}
@ -895,7 +887,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
const minPktLen = 4 + 1 + 4 + 1 + 4
mc := stmt.mc
// Determine threshould dynamically to avoid packet size shortage.
// Determine threshold dynamically to avoid packet size shortage.
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
if longDataSize < 64 {
longDataSize = 64
@ -905,15 +897,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
mc.sequence = 0
var data []byte
var err error
if len(args) == 0 {
data = mc.buf.takeBuffer(minPktLen)
data, err = mc.buf.takeBuffer(minPktLen)
} else {
data = mc.buf.takeCompleteBuffer()
data, err = mc.buf.takeCompleteBuffer()
// In this case the len(data) == cap(data) which is used to optimise the flow below.
}
if data == nil {
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
errLog.Print(err)
return errBadConnNoWrite
}
@ -939,7 +933,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
pos := minPktLen
var nullMask []byte
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
// buffer has to be extended but we don't know by how much so
// we depend on append after all data with known sizes fit.
// We stop at that because we deal with a lot of columns here
@ -948,10 +942,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
copy(tmp[:pos], data[:pos])
data = tmp
nullMask = data[pos : pos+maskLen]
// No need to clean nullMask as make ensures that.
pos += maskLen
} else {
nullMask = data[pos : pos+maskLen]
for i := 0; i < maskLen; i++ {
for i := range nullMask {
nullMask[i] = 0
}
pos += maskLen
@ -1088,7 +1083,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
mc.buf.buf = data
if err = mc.buf.store(data); err != nil {
errLog.Print(err)
return errBadConnNoWrite
}
}
pos += len(paramValues)