mirror of
https://gitee.com/johng/gf
synced 2026-06-06 02:25:47 +08:00
refract gdb package, add complete unit test cases, almost there
This commit is contained in:
@ -30,27 +30,39 @@ const (
|
||||
|
||||
// 数据库操作接口
|
||||
type DB interface {
|
||||
// 建立数据库连接方法(开发者一般不需要直接调用)
|
||||
Open(config *ConfigNode) (*sql.DB, error)
|
||||
|
||||
// SQL操作方法
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
Exec(sql string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(sql string, execOnMaster...bool) (*sql.Stmt, error)
|
||||
|
||||
doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error)
|
||||
doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error)
|
||||
doPrepare(link dbLink, query string) (*sql.Stmt, error)
|
||||
doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error)
|
||||
doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error)
|
||||
doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error)
|
||||
doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error)
|
||||
|
||||
// 数据库查询
|
||||
GetAll(query string, args ...interface{}) (Result, error)
|
||||
GetOne(query string, args ...interface{}) (Record, error)
|
||||
GetValue(query string, args ...interface{}) (Value, error)
|
||||
GetCount(query string, args ...interface{}) (int, error)
|
||||
GetStruct(obj interface{}, query string, args ...interface{}) error
|
||||
|
||||
// Ping
|
||||
// 创建底层数据库master/slave链接对象
|
||||
Master() (*sql.DB, error)
|
||||
Slave() (*sql.DB, error)
|
||||
|
||||
// Ping
|
||||
PingMaster() error
|
||||
PingSlave() error
|
||||
|
||||
// 连接属性设置
|
||||
SetMaxIdleConns(n int)
|
||||
SetMaxOpenConns(n int)
|
||||
SetConnMaxLifetime(n int)
|
||||
|
||||
// 开启事务操作
|
||||
Begin() (*Tx, error)
|
||||
Begin() (*TX, error)
|
||||
|
||||
// 数据表插入/更新/保存操作
|
||||
Insert(table string, data Map) (sql.Result, error)
|
||||
@ -72,17 +84,26 @@ type DB interface {
|
||||
|
||||
// 设置管理
|
||||
SetDebug(debug bool)
|
||||
GetQueriedSqls() []*Sql
|
||||
PrintQueriedSqls()
|
||||
SetMaxIdleConns(n int)
|
||||
SetMaxOpenConns(n int)
|
||||
SetConnMaxLifetime(n int)
|
||||
|
||||
// 内部方法接口
|
||||
open(c *ConfigNode) (*sql.DB, error)
|
||||
getCache() (*gcache.Cache)
|
||||
getChars() (charLeft string, charRight string)
|
||||
getDebug() bool
|
||||
putSql(s *Sql)
|
||||
formatCondition(condition interface{}) (where string)
|
||||
handleSqlBeforeExec(sql string) string
|
||||
}
|
||||
|
||||
// 执行底层数据库操作的核心接口
|
||||
type dbLink interface {
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
Exec(sql string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(sql string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
// 数据库链接对象
|
||||
type dbBase struct {
|
||||
db DB // 数据库对象
|
||||
@ -228,36 +249,36 @@ func getConfigNodeByPriority(cg ConfigGroup) *ConfigNode {
|
||||
}
|
||||
|
||||
// 获得底层数据库链接对象
|
||||
func (db *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
|
||||
func (bs *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
|
||||
// 负载均衡
|
||||
node, err := getConfigNodeByGroup(db.group, master)
|
||||
node, err := getConfigNodeByGroup(bs.group, master)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 检查缓存连接池对象
|
||||
cacheKey := node.String()
|
||||
if v := db.cache.Get(cacheKey); v != nil {
|
||||
if v := bs.cache.Get(cacheKey); v != nil {
|
||||
return v.(*sql.DB), nil
|
||||
}
|
||||
v := db.cache.GetOrSetFuncLock(node.String(), func() interface{} {
|
||||
sqlDb, err = db.db.open(node)
|
||||
v := bs.cache.GetOrSetFuncLock(node.String(), func() interface{} {
|
||||
sqlDb, err = bs.db.Open(node)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if n := db.maxIdleConnCount.Val(); n > 0 {
|
||||
if n := bs.maxIdleConnCount.Val(); n > 0 {
|
||||
sqlDb.SetMaxIdleConns(n)
|
||||
} else if node.MaxIdleConnCount > 0 {
|
||||
sqlDb.SetMaxIdleConns(node.MaxIdleConnCount)
|
||||
}
|
||||
|
||||
if n := db.maxOpenConnCount.Val(); n > 0 {
|
||||
if n := bs.maxOpenConnCount.Val(); n > 0 {
|
||||
sqlDb.SetMaxOpenConns(n)
|
||||
} else if node.MaxOpenConnCount > 0 {
|
||||
sqlDb.SetMaxOpenConns(node.MaxOpenConnCount)
|
||||
}
|
||||
|
||||
if n := db.maxConnLifetime.Val(); n > 0 {
|
||||
if n := bs.maxConnLifetime.Val(); n > 0 {
|
||||
sqlDb.SetConnMaxLifetime(time.Duration(n) * time.Second)
|
||||
} else if node.MaxConnLifetime > 0 {
|
||||
sqlDb.SetConnMaxLifetime(time.Duration(node.MaxConnLifetime) * time.Second)
|
||||
@ -271,11 +292,11 @@ func (db *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) {
|
||||
}
|
||||
|
||||
// 创建底层数据库master链接对象
|
||||
func (db *dbBase) Master() (*sql.DB, error) {
|
||||
return db.getSqlDb(true)
|
||||
func (bs *dbBase) Master() (*sql.DB, error) {
|
||||
return bs.getSqlDb(true)
|
||||
}
|
||||
|
||||
// 创建底层数据库slave链接对象
|
||||
func (db *dbBase) Slave() (*sql.DB, error) {
|
||||
return db.getSqlDb(false)
|
||||
func (bs *dbBase) Slave() (*sql.DB, error) {
|
||||
return bs.getSqlDb(false)
|
||||
}
|
||||
|
||||
@ -14,7 +14,7 @@ import (
|
||||
"gitee.com/johng/gf/g/os/gcache"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"gitee.com/johng/gf/g/util/gstr"
|
||||
"gitee.com/johng/gf/g/util/gregex"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
@ -24,13 +24,13 @@ const (
|
||||
)
|
||||
|
||||
// 获取已经执行的SQL列表(仅在debug=true时有效)
|
||||
func (db *dbBase) GetQueriedSqls() []*Sql {
|
||||
if db.sqls == nil {
|
||||
func (bs *dbBase) GetQueriedSqls() []*Sql {
|
||||
if bs.sqls == nil {
|
||||
return nil
|
||||
}
|
||||
sqls := make([]*Sql, 0)
|
||||
db.sqls.Prev()
|
||||
db.sqls.RLockIteratorPrev(func(value interface{}) bool {
|
||||
bs.sqls.Prev()
|
||||
bs.sqls.RLockIteratorPrev(func(value interface{}) bool {
|
||||
if value == nil {
|
||||
return false
|
||||
}
|
||||
@ -41,8 +41,8 @@ func (db *dbBase) GetQueriedSqls() []*Sql {
|
||||
}
|
||||
|
||||
// 打印已经执行的SQL列表(仅在debug=true时有效)
|
||||
func (db *dbBase) PrintQueriedSqls() {
|
||||
sqls := db.GetQueriedSqls()
|
||||
func (bs *dbBase) PrintQueriedSqls() {
|
||||
sqls := bs.GetQueriedSqls()
|
||||
for k, v := range sqls {
|
||||
fmt.Println(len(sqls) - k, ":")
|
||||
fmt.Println(" Sql :", v.Sql)
|
||||
@ -55,51 +55,106 @@ func (db *dbBase) PrintQueriedSqls() {
|
||||
}
|
||||
|
||||
// 数据库sql查询操作,主要执行查询
|
||||
func (db *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
link, err := db.Slave()
|
||||
func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
link, err := bs.db.Slave()
|
||||
if err != nil {
|
||||
return nil,err
|
||||
}
|
||||
return doQuery(db.db, link, query, args...)
|
||||
return bs.db.doQuery(link, query, args...)
|
||||
}
|
||||
|
||||
// 数据库sql查询操作,主要执行查询
|
||||
func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
query = bs.db.handleSqlBeforeExec(query)
|
||||
if bs.db.getDebug() {
|
||||
mTime1 := gtime.Millisecond()
|
||||
rows, err = link.Query(query, args...)
|
||||
mTime2 := gtime.Millisecond()
|
||||
s := &Sql {
|
||||
Sql : query,
|
||||
Args : args,
|
||||
Error : err,
|
||||
Start : mTime1,
|
||||
End : mTime2,
|
||||
}
|
||||
bs.sqls.Put(s)
|
||||
printSql(s)
|
||||
} else {
|
||||
rows, err = link.Query(query, args ...)
|
||||
}
|
||||
if err == nil {
|
||||
return rows, nil
|
||||
} else {
|
||||
err = formatError(err, query, args...)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行一条sql,并返回执行情况,主要用于非查询操作
|
||||
func (db *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) {
|
||||
link, err := db.Master()
|
||||
func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) {
|
||||
link, err := bs.db.Master()
|
||||
if err != nil {
|
||||
return nil,err
|
||||
}
|
||||
return doExec(db.db, link, query, args...)
|
||||
return bs.db.doExec(link, query, args...)
|
||||
}
|
||||
|
||||
// 执行一条sql,并返回执行情况,主要用于非查询操作
|
||||
func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
query = bs.db.handleSqlBeforeExec(query)
|
||||
if bs.db.getDebug() {
|
||||
mTime1 := gtime.Millisecond()
|
||||
result, err = link.Exec(query, args ...)
|
||||
mTime2 := gtime.Millisecond()
|
||||
s := &Sql{
|
||||
Sql : query,
|
||||
Args : args,
|
||||
Error : err,
|
||||
Start : mTime1,
|
||||
End : mTime2,
|
||||
}
|
||||
bs.sqls.Put(s)
|
||||
printSql(s)
|
||||
} else {
|
||||
result, err = link.Exec(query, args ...)
|
||||
}
|
||||
return result, formatError(err, query, args...)
|
||||
}
|
||||
|
||||
// SQL预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作; 默认执行在Slave上, 通过第二个参数指定执行在Master上
|
||||
func (db *dbBase) Prepare(query string, execOnMaster...bool) (*sql.Stmt, error) {
|
||||
func (bs *dbBase) Prepare(query string, execOnMaster...bool) (*sql.Stmt, error) {
|
||||
err := (error)(nil)
|
||||
sqldb := (*sql.DB)(nil)
|
||||
link := (dbLink)(nil)
|
||||
if len(execOnMaster) > 0 && execOnMaster[0] {
|
||||
if sqldb, err = db.Master(); err != nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if sqldb, err = db.Slave(); err != nil {
|
||||
if link, err = bs.db.Slave(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return sqldb.Prepare(query)
|
||||
return bs.db.doPrepare(link, query)
|
||||
}
|
||||
|
||||
// SQL预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作
|
||||
func (bs *dbBase) doPrepare(link dbLink, query string) (*sql.Stmt, error) {
|
||||
return link.Prepare(query)
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询结果集,以列表结构返回
|
||||
func (db *dbBase) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
rows, err := db.Query(query, args ...)
|
||||
func (bs *dbBase) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
rows, err := bs.Query(query, args ...)
|
||||
if err != nil || rows == nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return rowsToResult(rows)
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询结果记录,以关联数组结构返回
|
||||
func (db *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
list, err := db.GetAll(query, args ...)
|
||||
func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
list, err := bs.GetAll(query, args ...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -110,18 +165,17 @@ func (db *dbBase) GetOne(query string, args ...interface{}) (Record, error) {
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询结果记录,自动映射数据到给定的struct对象中
|
||||
func (db *dbBase) GetStruct(obj interface{}, query string, args ...interface{}) error {
|
||||
one, err := db.GetOne(query, args...)
|
||||
func (bs *dbBase) GetStruct(obj interface{}, query string, args ...interface{}) error {
|
||||
one, err := bs.GetOne(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return one.ToStruct(obj)
|
||||
}
|
||||
|
||||
|
||||
// 数据库查询,获取查询字段值
|
||||
func (db *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
one, err := db.GetOne(query, args ...)
|
||||
func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
one, err := bs.GetOne(query, args ...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -132,35 +186,20 @@ func (db *dbBase) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询数量
|
||||
func (db *dbBase) GetCount(query string, args ...interface{}) (int, error) {
|
||||
val, err := db.GetValue(query, args ...)
|
||||
func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) {
|
||||
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
|
||||
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
|
||||
}
|
||||
value, err := bs.GetValue(query, args ...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return gconv.Int(val), nil
|
||||
}
|
||||
|
||||
// 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作
|
||||
func (db *dbBase) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) {
|
||||
s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables)
|
||||
if condition != nil {
|
||||
s += fmt.Sprintf("WHERE %s ", db.formatCondition(condition))
|
||||
}
|
||||
if len(groupBy) > 0 {
|
||||
s += fmt.Sprintf("GROUP BY %s ", groupBy)
|
||||
}
|
||||
if len(orderBy) > 0 {
|
||||
s += fmt.Sprintf("ORDER BY %s ", orderBy)
|
||||
}
|
||||
if limit > 0 {
|
||||
s += fmt.Sprintf("LIMIT %d,%d ", first, limit)
|
||||
}
|
||||
return db.GetAll(s, args ... )
|
||||
return value.Int(), nil
|
||||
}
|
||||
|
||||
// ping一下,判断或保持数据库链接(master)
|
||||
func (db *dbBase) PingMaster() error {
|
||||
if master, err := db.Master(); err != nil {
|
||||
func (bs *dbBase) PingMaster() error {
|
||||
if master, err := bs.db.Master(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return master.Ping()
|
||||
@ -168,8 +207,8 @@ func (db *dbBase) PingMaster() error {
|
||||
}
|
||||
|
||||
// ping一下,判断或保持数据库链接(slave)
|
||||
func (db *dbBase) PingSlave() error {
|
||||
if slave, err := db.Slave(); err != nil {
|
||||
func (bs *dbBase) PingSlave() error {
|
||||
if slave, err := bs.db.Slave(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return slave.Ping()
|
||||
@ -178,13 +217,13 @@ func (db *dbBase) PingSlave() error {
|
||||
|
||||
// 事务操作,开启,会返回一个底层的事务操作对象链接如需要嵌套事务,那么可以使用该对象,否则请忽略
|
||||
// 只有在tx.Commit/tx.Rollback时,链接会自动Close
|
||||
func (db *dbBase) Begin() (*TX, error) {
|
||||
if master, err := db.Master(); err != nil {
|
||||
func (bs *dbBase) Begin() (*TX, error) {
|
||||
if master, err := bs.db.Master(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
if tx, err := master.Begin(); err == nil {
|
||||
return &TX {
|
||||
db : db.db,
|
||||
db : bs.db,
|
||||
tx : tx,
|
||||
master : master,
|
||||
}, nil
|
||||
@ -194,16 +233,31 @@ func (db *dbBase) Begin() (*TX, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
|
||||
func (bs *dbBase) Insert(table string, data Map) (sql.Result, error) {
|
||||
return bs.db.doInsert(nil, table, data, OPTION_INSERT)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (bs *dbBase) Replace(table string, data Map) (sql.Result, error) {
|
||||
return bs.db.doInsert(nil, table, data, OPTION_REPLACE)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (bs *dbBase) Save(table string, data Map) (sql.Result, error) {
|
||||
return bs.db.doInsert(nil, table, data, OPTION_SAVE)
|
||||
}
|
||||
|
||||
// insert、replace, save, ignore操作
|
||||
// 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
|
||||
// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做
|
||||
func (db *dbBase) insert(table string, data Map, option uint8) (sql.Result, error) {
|
||||
func (bs *dbBase) doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error) {
|
||||
var fields []string
|
||||
var values []string
|
||||
var params []interface{}
|
||||
charl, charr := db.db.getChars()
|
||||
charl, charr := bs.db.getChars()
|
||||
for k, v := range data {
|
||||
fields = append(fields, charl + k + charr)
|
||||
values = append(values, "?")
|
||||
@ -223,48 +277,53 @@ func (db *dbBase) insert(table string, data Map, option uint8) (sql.Result, erro
|
||||
}
|
||||
updatestr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
|
||||
}
|
||||
return db.Exec(
|
||||
fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
|
||||
operation, table, strings.Join(fields, ","),
|
||||
strings.Join(values, ","),
|
||||
updatestr),
|
||||
params...
|
||||
)
|
||||
if link == nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
|
||||
operation, table, strings.Join(fields, ","),
|
||||
strings.Join(values, ","), updatestr),
|
||||
params...)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
|
||||
func (db *dbBase) Insert(table string, data Map) (sql.Result, error) {
|
||||
return db.insert(table, data, OPTION_INSERT)
|
||||
// CURD操作:批量数据指定批次量写入
|
||||
func (bs *dbBase) BatchInsert(table string, list List, batch int) (sql.Result, error) {
|
||||
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_INSERT)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (db *dbBase) Replace(table string, data Map) (sql.Result, error) {
|
||||
return db.insert(table, data, OPTION_REPLACE)
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (bs *dbBase) BatchReplace(table string, list List, batch int) (sql.Result, error) {
|
||||
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_REPLACE)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (db *dbBase) Save(table string, data Map) (sql.Result, error) {
|
||||
return db.insert(table, data, OPTION_SAVE)
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (bs *dbBase) BatchSave(table string, list List, batch int) (sql.Result, error) {
|
||||
return bs.db.doBatchInsert(nil, table, list, batch, OPTION_SAVE)
|
||||
}
|
||||
|
||||
// 批量写入数据
|
||||
func (db *dbBase) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) {
|
||||
func (bs *dbBase) doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error) {
|
||||
var keys []string
|
||||
var values []string
|
||||
var bvalues []string
|
||||
var params []interface{}
|
||||
var result sql.Result
|
||||
var size = len(list)
|
||||
// 判断长度
|
||||
if size < 1 {
|
||||
if len(list) < 1 {
|
||||
return result, errors.New("empty data list")
|
||||
}
|
||||
if link == nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
// 首先获取字段名称及记录长度
|
||||
for k, _ := range list[0] {
|
||||
keys = append(keys, k)
|
||||
values = append(values, "?")
|
||||
}
|
||||
charl, charr := db.db.getChars()
|
||||
charl, charr := bs.db.getChars()
|
||||
keyStr := charl + strings.Join(keys, charl + "," + charr) + charr
|
||||
valueHolderStr := "(" + strings.Join(values, ",") + ")"
|
||||
// 操作判断
|
||||
@ -283,13 +342,13 @@ func (db *dbBase) batchInsert(table string, list List, batch int, option uint8)
|
||||
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
|
||||
}
|
||||
// 构造批量写入数据格式(注意map的遍历是无序的)
|
||||
for i := 0; i < size; i++ {
|
||||
for i := 0; i < len(list); i++ {
|
||||
for _, k := range keys {
|
||||
params = append(params, list[i][k])
|
||||
}
|
||||
bvalues = append(bvalues, valueHolderStr)
|
||||
if len(bvalues) == batch {
|
||||
r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
operation, table, keyStr, strings.Join(bvalues, ","),
|
||||
updatestr),
|
||||
params...)
|
||||
@ -303,7 +362,7 @@ func (db *dbBase) batchInsert(table string, list List, batch int, option uint8)
|
||||
}
|
||||
// 处理最后不构成指定批量的数据
|
||||
if len(bvalues) > 0 {
|
||||
r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
operation, table, keyStr, strings.Join(bvalues, ","),
|
||||
updatestr),
|
||||
params...)
|
||||
@ -315,27 +374,22 @@ func (db *dbBase) batchInsert(table string, list List, batch int, option uint8)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入
|
||||
func (db *dbBase) BatchInsert(table string, list List, batch int) (sql.Result, error) {
|
||||
return db.batchInsert(table, list, batch, OPTION_INSERT)
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (db *dbBase) BatchReplace(table string, list List, batch int) (sql.Result, error) {
|
||||
return db.batchInsert(table, list, batch, OPTION_REPLACE)
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (db *dbBase) BatchSave(table string, list List, batch int) (sql.Result, error) {
|
||||
return db.batchInsert(table, list, batch, OPTION_SAVE)
|
||||
// CURD操作:数据更新,统一采用sql预处理
|
||||
// data参数支持字符串或者关联数组类型,内部会自行做判断处理
|
||||
func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
link, err := bs.db.Master()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bs.db.doUpdate(link, table, data, condition, args ...)
|
||||
}
|
||||
|
||||
// CURD操作:数据更新,统一采用sql预处理
|
||||
// data参数支持字符串或者关联数组类型,内部会自行做判断处理
|
||||
func (db *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
var params []interface{}
|
||||
var updates string
|
||||
charl, charr := db.db.getChars()
|
||||
func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error) {
|
||||
params := ([]interface{})(nil)
|
||||
updates := ""
|
||||
charl, charr := bs.db.getChars()
|
||||
refValue := reflect.ValueOf(data)
|
||||
if refValue.Kind() == reflect.Map {
|
||||
var fields []string
|
||||
@ -351,44 +405,29 @@ func (db *dbBase) Update(table string, data interface{}, condition interface{},
|
||||
for _, v := range args {
|
||||
params = append(params, gconv.String(v))
|
||||
}
|
||||
return db.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, db.formatCondition(condition)), params...)
|
||||
if link == nil {
|
||||
if link, err = bs.db.Master(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, formatCondition(condition)), params...)
|
||||
}
|
||||
|
||||
// CURD操作:删除数据
|
||||
func (db *dbBase) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
return db.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, db.formatCondition(condition)), args...)
|
||||
func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
|
||||
link, err := bs.db.Master()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bs.db.doDelete(link, table, condition, args ...)
|
||||
}
|
||||
|
||||
// 格式化SQL查询条件
|
||||
func (db *dbBase) formatCondition(condition interface{}) (where string) {
|
||||
if reflect.ValueOf(condition).Kind() == reflect.Map {
|
||||
ks := reflect.ValueOf(condition).MapKeys()
|
||||
vs := reflect.ValueOf(condition)
|
||||
for _, k := range ks {
|
||||
key := gconv.String(k.Interface())
|
||||
value := gconv.String(vs.MapIndex(k).Interface())
|
||||
isNum := gstr.IsNumeric(value)
|
||||
if len(where) > 0 {
|
||||
where += " AND "
|
||||
}
|
||||
if isNum || value == "?" {
|
||||
where += key + "=" + value
|
||||
} else {
|
||||
where += key + "='" + value + "'"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
where += gconv.String(condition)
|
||||
}
|
||||
return
|
||||
// CURD操作:删除数据
|
||||
func (bs *dbBase) doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
|
||||
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, formatCondition(condition)), args...)
|
||||
}
|
||||
|
||||
// 获得缓存对象
|
||||
func (db *dbBase) getCache() *gcache.Cache {
|
||||
return db.cache
|
||||
func (bs *dbBase) getCache() *gcache.Cache {
|
||||
return bs.cache
|
||||
}
|
||||
|
||||
// 记录执行的SQL
|
||||
func (db *dbBase) putSql(s *Sql) {
|
||||
db.sqls.Put(s)
|
||||
}
|
||||
@ -123,19 +123,19 @@ func SetDefaultGroup (groupName string) {
|
||||
}
|
||||
|
||||
// 设置数据库连接池中空闲链接的大小
|
||||
func (db *dbBase) SetMaxIdleConns(n int) {
|
||||
db.maxIdleConnCount.Set(n)
|
||||
func (bs *dbBase) SetMaxIdleConns(n int) {
|
||||
bs.maxIdleConnCount.Set(n)
|
||||
}
|
||||
|
||||
// 设置数据库连接池最大打开的链接数量
|
||||
func (db *dbBase) SetMaxOpenConns(n int) {
|
||||
db.maxOpenConnCount.Set(n)
|
||||
func (bs *dbBase) SetMaxOpenConns(n int) {
|
||||
bs.maxOpenConnCount.Set(n)
|
||||
}
|
||||
|
||||
// 设置数据库连接可重复利用的时间,超过该时间则被关闭废弃
|
||||
// 如果 d <= 0 表示该链接会一直重复利用
|
||||
func (db *dbBase) SetConnMaxLifetime(n int) {
|
||||
db.maxConnLifetime.Set(n)
|
||||
func (bs *dbBase) SetConnMaxLifetime(n int) {
|
||||
bs.maxConnLifetime.Set(n)
|
||||
}
|
||||
|
||||
// 节点配置转换为字符串
|
||||
@ -150,14 +150,14 @@ func (node *ConfigNode) String() string {
|
||||
}
|
||||
|
||||
// 是否开启调试服务
|
||||
func (db *dbBase) SetDebug(debug bool) {
|
||||
db.debug.Set(debug)
|
||||
if debug && db.sqls == nil {
|
||||
db.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH)
|
||||
func (bs *dbBase) SetDebug(debug bool) {
|
||||
bs.debug.Set(debug)
|
||||
if debug && bs.sqls == nil {
|
||||
bs.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取是否开启调试服务
|
||||
func (db *dbBase) getDebug() bool {
|
||||
return db.debug.Val()
|
||||
func (bs *dbBase) getDebug() bool {
|
||||
return bs.debug.Val()
|
||||
}
|
||||
@ -13,65 +13,12 @@ import (
|
||||
"gitee.com/johng/gf/g/container/gvar"
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"gitee.com/johng/gf/g/util/gstr"
|
||||
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
"strings"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type dbLink interface {
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
Exec(sql string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(sql string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
// 数据库sql查询操作,主要执行查询
|
||||
func doQuery(db DB, link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
query = db.handleSqlBeforeExec(query)
|
||||
if db.getDebug() {
|
||||
mTime1 := gtime.Millisecond()
|
||||
rows, err = link.Query(query, args...)
|
||||
mTime2 := gtime.Millisecond()
|
||||
s := &Sql{
|
||||
Sql : query,
|
||||
Args : args,
|
||||
Error : err,
|
||||
Start : mTime1,
|
||||
End : mTime2,
|
||||
}
|
||||
db.putSql(s)
|
||||
printSql(s)
|
||||
} else {
|
||||
rows, err = link.Query(query, args ...)
|
||||
}
|
||||
if err == nil {
|
||||
return rows, nil
|
||||
} else {
|
||||
err = formatError(err, query, args...)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行一条sql,并返回执行情况,主要用于非查询操作
|
||||
func doExec(db DB, link dbLink, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
query = db.handleSqlBeforeExec(query)
|
||||
if db.getDebug() {
|
||||
mTime1 := gtime.Millisecond()
|
||||
result, err = link.Exec(query, args ...)
|
||||
mTime2 := gtime.Millisecond()
|
||||
s := &Sql{
|
||||
Sql : query,
|
||||
Args : args,
|
||||
Error : err,
|
||||
Start : mTime1,
|
||||
End : mTime2,
|
||||
}
|
||||
db.putSql(s)
|
||||
printSql(s)
|
||||
} else {
|
||||
result, err = link.Exec(query, args ...)
|
||||
}
|
||||
return result, formatError(err, query, args...)
|
||||
}
|
||||
|
||||
// 将数据查询的列表数据*sql.Rows转换为Result类型
|
||||
func rowsToResult(rows *sql.Rows) (Result, error) {
|
||||
// 列名称列表
|
||||
@ -103,34 +50,31 @@ func rowsToResult(rows *sql.Rows) (Result, error) {
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func formatInsertQuery(db DB, table string, data Map, option uint8) (string, []interface{}) {
|
||||
var fields []string
|
||||
var values []string
|
||||
var params []interface{}
|
||||
charl, charr := db.getChars()
|
||||
for k, v := range data {
|
||||
fields = append(fields, charl + k + charr)
|
||||
values = append(values, "?")
|
||||
params = append(params, v)
|
||||
}
|
||||
operation := getInsertOperationByOption(option)
|
||||
updatestr := ""
|
||||
if option == OPTION_SAVE {
|
||||
var updates []string
|
||||
for k, _ := range data {
|
||||
updates = append(updates,
|
||||
fmt.Sprintf("%s%s%s=VALUES(%s%s%s)",
|
||||
charl, k, charr,
|
||||
charl, k, charr,
|
||||
),
|
||||
)
|
||||
// 格式化SQL查询条件
|
||||
func formatCondition(condition interface{}) (where string) {
|
||||
if reflect.ValueOf(condition).Kind() == reflect.Map {
|
||||
ks := reflect.ValueOf(condition).MapKeys()
|
||||
vs := reflect.ValueOf(condition)
|
||||
for _, k := range ks {
|
||||
key := gconv.String(k.Interface())
|
||||
value := gconv.String(vs.MapIndex(k).Interface())
|
||||
isNum := gstr.IsNumeric(value)
|
||||
if len(where) > 0 {
|
||||
where += " AND "
|
||||
}
|
||||
if isNum || value == "?" {
|
||||
where += key + "=" + value
|
||||
} else {
|
||||
where += key + "='" + value + "'"
|
||||
}
|
||||
}
|
||||
updatestr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
|
||||
} else {
|
||||
where += gconv.String(condition)
|
||||
}
|
||||
return fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
|
||||
operation, table, strings.Join(fields, ","),
|
||||
strings.Join(values, ","), updatestr),
|
||||
params
|
||||
if len(where) == 0 {
|
||||
where = "1"
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 打印SQL对象(仅在debug=true时有效)
|
||||
@ -163,7 +107,7 @@ func formatError(err error, query string, args ...interface{}) error {
|
||||
}
|
||||
|
||||
// 根据insert选项获得操作名称
|
||||
func getInsertOperationByOption(option uint8) string {
|
||||
func getInsertOperationByOption(option int) string {
|
||||
oper := "INSERT"
|
||||
switch option {
|
||||
case OPTION_REPLACE:
|
||||
|
||||
@ -17,8 +17,8 @@ import (
|
||||
|
||||
// 数据库链式操作模型对象
|
||||
type Model struct {
|
||||
tx *Tx // 数据库事务对象
|
||||
db DB // 数据库操作对象
|
||||
tx *TX // 数据库事务对象
|
||||
tablesInit string // 初始化Model时的表名称(可以是多个)
|
||||
tables string // 数据库操作表
|
||||
fields string // 操作字段
|
||||
@ -36,9 +36,9 @@ type Model struct {
|
||||
}
|
||||
|
||||
// 链式操作,数据表字段,可支持多个表,以半角逗号连接
|
||||
func (db *dbBase) Table(tables string) (*Model) {
|
||||
return &Model{
|
||||
db : db.db,
|
||||
func (bs *dbBase) Table(tables string) (*Model) {
|
||||
return &Model {
|
||||
db : bs.db,
|
||||
tablesInit : tables,
|
||||
tables : tables,
|
||||
fields : "*",
|
||||
@ -46,12 +46,12 @@ func (db *dbBase) Table(tables string) (*Model) {
|
||||
}
|
||||
|
||||
// 链式操作,数据表字段,可支持多个表,以半角逗号连接
|
||||
func (db *dbBase) From(tables string) (*Model) {
|
||||
return db.Table(tables)
|
||||
func (bs *dbBase) From(tables string) (*Model) {
|
||||
return bs.db.Table(tables)
|
||||
}
|
||||
|
||||
// (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接
|
||||
func (tx *Tx) Table(tables string) (*Model) {
|
||||
func (tx *TX) Table(tables string) (*Model) {
|
||||
return &Model{
|
||||
db : tx.db,
|
||||
tx : tx,
|
||||
@ -61,7 +61,7 @@ func (tx *Tx) Table(tables string) (*Model) {
|
||||
}
|
||||
|
||||
// (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接
|
||||
func (tx *Tx) From(tables string) (*Model) {
|
||||
func (tx *TX) From(tables string) (*Model) {
|
||||
return tx.Table(tables)
|
||||
}
|
||||
|
||||
@ -100,7 +100,7 @@ func (md *Model) Fields(fields string) (*Model) {
|
||||
|
||||
// 链式操作,condition,支持string & gdb.Map
|
||||
func (md *Model) Where(where interface{}, args ...interface{}) (*Model) {
|
||||
md.where = md.db.formatCondition(where)
|
||||
md.where = formatCondition(where)
|
||||
md.whereArgs = append(md.whereArgs, args...)
|
||||
// 支持 Where("uid", 1)这种格式
|
||||
if len(args) == 1 && strings.Index(md.where , "?") < 0 {
|
||||
@ -111,14 +111,14 @@ func (md *Model) Where(where interface{}, args ...interface{}) (*Model) {
|
||||
|
||||
// 链式操作,添加AND条件到Where中
|
||||
func (md *Model) And(where interface{}, args ...interface{}) (*Model) {
|
||||
md.where += " AND " + md.db.formatCondition(where)
|
||||
md.where += " AND " + formatCondition(where)
|
||||
md.whereArgs = append(md.whereArgs, args...)
|
||||
return md
|
||||
}
|
||||
|
||||
// 链式操作,添加OR条件到Where中
|
||||
func (md *Model) Or(where interface{}, args ...interface{}) (*Model) {
|
||||
md.where += " OR " + md.db.formatCondition(where)
|
||||
md.where += " OR " + formatCondition(where)
|
||||
md.whereArgs = append(md.whereArgs, args...)
|
||||
return md
|
||||
}
|
||||
@ -220,9 +220,9 @@ func (md *Model) Replace() (result sql.Result, err error) {
|
||||
}
|
||||
} else if dataMap, ok := md.data.(Map); ok {
|
||||
if md.tx == nil {
|
||||
return md.db.Insert(md.tables, dataMap)
|
||||
return md.db.Replace(md.tables, dataMap)
|
||||
} else {
|
||||
return md.tx.Insert(md.tables, dataMap)
|
||||
return md.tx.Replace(md.tables, dataMap)
|
||||
}
|
||||
}
|
||||
return nil, errors.New("replacing into table with invalid data type")
|
||||
@ -286,9 +286,6 @@ func (md *Model) Delete() (result sql.Result, err error) {
|
||||
}
|
||||
md.clear()
|
||||
}()
|
||||
if md.where == "" {
|
||||
return nil, errors.New("where is required while deleting")
|
||||
}
|
||||
if md.tx == nil {
|
||||
return md.db.Delete(md.tables, md.where, md.whereArgs...)
|
||||
} else {
|
||||
@ -320,13 +317,13 @@ func (md *Model) Cache(time int, name ... string) *Model {
|
||||
|
||||
// 链式操作,select
|
||||
func (md *Model) Select() (Result, error) {
|
||||
defer md.clear()
|
||||
return md.getAll(md.getFormattedSql(), md.whereArgs...)
|
||||
return md.All()
|
||||
}
|
||||
|
||||
// 链式操作,查询所有记录
|
||||
func (md *Model) All() (Result, error) {
|
||||
return md.Select()
|
||||
defer md.clear()
|
||||
return md.getAll(md.getFormattedSql(), md.whereArgs...)
|
||||
}
|
||||
|
||||
// 链式操作,查询单条记录
|
||||
|
||||
@ -28,12 +28,13 @@ type dbMssql struct {
|
||||
}
|
||||
|
||||
// 创建SQL操作对象
|
||||
func (db *dbMssql) open(c *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if c.Linkinfo != "" {
|
||||
source = c.Linkinfo
|
||||
func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
source := ""
|
||||
if config.Linkinfo != "" {
|
||||
source = config.Linkinfo
|
||||
} else {
|
||||
source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable", c.User, c.Pass, c.Host, c.Port, c.Name)
|
||||
source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable",
|
||||
config.User, config.Pass, config.Host, config.Port, config.Name)
|
||||
}
|
||||
if db, err := sql.Open("sqlserver", source); err == nil {
|
||||
return db, nil
|
||||
|
||||
@ -18,12 +18,12 @@ type dbMysql struct {
|
||||
}
|
||||
|
||||
// 创建SQL操作对象,内部采用了lazy link处理
|
||||
func (db *dbMysql) open (c *ConfigNode) (*sql.DB, error) {
|
||||
func (db *dbMysql) Open (config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if c.Linkinfo != "" {
|
||||
source = c.Linkinfo
|
||||
if config.Linkinfo != "" {
|
||||
source = config.Linkinfo
|
||||
} else {
|
||||
source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", c.User, c.Pass, c.Host, c.Port, c.Name)
|
||||
source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true", config.User, config.Pass, config.Host, config.Port, config.Name)
|
||||
}
|
||||
if db, err := sql.Open("mysql", source); err == nil {
|
||||
return db, nil
|
||||
|
||||
@ -27,12 +27,12 @@ type dbOracle struct {
|
||||
}
|
||||
|
||||
// 创建SQL操作对象
|
||||
func (db *dbOracle) open(c *ConfigNode) (*sql.DB, error) {
|
||||
func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if c.Linkinfo != "" {
|
||||
source = c.Linkinfo
|
||||
if config.Linkinfo != "" {
|
||||
source = config.Linkinfo
|
||||
} else {
|
||||
source = fmt.Sprintf("%s/%s@%s", c.User, c.Pass, c.Name)
|
||||
source = fmt.Sprintf("%s/%s@%s", config.User, config.Pass, config.Name)
|
||||
}
|
||||
if db, err := sql.Open("oci8", source); err == nil {
|
||||
return db, nil
|
||||
|
||||
@ -24,12 +24,12 @@ type dbPgsql struct {
|
||||
}
|
||||
|
||||
// 创建SQL操作对象,内部采用了lazy link处理
|
||||
func (db *dbPgsql) open (c *ConfigNode) (*sql.DB, error) {
|
||||
func (db *dbPgsql) Open (config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if c.Linkinfo != "" {
|
||||
source = c.Linkinfo
|
||||
if config.Linkinfo != "" {
|
||||
source = config.Linkinfo
|
||||
} else {
|
||||
source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", c.User, c.Pass, c.Host, c.Port, c.Name)
|
||||
source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", config.User, config.Pass, config.Host, config.Port, config.Name)
|
||||
}
|
||||
if db, err := sql.Open("postgres", source); err == nil {
|
||||
return db, nil
|
||||
|
||||
@ -22,12 +22,12 @@ type dbSqlite struct {
|
||||
*dbBase
|
||||
}
|
||||
|
||||
func (db *dbSqlite) open(c *ConfigNode) (*sql.DB, error) {
|
||||
func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) {
|
||||
var source string
|
||||
if c.Linkinfo != "" {
|
||||
source = c.Linkinfo
|
||||
if config.Linkinfo != "" {
|
||||
source = config.Linkinfo
|
||||
} else {
|
||||
source = c.Name
|
||||
source = config.Name
|
||||
}
|
||||
if db, err := sql.Open("sqlite3", source); err == nil {
|
||||
return db, nil
|
||||
|
||||
@ -7,15 +7,9 @@
|
||||
package gdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"errors"
|
||||
"strings"
|
||||
"reflect"
|
||||
"database/sql"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"gitee.com/johng/gf/g/util/gregex"
|
||||
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
"gitee.com/johng/gf/g/container/gvar"
|
||||
)
|
||||
|
||||
// 数据库事务对象
|
||||
@ -37,49 +31,27 @@ func (tx *TX) Rollback() error {
|
||||
|
||||
// (事务)数据库sql查询操作,主要执行查询
|
||||
func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
return doQuery(tx.db, tx.tx, query, args...)
|
||||
return tx.db.doQuery(tx.tx, query, args...)
|
||||
}
|
||||
|
||||
// (事务)执行一条sql,并返回执行情况,主要用于非查询操作
|
||||
func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return doExec(tx.db, tx.tx, query, args...)
|
||||
return tx.db.doExec(tx.tx, query, args...)
|
||||
}
|
||||
|
||||
// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作
|
||||
func (tx *TX) Prepare(query string) (*sql.Stmt, error) {
|
||||
return tx.db.doPrepare(tx.tx, query)
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询结果集,以列表结构返回
|
||||
func (tx *TX) GetAll(query string, args ...interface{}) (Result, error) {
|
||||
// 执行sql
|
||||
rows, err := tx.Query(query, args ...)
|
||||
if err != nil || rows == nil {
|
||||
return nil, err
|
||||
}
|
||||
// 列名称列表
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 返回结构组装
|
||||
values := make([]sql.RawBytes, len(columns))
|
||||
scanArgs := make([]interface{}, len(values))
|
||||
records := make(Result, 0)
|
||||
for i := range values {
|
||||
scanArgs[i] = &values[i]
|
||||
}
|
||||
for rows.Next() {
|
||||
err = rows.Scan(scanArgs...)
|
||||
if err != nil {
|
||||
return records, err
|
||||
}
|
||||
row := make(Record)
|
||||
// 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址
|
||||
for i, col := range values {
|
||||
v := make([]byte, len(col))
|
||||
copy(v, col)
|
||||
row[columns[i]] = gvar.New(v, false)
|
||||
}
|
||||
//fmt.Printf("%p\n", row["typeid"])
|
||||
records = append(records, row)
|
||||
}
|
||||
return records, nil
|
||||
defer rows.Close()
|
||||
return rowsToResult(rows)
|
||||
}
|
||||
|
||||
// 数据库查询,获取查询结果记录,以关联数组结构返回
|
||||
@ -103,7 +75,6 @@ func (tx *TX) GetStruct(obj interface{}, query string, args ...interface{}) erro
|
||||
return one.ToStruct(obj)
|
||||
}
|
||||
|
||||
|
||||
// 数据库查询,获取查询字段值
|
||||
func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
one, err := tx.GetOne(query, args ...)
|
||||
@ -118,185 +89,54 @@ func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) {
|
||||
|
||||
// 数据库查询,获取查询数量
|
||||
func (tx *TX) GetCount(query string, args ...interface{}) (int, error) {
|
||||
val, err := tx.GetValue(query, args ...)
|
||||
if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) {
|
||||
query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
|
||||
}
|
||||
value, err := tx.GetValue(query, args ...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return gconv.Int(val), nil
|
||||
}
|
||||
|
||||
// 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作
|
||||
func (tx *TX) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) {
|
||||
s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables)
|
||||
if condition != nil {
|
||||
s += fmt.Sprintf("WHERE %s ", tx.db.formatCondition(condition))
|
||||
}
|
||||
if len(groupBy) > 0 {
|
||||
s += fmt.Sprintf("GROUP BY %s ", groupBy)
|
||||
}
|
||||
if len(orderBy) > 0 {
|
||||
s += fmt.Sprintf("ORDER BY %s ", orderBy)
|
||||
}
|
||||
if limit > 0 {
|
||||
s += fmt.Sprintf("LIMIT %d,%d ", first, limit)
|
||||
}
|
||||
return tx.GetAll(s, args ... )
|
||||
}
|
||||
|
||||
// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作
|
||||
func (tx *TX) Prepare(query string) (*sql.Stmt, error) {
|
||||
return tx.tx.Prepare(query)
|
||||
}
|
||||
|
||||
// insert、replace, save, ignore操作
|
||||
// 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
|
||||
// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做
|
||||
func (tx *TX) insert(table string, data Map, option uint8) (sql.Result, error) {
|
||||
var keys []string
|
||||
var values []string
|
||||
var params []interface{}
|
||||
for k, v := range data {
|
||||
keys = append(keys, tx.db.charl + k + tx.db.charr)
|
||||
values = append(values, "?")
|
||||
params = append(params, v)
|
||||
}
|
||||
operation := tx.db.getInsertOperationByOption(option)
|
||||
updatestr := ""
|
||||
if option == OPTION_SAVE {
|
||||
var updates []string
|
||||
for k, _ := range data {
|
||||
updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k))
|
||||
}
|
||||
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
|
||||
}
|
||||
return tx.Exec(
|
||||
fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s",
|
||||
operation, table, strings.Join(keys, ","),
|
||||
strings.Join(values, ","),
|
||||
updatestr),
|
||||
params...
|
||||
)
|
||||
return value.Int(), nil
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
|
||||
func (tx *TX) Insert(table string, data Map) (sql.Result, error) {
|
||||
return tx.insert(table, data, OPTION_INSERT)
|
||||
return tx.db.doInsert(tx.tx, table, data, OPTION_INSERT)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (tx *TX) Replace(table string, data Map) (sql.Result, error) {
|
||||
return tx.insert(table, data, OPTION_REPLACE)
|
||||
return tx.db.doInsert(tx.tx, table, data, OPTION_REPLACE)
|
||||
}
|
||||
|
||||
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (tx *TX) Save(table string, data Map) (sql.Result, error) {
|
||||
return tx.insert(table, data, OPTION_SAVE)
|
||||
}
|
||||
|
||||
// 批量写入数据
|
||||
func (tx *TX) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) {
|
||||
var keys []string
|
||||
var values []string
|
||||
var bvalues []string
|
||||
var params []interface{}
|
||||
var result sql.Result
|
||||
var size = len(list)
|
||||
// 判断长度
|
||||
if size < 1 {
|
||||
return result, errors.New("empty data list")
|
||||
}
|
||||
// 首先获取字段名称及记录长度
|
||||
for k, _ := range list[0] {
|
||||
keys = append(keys, k)
|
||||
values = append(values, "?")
|
||||
}
|
||||
keyStr := tx.db.charl + strings.Join(keys, tx.db.charl + "," + tx.db.charr) + tx.db.charr
|
||||
valueHolderStr := "(" + strings.Join(values, ",") + ")"
|
||||
// 操作判断
|
||||
operation := tx.db.getInsertOperationByOption(option)
|
||||
updatestr := ""
|
||||
if option == OPTION_SAVE {
|
||||
var updates []string
|
||||
for _, k := range keys {
|
||||
updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k))
|
||||
}
|
||||
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
|
||||
}
|
||||
// 构造批量写入数据格式(注意map的遍历是无序的)
|
||||
for i := 0; i < size; i++ {
|
||||
for _, k := range keys {
|
||||
params = append(params, list[i][k])
|
||||
}
|
||||
bvalues = append(bvalues, valueHolderStr)
|
||||
if len(bvalues) == batch {
|
||||
r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
operation, table, keyStr, strings.Join(bvalues, ","),
|
||||
updatestr),
|
||||
params...)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
result = r
|
||||
params = params[:0]
|
||||
bvalues = bvalues[:0]
|
||||
}
|
||||
}
|
||||
// 处理最后不构成指定批量的数据
|
||||
if len(bvalues) > 0 {
|
||||
r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s",
|
||||
operation, table, keyStr, strings.Join(bvalues, ","),
|
||||
updatestr),
|
||||
params...)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
result = r
|
||||
}
|
||||
return result, nil
|
||||
return tx.db.doInsert(tx.tx, table, data, OPTION_SAVE)
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入
|
||||
func (tx *TX) BatchInsert(table string, list List, batch int) (sql.Result, error) {
|
||||
return tx.batchInsert(table, list, batch, OPTION_INSERT)
|
||||
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_INSERT)
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
|
||||
func (tx *TX) BatchReplace(table string, list List, batch int) (sql.Result, error) {
|
||||
return tx.batchInsert(table, list, batch, OPTION_REPLACE)
|
||||
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_REPLACE)
|
||||
}
|
||||
|
||||
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
|
||||
func (tx *TX) BatchSave(table string, list List, batch int) (sql.Result, error) {
|
||||
return tx.batchInsert(table, list, batch, OPTION_SAVE)
|
||||
return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_SAVE)
|
||||
}
|
||||
|
||||
// CURD操作:数据更新,统一采用sql预处理
|
||||
// data参数支持字符串或者关联数组类型,内部会自行做判断处理
|
||||
func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
var params []interface{}
|
||||
var updates string
|
||||
refValue := reflect.ValueOf(data)
|
||||
if refValue.Kind() == reflect.Map {
|
||||
var fields []string
|
||||
keys := refValue.MapKeys()
|
||||
for _, k := range keys {
|
||||
fields = append(fields, fmt.Sprintf("%s%s%s=?", tx.db.charl, k, tx.db.charr))
|
||||
params = append(params, gconv.String(refValue.MapIndex(k).Interface()))
|
||||
updates = strings.Join(fields, ",")
|
||||
}
|
||||
} else {
|
||||
updates = gconv.String(data)
|
||||
}
|
||||
for _, v := range args {
|
||||
params = append(params, gconv.String(v))
|
||||
}
|
||||
return tx.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, tx.db.formatCondition(condition)), params...)
|
||||
return tx.db.doUpdate(tx.tx, table, data, condition, args ...)
|
||||
}
|
||||
|
||||
// CURD操作:删除数据
|
||||
func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
return tx.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, tx.db.formatCondition(condition)), args...)
|
||||
return tx.db.doDelete(tx.tx, table, condition, args ...)
|
||||
}
|
||||
|
||||
|
||||
215
g/database/gdb/gdb_unit_1_test.go
Normal file
215
g/database/gdb/gdb_unit_1_test.go
Normal file
@ -0,0 +1,215 @@
|
||||
package gdb_test
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/database/gdb"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gtest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
// 数据库对象/接口
|
||||
db gdb.DB
|
||||
)
|
||||
|
||||
// 初始化连接参数。
|
||||
// 测试前需要修改连接参数。
|
||||
func init() {
|
||||
gdb.AddDefaultConfigNode(gdb.ConfigNode{
|
||||
Host: "127.0.0.1",
|
||||
Port: "3306",
|
||||
User: "root",
|
||||
Pass: "12345678",
|
||||
Name: "test",
|
||||
Type: "mysql",
|
||||
Role: "master",
|
||||
Charset: "utf8",
|
||||
Priority: 1,
|
||||
})
|
||||
if r, err := gdb.New(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
db = r
|
||||
}
|
||||
// 准备测试数据结构
|
||||
if _, err := db.Exec("DROP TABLE `user`"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := db.Exec(`
|
||||
CREATE TABLE user (
|
||||
id int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '用户ID',
|
||||
passport varchar(45) NOT NULL COMMENT '账号',
|
||||
password char(32) NOT NULL COMMENT '密码',
|
||||
nickname varchar(45) NOT NULL COMMENT '昵称',
|
||||
create_time timestamp NOT NULL COMMENT '创建时间/注册时间',
|
||||
PRIMARY KEY (id)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
|
||||
`); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Query(t *testing.T) {
|
||||
if _, err := db.Query("SELECT ?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := db.Query("ERROR"); err == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Exec(t *testing.T) {
|
||||
if _, err := db.Exec("SELECT ?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := db.Exec("ERROR"); err == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Prepare(t *testing.T) {
|
||||
st, err := db.Prepare("SELECT 100")
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
rows, err := st.Query()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
array, err := rows.Columns()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(array[0], "100")
|
||||
if err := rows.Close(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Insert(t *testing.T) {
|
||||
if _, err := db.Insert("user", g.Map{
|
||||
"id" : 1,
|
||||
"passport" : "t1",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T1",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_BatchInsert(t *testing.T) {
|
||||
if _, err := db.BatchInsert("user", g.List {
|
||||
{
|
||||
"id" : 2,
|
||||
"passport" : "t2",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T2",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
{
|
||||
"id" : 3,
|
||||
"passport" : "t3",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T3",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
}, 10); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Save(t *testing.T) {
|
||||
if _, err := db.Save("user", g.Map{
|
||||
"id" : 1,
|
||||
"passport" : "t1",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T11",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Replace(t *testing.T) {
|
||||
if _, err := db.Save("user", g.Map{
|
||||
"id" : 1,
|
||||
"passport" : "t1",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T111",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Update(t *testing.T) {
|
||||
if result, err := db.Update("user", "create_time='2010-10-10 00:00:01'", "id=3"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_GetAll(t *testing.T) {
|
||||
if result, err := db.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(len(result), 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_GetOne(t *testing.T) {
|
||||
if record, err := db.GetOne("SELECT * FROM user WHERE passport=?", "t1"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
if record == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
gtest.Assert(record["nickname"].String(), "T111")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_GetValue(t *testing.T) {
|
||||
if value, err := db.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.Int(), 3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_GetCount(t *testing.T) {
|
||||
if count, err := db.GetCount("SELECT * FROM user"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(count, 3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_GetStruct(t *testing.T) {
|
||||
type User struct {
|
||||
Id int
|
||||
Passport string
|
||||
Password string
|
||||
NickName string
|
||||
CreateTime gtime.Time
|
||||
}
|
||||
user := new(User)
|
||||
if err := db.GetStruct(user, "SELECT * FROM user WHERE id=?", 3); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(user.CreateTime.String(), "2010-10-10 00:00:01")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDbBase_Delete(t *testing.T) {
|
||||
if result, err := db.Delete("user", nil); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 3)
|
||||
}
|
||||
}
|
||||
|
||||
177
g/database/gdb/gdb_unit_2_test.go
Normal file
177
g/database/gdb/gdb_unit_2_test.go
Normal file
@ -0,0 +1,177 @@
|
||||
package gdb_test
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gtest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModel_Insert(t *testing.T) {
|
||||
result, err := db.Table("user").Data(g.Map{
|
||||
"id" : 1,
|
||||
"passport" : "t1",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T1",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}).Insert()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.LastInsertId()
|
||||
gtest.Assert(n, 1)
|
||||
}
|
||||
|
||||
func TestModel_Batch(t *testing.T) {
|
||||
result, err := db.Table("user").Data(g.List{
|
||||
{
|
||||
"id" : 2,
|
||||
"passport" : "t2",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T2",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
{
|
||||
"id" : 3,
|
||||
"passport" : "t3",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T3",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
}).Batch(10).Insert()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 2)
|
||||
}
|
||||
|
||||
func TestModel_Replace(t *testing.T) {
|
||||
result, err := db.Table("user").Data(g.Map{
|
||||
"id" : 1,
|
||||
"passport" : "t11",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T11",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}).Replace()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 2)
|
||||
}
|
||||
|
||||
func TestModel_Save(t *testing.T) {
|
||||
result, err := db.Table("user").Data(g.Map{
|
||||
"id" : 1,
|
||||
"passport" : "t111",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T111",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}).Save()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 2)
|
||||
}
|
||||
|
||||
func TestModel_Update(t *testing.T) {
|
||||
result, err := db.Table("user").Data("passport", "t22").Where("passport=?", "t2").Update()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 1)
|
||||
}
|
||||
|
||||
func TestModel_All(t *testing.T) {
|
||||
result, err := db.Table("user").All()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 3)
|
||||
}
|
||||
|
||||
func TestModel_One(t *testing.T) {
|
||||
record, err := db.Table("user").Where("id", 1).One()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if record == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
gtest.Assert(record["nickname"].String(), "T111")
|
||||
}
|
||||
|
||||
func TestModel_Value(t *testing.T) {
|
||||
value, err := db.Table("user").Fields("nickname").Where("id", 1).Value()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if value == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
gtest.Assert(value.String(), "T111")
|
||||
}
|
||||
|
||||
func TestModel_Count(t *testing.T) {
|
||||
count, err := db.Table("user").Count()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(count, 3)
|
||||
}
|
||||
|
||||
func TestModel_Select(t *testing.T) {
|
||||
result, err := db.Table("user").Select()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 3)
|
||||
}
|
||||
|
||||
func TestModel_Struct(t *testing.T) {
|
||||
type User struct {
|
||||
Id int
|
||||
Passport string
|
||||
Password string
|
||||
NickName string
|
||||
CreateTime gtime.Time
|
||||
}
|
||||
user := new(User)
|
||||
err := db.Table("user").Where("id=1").Struct(user)
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(user.NickName, "T111")
|
||||
}
|
||||
|
||||
func TestModel_OrderBy(t *testing.T) {
|
||||
result, err := db.Table("user").OrderBy("id DESC").Select()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 3)
|
||||
gtest.Assert(result[0]["nickname"].String(), "T3")
|
||||
}
|
||||
|
||||
func TestModel_GroupBy(t *testing.T) {
|
||||
result, err := db.Table("user").GroupBy("id").Select()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(len(result), 3)
|
||||
gtest.Assert(result[0]["nickname"].String(), "T111")
|
||||
}
|
||||
|
||||
func TestModel_Delete(t *testing.T) {
|
||||
result, err := db.Table("user").Delete()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 3)
|
||||
}
|
||||
|
||||
|
||||
372
g/database/gdb/gdb_unit_3_test.go
Normal file
372
g/database/gdb/gdb_unit_3_test.go
Normal file
@ -0,0 +1,372 @@
|
||||
package gdb_test
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g"
|
||||
"gitee.com/johng/gf/g/os/gtime"
|
||||
"gitee.com/johng/gf/g/util/gtest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTX_Query(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if rows, err := tx.Query("SELECT ?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
rows.Close()
|
||||
}
|
||||
if _, err := tx.Query("ERROR"); err == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Exec(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Exec("SELECT ?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Exec("ERROR"); err == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Commit(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Rollback(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Rollback(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Prepare(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
st, err := tx.Prepare("SELECT 100")
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
rows, err := st.Query()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
array, err := rows.Columns()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
gtest.Assert(array[0], "100")
|
||||
if err := rows.Close(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Insert(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Insert("user", g.Map {
|
||||
"id" : 1,
|
||||
"passport" : "t1",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T1",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if n, err := db.Table("user").Count(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(n, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_BatchInsert(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.BatchInsert("user", g.List {
|
||||
{
|
||||
"id" : 2,
|
||||
"passport" : "t",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T2",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
{
|
||||
"id" : 3,
|
||||
"passport" : "t3",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T3",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
}, 10); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if n, err := db.Table("user").Count(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(n, 3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_BatchReplace(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.BatchReplace("user", g.List {
|
||||
{
|
||||
"id" : 2,
|
||||
"passport" : "t2",
|
||||
"password" : "p2",
|
||||
"nickname" : "T2",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
{
|
||||
"id" : 4,
|
||||
"passport" : "t4",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T4",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
}, 10); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
// 数据数量
|
||||
if n, err := db.Table("user").Count(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(n, 4)
|
||||
}
|
||||
// 检查replace后的数值
|
||||
if value, err := db.Table("user").Fields("password").Where("id", 2).Value(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.String(), "p2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_BatchSave(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.BatchSave("user", g.List {
|
||||
{
|
||||
"id" : 4,
|
||||
"passport" : "t4",
|
||||
"password" : "p4",
|
||||
"nickname" : "T4",
|
||||
"create_time" : gtime.Now().String(),
|
||||
},
|
||||
}, 10); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
// 数据数量
|
||||
if n, err := db.Table("user").Count(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(n, 4)
|
||||
}
|
||||
// 检查replace后的数值
|
||||
if value, err := db.Table("user").Fields("password").Where("id", 4).Value(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.String(), "p4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Replace(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Replace("user", g.Map {
|
||||
"id" : 1,
|
||||
"passport" : "t11",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T11",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Rollback(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.String(), "T1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Save(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Save("user", g.Map {
|
||||
"id" : 1,
|
||||
"passport" : "t11",
|
||||
"password" : "25d55ad283aa400af464c76d713c07ad",
|
||||
"nickname" : "T11",
|
||||
"create_time" : gtime.Now().String(),
|
||||
}); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.String(), "T11")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_GetAll(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if result, err := tx.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(len(result), 1)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_GetOne(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if record, err := tx.GetOne("SELECT * FROM user WHERE passport=?", "t2"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
if record == nil {
|
||||
gtest.Fatal("FAIL")
|
||||
}
|
||||
gtest.Assert(record["nickname"].String(), "T2")
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_GetValue(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if value, err := tx.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(value.Int(), 3)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_GetCount(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if count, err := tx.GetCount("SELECT * FROM user"); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(count, 4)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_GetStruct(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
type User struct {
|
||||
Id int
|
||||
Passport string
|
||||
Password string
|
||||
NickName string
|
||||
CreateTime gtime.Time
|
||||
}
|
||||
user := new(User)
|
||||
if err := tx.GetStruct(user, "SELECT * FROM user WHERE id=?", 1); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(user.NickName, "T11")
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTX_Delete(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if _, err := tx.Delete("user", nil); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
if n, err := db.Table("user").Count(); err != nil {
|
||||
gtest.Fatal(err)
|
||||
} else {
|
||||
gtest.Assert(n, 0)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -120,7 +120,7 @@ func Config(file...string) *gcfg.Config {
|
||||
}
|
||||
|
||||
// 数据库操作对象,使用了连接池
|
||||
func Database(name...string) *gdb.Db {
|
||||
func Database(name...string) gdb.DB {
|
||||
config := Config()
|
||||
group := gdb.DEFAULT_GROUP_NAME
|
||||
if len(name) > 0 {
|
||||
@ -195,7 +195,7 @@ func Database(name...string) *gdb.Db {
|
||||
return nil
|
||||
})
|
||||
if db != nil {
|
||||
return db.(*gdb.Db)
|
||||
return db.(gdb.DB)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -44,12 +44,12 @@ func Config(file...string) *gcfg.Config {
|
||||
}
|
||||
|
||||
// 数据库操作对象,使用了连接池
|
||||
func Database(name...string) *gdb.Db {
|
||||
func Database(name...string) gdb.DB {
|
||||
return gins.Database(name...)
|
||||
}
|
||||
|
||||
// (别名)Database
|
||||
func DB(name...string) *gdb.Db {
|
||||
func DB(name...string) gdb.DB {
|
||||
return gins.Database(name...)
|
||||
}
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ type Logger struct {
|
||||
file *gtype.String // 日志文件名称格式
|
||||
level *gtype.Int // 日志输出等级
|
||||
btSkip *gtype.Int // 错误产生时的backtrace回调信息skip条数
|
||||
btEnabled *gtype.Bool // 是否当打印错误时同时开启backtrace打印
|
||||
btStatus *gtype.Int // 是否当打印错误时同时开启backtrace打印(默认-1,表示默认打印逻辑 - 错误才打印)
|
||||
printHeader *gtype.Bool // 是否不打印前缀信息(时间,级别等)
|
||||
alsoStdPrint *gtype.Bool // 控制台打印开关,当输出到文件/自定义输出时也同时打印到终端
|
||||
}
|
||||
@ -65,7 +65,7 @@ func New() *Logger {
|
||||
file : gtype.NewString(gDEFAULT_FILE_FORMAT),
|
||||
level : gtype.NewInt(defaultLevel.Val()),
|
||||
btSkip : gtype.NewInt(),
|
||||
btEnabled : gtype.NewBool(true),
|
||||
btStatus : gtype.NewInt(-1),
|
||||
printHeader : gtype.NewBool(true),
|
||||
alsoStdPrint : gtype.NewBool(true),
|
||||
}
|
||||
@ -80,7 +80,7 @@ func (l *Logger) Clone() *Logger {
|
||||
file : l.file.Clone(),
|
||||
level : l.level.Clone(),
|
||||
btSkip : l.btSkip.Clone(),
|
||||
btEnabled : l.btEnabled.Clone(),
|
||||
btStatus : l.btStatus.Clone(),
|
||||
printHeader : l.printHeader.Clone(),
|
||||
alsoStdPrint : l.alsoStdPrint.Clone(),
|
||||
}
|
||||
@ -106,7 +106,12 @@ func (l *Logger) SetDebug(debug bool) {
|
||||
}
|
||||
|
||||
func (l *Logger) SetBacktrace(enabled bool) {
|
||||
l.btEnabled.Set(enabled)
|
||||
if enabled {
|
||||
l.btStatus.Set(1)
|
||||
} else {
|
||||
l.btStatus.Set(0)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// 设置BacktraceSkip
|
||||
@ -214,27 +219,37 @@ func (l *Logger) doStdLockPrint(std io.Writer, s string) {
|
||||
|
||||
// 核心打印数据方法(标准输出)
|
||||
func (l *Logger) stdPrint(s string) {
|
||||
if l.btStatus.Val() == 1 {
|
||||
s = l.appendBacktrace(s)
|
||||
}
|
||||
l.print(os.Stdout, s)
|
||||
}
|
||||
|
||||
// 核心打印数据方法(标准错误)
|
||||
func (l *Logger) errPrint(s string) {
|
||||
// 记录调用回溯信息
|
||||
if l.btEnabled.Val() {
|
||||
tracestr := l.GetBacktrace()
|
||||
if tracestr != "" {
|
||||
backtrace := "Backtrace:" + ln + tracestr
|
||||
if s[len(s) - 1] == byte('\n') {
|
||||
s = s + backtrace + ln
|
||||
} else {
|
||||
s = s + ln + backtrace + ln
|
||||
}
|
||||
}
|
||||
status := l.btStatus.Val()
|
||||
if status == -1 || status == 1 {
|
||||
s = l.appendBacktrace(s)
|
||||
}
|
||||
// 防止串日志情况,这里不使用stderr,而是使用stdout
|
||||
l.print(os.Stdout, s)
|
||||
}
|
||||
|
||||
// 输出内容中添加回溯信息
|
||||
func (l *Logger) appendBacktrace(s string) string {
|
||||
trace := l.GetBacktrace()
|
||||
if trace != "" {
|
||||
backtrace := "Backtrace:" + ln + trace
|
||||
if s[len(s) - 1] == byte('\n') {
|
||||
s = s + backtrace + ln
|
||||
} else {
|
||||
s = s + ln + backtrace + ln
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// 直接打印回溯信息,参数skip表示调用端往上多少级开始回溯
|
||||
func (l *Logger) PrintBacktrace(skip...int) {
|
||||
l.Println(l.GetBacktrace(skip...))
|
||||
|
||||
@ -157,12 +157,10 @@ func bindVarToStruct(elem reflect.Value, name string, value interface{}) (err er
|
||||
structFieldValue := elem.FieldByName(name)
|
||||
// 键名与对象属性匹配检测,map中如果有struct不存在的属性,那么不做处理,直接return
|
||||
if !structFieldValue.IsValid() {
|
||||
//return errors.New(fmt.Sprintf(`invalid struct attribute of name "%s"`, name))
|
||||
return nil
|
||||
}
|
||||
// CanSet的属性必须为公开属性(首字母大写)
|
||||
if !structFieldValue.CanSet() {
|
||||
//return errors.New(fmt.Sprintf(`struct attribute of name "%s" cannot be set`, name))
|
||||
return nil
|
||||
}
|
||||
// 必须将value转换为struct属性的数据类型,这里必须用到gconv包
|
||||
@ -181,12 +179,10 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e
|
||||
structFieldValue := elem.FieldByIndex([]int{index})
|
||||
// 键名与对象属性匹配检测
|
||||
if !structFieldValue.IsValid() {
|
||||
//return errors.New(fmt.Sprintf("invalid struct attribute at index %d", index))
|
||||
return nil
|
||||
}
|
||||
// CanSet的属性必须为公开属性(首字母大写)
|
||||
if !structFieldValue.CanSet() {
|
||||
//return errors.New(fmt.Sprintf("struct attribute cannot be set at index %d", index))
|
||||
return nil
|
||||
}
|
||||
// 必须将value转换为struct属性的数据类型,这里必须用到gconv包
|
||||
|
||||
30
g/util/gtest/gtest.go
Normal file
30
g/util/gtest/gtest.go
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright 2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://gitee.com/johng/gf.
|
||||
|
||||
// Package gtest provides useful test utils.
|
||||
// 测试模块.
|
||||
package gtest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/os/glog"
|
||||
"gitee.com/johng/gf/g/util/gconv"
|
||||
"os"
|
||||
)
|
||||
|
||||
// 断言判断
|
||||
func Assert(value, expect interface{}) {
|
||||
if gconv.String(value) != gconv.String(expect) {
|
||||
glog.Backtrace(true, 1).Printfln(`[ASSERT] VALUE: %v, EXPECT: %v`, value, expect)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// 提示错误并退出
|
||||
func Fatal(message...interface{}) {
|
||||
glog.Backtrace(true, 1).Println(`[FATAL] `, fmt.Sprint(message...))
|
||||
os.Exit(1)
|
||||
}
|
||||
@ -9,7 +9,7 @@ import (
|
||||
|
||||
// 本文件用于gf框架的mysql数据库操作示例,不作为单元测试使用
|
||||
|
||||
var db *gdb.Db
|
||||
var db gdb.DB
|
||||
|
||||
// 初始化配置及创建数据库
|
||||
func init () {
|
||||
@ -17,7 +17,7 @@ func init () {
|
||||
Host : "127.0.0.1",
|
||||
Port : "3306",
|
||||
User : "root",
|
||||
Pass : "8692651",
|
||||
Pass : "12345678",
|
||||
Name : "test",
|
||||
Type : "mysql",
|
||||
Role : "master",
|
||||
|
||||
@ -1,28 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"gitee.com/johng/gf/g/database/gdb"
|
||||
"gitee.com/johng/gf/g"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gdb.AddDefaultConfigNode(gdb.ConfigNode {
|
||||
Host : "127.0.0.1",
|
||||
Port : "3306",
|
||||
User : "root",
|
||||
Pass : "12345678",
|
||||
Name : "test",
|
||||
Type : "mysql",
|
||||
Role : "master",
|
||||
Charset : "utf8",
|
||||
MaxIdleConnCount : 10,
|
||||
MaxOpenConnCount : 10,
|
||||
MaxConnLifetime : 10,
|
||||
})
|
||||
db, err := gdb.New()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
db := g.DB()
|
||||
db.SetMaxIdleConns(10)
|
||||
db.SetMaxOpenConns(10)
|
||||
db.SetConnMaxLifetime(10)
|
||||
|
||||
// 开启调试模式,以便于记录所有执行的SQL
|
||||
db.SetDebug(true)
|
||||
|
||||
|
||||
@ -2,24 +2,11 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/database/gdb"
|
||||
"gitee.com/johng/gf/g"
|
||||
)
|
||||
|
||||
func main() {
|
||||
gdb.AddDefaultConfigNode(gdb.ConfigNode {
|
||||
Host : "192.168.1.11",
|
||||
Port : "3306",
|
||||
User : "root",
|
||||
Pass : "8692651",
|
||||
Name : "test",
|
||||
Type : "mysql",
|
||||
Role : "master",
|
||||
Charset : "utf8",
|
||||
})
|
||||
db, err := gdb.New()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
db := g.DB()
|
||||
// 开启调试模式,以便于记录所有执行的SQL
|
||||
db.SetDebug(true)
|
||||
|
||||
|
||||
@ -1,41 +0,0 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var s = "/name/john///."
|
||||
var c = "./"
|
||||
|
||||
func t1(s string) string {
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
for _, cut := range c {
|
||||
for s[len(s) - 1] == uint8(cut) {
|
||||
s = s[:len(s) - 1]
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func t2(s string) string {
|
||||
return strings.TrimRight(s, c)
|
||||
}
|
||||
|
||||
func Benchmark_t1(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
t1(s)
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_t2(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
t2(s)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,25 +1,15 @@
|
||||
package main
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"gitee.com/johng/gf/g/util/gregex"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Uid int
|
||||
}
|
||||
|
||||
func New() *User {
|
||||
return &User{
|
||||
100,
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) Clear() {
|
||||
user = New()
|
||||
}
|
||||
|
||||
func main() {
|
||||
user := New()
|
||||
user.Uid = 10000
|
||||
fmt.Println(user)
|
||||
user.Clear()
|
||||
fmt.Println(user)
|
||||
query := "select * from user"
|
||||
q, err := gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query)
|
||||
fmt.Println(err)
|
||||
fmt.Println(q)
|
||||
}
|
||||
@ -4,6 +4,7 @@ go:
|
||||
- 1.8.x
|
||||
- 1.9.x
|
||||
- 1.10.x
|
||||
- 1.11.x
|
||||
- master
|
||||
|
||||
before_install:
|
||||
|
||||
@ -35,6 +35,7 @@ Hanno Braun <mail at hannobraun.com>
|
||||
Henri Yandell <flamefew at gmail.com>
|
||||
Hirotaka Yamamoto <ymmt2005 at gmail.com>
|
||||
ICHINOSE Shogo <shogo82148 at gmail.com>
|
||||
Ilia Cimpoes <ichimpoesh at gmail.com>
|
||||
INADA Naoki <songofacandy at gmail.com>
|
||||
Jacek Szwec <szwec.jacek at gmail.com>
|
||||
James Harr <james.harr at gmail.com>
|
||||
@ -72,7 +73,9 @@ Shuode Li <elemount at qq.com>
|
||||
Soroush Pour <me at soroushjp.com>
|
||||
Stan Putrya <root.vagner at gmail.com>
|
||||
Stanley Gunawan <gunawan.stanley at gmail.com>
|
||||
Steven Hartland <steven.hartland at multiplay.co.uk>
|
||||
Thomas Wodarek <wodarekwebpage at gmail.com>
|
||||
Tom Jenkinson <tom at tjenkinson.me>
|
||||
Xiangyu Hu <xiangyu.hu at outlook.com>
|
||||
Xiaobing Jiang <s7v7nislands at gmail.com>
|
||||
Xiuming Chen <cc at cxm.cc>
|
||||
@ -88,3 +91,4 @@ Keybase Inc.
|
||||
Percona LLC
|
||||
Pivotal Inc.
|
||||
Stripe Inc.
|
||||
Multiplay Ltd.
|
||||
|
||||
@ -58,7 +58,7 @@ _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface.
|
||||
Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`:
|
||||
```go
|
||||
import "database/sql"
|
||||
import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
import _ "github.com/go-sql-driver/mysql"
|
||||
|
||||
db, err := sql.Open("mysql", "user:password@/dbname")
|
||||
```
|
||||
@ -328,11 +328,11 @@ Timeout for establishing connections, aka dial timeout. The value must be a deci
|
||||
|
||||
```
|
||||
Type: bool / string
|
||||
Valid Values: true, false, skip-verify, <name>
|
||||
Valid Values: true, false, skip-verify, preferred, <name>
|
||||
Default: false
|
||||
```
|
||||
|
||||
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
|
||||
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use `preferred` to use TLS only when advertised by the server, this is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
|
||||
|
||||
|
||||
##### `writeTimeout`
|
||||
@ -431,7 +431,7 @@ See [context support in the database/sql package](https://golang.org/doc/go1.8#d
|
||||
### `LOAD DATA LOCAL INFILE` support
|
||||
For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
|
||||
```go
|
||||
import "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
import "github.com/go-sql-driver/mysql"
|
||||
```
|
||||
|
||||
Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)).
|
||||
|
||||
@ -234,64 +234,64 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return mc.writeAuthSwitchPacket(enc, false)
|
||||
return mc.writeAuthSwitchPacket(enc)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) {
|
||||
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
|
||||
switch plugin {
|
||||
case "caching_sha2_password":
|
||||
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
|
||||
return authResp, false, nil
|
||||
return authResp, nil
|
||||
|
||||
case "mysql_old_password":
|
||||
if !mc.cfg.AllowOldPasswords {
|
||||
return nil, false, ErrOldPassword
|
||||
return nil, ErrOldPassword
|
||||
}
|
||||
// Note: there are edge cases where this should work but doesn't;
|
||||
// this is currently "wontfix":
|
||||
// https://github.com/go-sql-driver/mysql/issues/184
|
||||
authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd)
|
||||
return authResp, true, nil
|
||||
authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
|
||||
return authResp, nil
|
||||
|
||||
case "mysql_clear_password":
|
||||
if !mc.cfg.AllowCleartextPasswords {
|
||||
return nil, false, ErrCleartextPassword
|
||||
return nil, ErrCleartextPassword
|
||||
}
|
||||
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
|
||||
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
|
||||
return []byte(mc.cfg.Passwd), true, nil
|
||||
return append([]byte(mc.cfg.Passwd), 0), nil
|
||||
|
||||
case "mysql_native_password":
|
||||
if !mc.cfg.AllowNativePasswords {
|
||||
return nil, false, ErrNativePassword
|
||||
return nil, ErrNativePassword
|
||||
}
|
||||
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
|
||||
// Native password authentication only need and will need 20-byte challenge.
|
||||
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
|
||||
return authResp, false, nil
|
||||
return authResp, nil
|
||||
|
||||
case "sha256_password":
|
||||
if len(mc.cfg.Passwd) == 0 {
|
||||
return nil, true, nil
|
||||
return []byte{0}, nil
|
||||
}
|
||||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
|
||||
// write cleartext auth packet
|
||||
return []byte(mc.cfg.Passwd), true, nil
|
||||
return append([]byte(mc.cfg.Passwd), 0), nil
|
||||
}
|
||||
|
||||
pubKey := mc.cfg.pubKey
|
||||
if pubKey == nil {
|
||||
// request public key from server
|
||||
return []byte{1}, false, nil
|
||||
return []byte{1}, nil
|
||||
}
|
||||
|
||||
// encrypted password
|
||||
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
|
||||
return enc, false, err
|
||||
return enc, err
|
||||
|
||||
default:
|
||||
errLog.Print("unknown auth plugin:", plugin)
|
||||
return nil, false, ErrUnknownPlugin
|
||||
return nil, ErrUnknownPlugin
|
||||
}
|
||||
}
|
||||
|
||||
@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
|
||||
|
||||
plugin = newPlugin
|
||||
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil {
|
||||
if err = mc.writeAuthSwitchPacket(authResp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
|
||||
case cachingSha2PasswordPerformFullAuthentication:
|
||||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
|
||||
// write cleartext auth packet
|
||||
err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true)
|
||||
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
|
||||
pubKey := mc.cfg.pubKey
|
||||
if pubKey == nil {
|
||||
// request public key from server
|
||||
data := mc.buf.takeSmallBuffer(4 + 1)
|
||||
data, err := mc.buf.takeSmallBuffer(4 + 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data[4] = cachingSha2PasswordRequestPublicKey
|
||||
mc.writePacket(data)
|
||||
|
||||
// parse public key
|
||||
data, err := mc.readPacket()
|
||||
if err != nil {
|
||||
if data, err = mc.readPacket(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@ -85,11 +85,11 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
|
||||
plugin := "caching_sha2_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -130,11 +130,11 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) {
|
||||
plugin := "caching_sha2_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -172,11 +172,11 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) {
|
||||
plugin := "caching_sha2_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -228,11 +228,11 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) {
|
||||
plugin := "caching_sha2_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -280,11 +280,11 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
|
||||
plugin := "caching_sha2_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -336,7 +336,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) {
|
||||
plugin := "mysql_clear_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
_, _, err := mc.auth(authData, plugin)
|
||||
_, err := mc.auth(authData, plugin)
|
||||
if err != ErrCleartextPassword {
|
||||
t.Errorf("expected ErrCleartextPassword, got %v", err)
|
||||
}
|
||||
@ -353,11 +353,11 @@ func TestAuthFastCleartextPassword(t *testing.T) {
|
||||
plugin := "mysql_clear_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -367,8 +367,8 @@ func TestAuthFastCleartextPassword(t *testing.T) {
|
||||
authRespEnd := authRespStart + 1 + len(authResp)
|
||||
writtenAuthRespLen := conn.written[authRespStart]
|
||||
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
|
||||
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116}
|
||||
if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
|
||||
if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
|
||||
}
|
||||
conn.written = nil
|
||||
@ -396,11 +396,11 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
|
||||
plugin := "mysql_clear_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -410,9 +410,9 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
|
||||
authRespEnd := authRespStart + 1 + len(authResp)
|
||||
writtenAuthRespLen := conn.written[authRespStart]
|
||||
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
|
||||
if writtenAuthRespLen != 0 {
|
||||
t.Fatalf("unexpected written auth response (%d bytes): %v",
|
||||
writtenAuthRespLen, writtenAuthResp)
|
||||
expectedAuthResp := []byte{0}
|
||||
if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
|
||||
}
|
||||
conn.written = nil
|
||||
|
||||
@ -439,7 +439,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) {
|
||||
plugin := "mysql_native_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
_, _, err := mc.auth(authData, plugin)
|
||||
_, err := mc.auth(authData, plugin)
|
||||
if err != ErrNativePassword {
|
||||
t.Errorf("expected ErrNativePassword, got %v", err)
|
||||
}
|
||||
@ -455,11 +455,11 @@ func TestAuthFastNativePassword(t *testing.T) {
|
||||
plugin := "mysql_native_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -498,11 +498,11 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) {
|
||||
plugin := "mysql_native_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -540,11 +540,11 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
|
||||
plugin := "sha256_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -554,7 +554,8 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
|
||||
authRespEnd := authRespStart + 1 + len(authResp)
|
||||
writtenAuthRespLen := conn.written[authRespStart]
|
||||
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
|
||||
if writtenAuthRespLen != 0 {
|
||||
expectedAuthResp := []byte{0}
|
||||
if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
|
||||
}
|
||||
conn.written = nil
|
||||
@ -587,11 +588,11 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) {
|
||||
plugin := "sha256_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -636,11 +637,11 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) {
|
||||
plugin := "sha256_password"
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -669,7 +670,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
|
||||
plugin := "sha256_password"
|
||||
|
||||
// send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -677,18 +678,18 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
|
||||
// unset TLS config to prevent the actual establishment of a TLS wrapper
|
||||
mc.cfg.tls = nil
|
||||
|
||||
err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin)
|
||||
err = mc.writeHandshakeResponsePacket(authResp, plugin)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// check written auth response
|
||||
authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1
|
||||
authRespEnd := authRespStart + 1 + len(authResp) + 1
|
||||
authRespEnd := authRespStart + 1 + len(authResp)
|
||||
writtenAuthRespLen := conn.written[authRespStart]
|
||||
writtenAuthResp := conn.written[authRespStart+1 : authRespEnd]
|
||||
expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0}
|
||||
if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) {
|
||||
t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp)
|
||||
}
|
||||
conn.written = nil
|
||||
@ -1064,6 +1065,22 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request.
|
||||
func TestOldAuthSwitchNotAllowed(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
|
||||
// OldAuthSwitch request
|
||||
conn.data = []byte{1, 0, 0, 2, 0xfe}
|
||||
conn.maxReads = 1
|
||||
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
|
||||
84, 96, 101, 92, 123, 121, 107}
|
||||
plugin := "mysql_native_password"
|
||||
err := mc.handleAuthResult(authData, plugin)
|
||||
if err != ErrOldPassword {
|
||||
t.Errorf("expected ErrOldPassword, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthSwitchOldPassword(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
mc.cfg.AllowOldPasswords = true
|
||||
@ -1092,6 +1109,32 @@ func TestAuthSwitchOldPassword(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request.
|
||||
func TestOldAuthSwitch(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
mc.cfg.AllowOldPasswords = true
|
||||
mc.cfg.Passwd = "secret"
|
||||
|
||||
// OldAuthSwitch request
|
||||
conn.data = []byte{1, 0, 0, 2, 0xfe}
|
||||
|
||||
// auth response
|
||||
conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
|
||||
conn.maxReads = 2
|
||||
|
||||
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
|
||||
84, 96, 101, 92, 123, 121, 107}
|
||||
plugin := "mysql_native_password"
|
||||
|
||||
if err := mc.handleAuthResult(authData, plugin); err != nil {
|
||||
t.Errorf("got error: %v", err)
|
||||
}
|
||||
|
||||
expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0}
|
||||
if !bytes.Equal(conn.written, expectedReply) {
|
||||
t.Errorf("got unexpected data: %v", conn.written)
|
||||
}
|
||||
}
|
||||
func TestAuthSwitchOldPasswordEmpty(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
mc.cfg.AllowOldPasswords = true
|
||||
@ -1120,6 +1163,33 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request.
|
||||
func TestOldAuthSwitchPasswordEmpty(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
mc.cfg.AllowOldPasswords = true
|
||||
mc.cfg.Passwd = ""
|
||||
|
||||
// OldAuthSwitch request.
|
||||
conn.data = []byte{1, 0, 0, 2, 0xfe}
|
||||
|
||||
// auth response
|
||||
conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}}
|
||||
conn.maxReads = 2
|
||||
|
||||
authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35,
|
||||
84, 96, 101, 92, 123, 121, 107}
|
||||
plugin := "mysql_native_password"
|
||||
|
||||
if err := mc.handleAuthResult(authData, plugin); err != nil {
|
||||
t.Errorf("got error: %v", err)
|
||||
}
|
||||
|
||||
expectedReply := []byte{1, 0, 0, 3, 0}
|
||||
if !bytes.Equal(conn.written, expectedReply) {
|
||||
t.Errorf("got unexpected data: %v", conn.written)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) {
|
||||
conn, mc := newRWMockConn(2)
|
||||
mc.cfg.Passwd = ""
|
||||
|
||||
@ -22,17 +22,17 @@ const defaultBufSize = 4096
|
||||
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
|
||||
// Also highly optimized for this particular use case.
|
||||
type buffer struct {
|
||||
buf []byte
|
||||
buf []byte // buf is a byte buffer who's length and capacity are equal.
|
||||
nc net.Conn
|
||||
idx int
|
||||
length int
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// newBuffer allocates and returns a new buffer.
|
||||
func newBuffer(nc net.Conn) buffer {
|
||||
var b [defaultBufSize]byte
|
||||
return buffer{
|
||||
buf: b[:],
|
||||
buf: make([]byte, defaultBufSize),
|
||||
nc: nc,
|
||||
}
|
||||
}
|
||||
@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) {
|
||||
return b.buf[offset:b.idx], nil
|
||||
}
|
||||
|
||||
// returns a buffer with the requested size.
|
||||
// takeBuffer returns a buffer with the requested size.
|
||||
// If possible, a slice from the existing buffer is returned.
|
||||
// Otherwise a bigger buffer is made.
|
||||
// Only one buffer (total) can be used at a time.
|
||||
func (b *buffer) takeBuffer(length int) []byte {
|
||||
func (b *buffer) takeBuffer(length int) ([]byte, error) {
|
||||
if b.length > 0 {
|
||||
return nil
|
||||
return nil, ErrBusyBuffer
|
||||
}
|
||||
|
||||
// test (cheap) general case first
|
||||
if length <= defaultBufSize || length <= cap(b.buf) {
|
||||
return b.buf[:length]
|
||||
if length <= cap(b.buf) {
|
||||
return b.buf[:length], nil
|
||||
}
|
||||
|
||||
if length < maxPacketSize {
|
||||
b.buf = make([]byte, length)
|
||||
return b.buf
|
||||
return b.buf, nil
|
||||
}
|
||||
return make([]byte, length)
|
||||
|
||||
// buffer is larger than we want to store.
|
||||
return make([]byte, length), nil
|
||||
}
|
||||
|
||||
// shortcut which can be used if the requested buffer is guaranteed to be
|
||||
// smaller than defaultBufSize
|
||||
// takeSmallBuffer is shortcut which can be used if length is
|
||||
// known to be smaller than defaultBufSize.
|
||||
// Only one buffer (total) can be used at a time.
|
||||
func (b *buffer) takeSmallBuffer(length int) []byte {
|
||||
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
|
||||
if b.length > 0 {
|
||||
return nil
|
||||
return nil, ErrBusyBuffer
|
||||
}
|
||||
return b.buf[:length]
|
||||
return b.buf[:length], nil
|
||||
}
|
||||
|
||||
// takeCompleteBuffer returns the complete existing buffer.
|
||||
// This can be used if the necessary buffer size is unknown.
|
||||
// cap and len of the returned buffer will be equal.
|
||||
// Only one buffer (total) can be used at a time.
|
||||
func (b *buffer) takeCompleteBuffer() []byte {
|
||||
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
|
||||
if b.length > 0 {
|
||||
return nil
|
||||
return nil, ErrBusyBuffer
|
||||
}
|
||||
return b.buf
|
||||
return b.buf, nil
|
||||
}
|
||||
|
||||
// store stores buf, an updated buffer, if its suitable to do so.
|
||||
func (b *buffer) store(buf []byte) error {
|
||||
if b.length > 0 {
|
||||
return ErrBusyBuffer
|
||||
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
|
||||
b.buf = buf[:cap(buf)]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -19,16 +19,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// a copy of context.Context for Go 1.7 and earlier
|
||||
type mysqlContext interface {
|
||||
Done() <-chan struct{}
|
||||
Err() error
|
||||
|
||||
// defined in context.Context, but not used in this driver:
|
||||
// Deadline() (deadline time.Time, ok bool)
|
||||
// Value(key interface{}) interface{}
|
||||
}
|
||||
|
||||
type mysqlConn struct {
|
||||
buf buffer
|
||||
netConn net.Conn
|
||||
@ -45,7 +35,7 @@ type mysqlConn struct {
|
||||
|
||||
// for context support (Go 1.8+)
|
||||
watching bool
|
||||
watcher chan<- mysqlContext
|
||||
watcher chan<- context.Context
|
||||
closech chan struct{}
|
||||
finished chan<- struct{}
|
||||
canceled atomicError // set non-nil if conn is canceled
|
||||
@ -192,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
|
||||
return "", driver.ErrSkip
|
||||
}
|
||||
|
||||
buf := mc.buf.takeCompleteBuffer()
|
||||
if buf == nil {
|
||||
buf, err := mc.buf.takeCompleteBuffer()
|
||||
if err != nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return "", ErrInvalidConn
|
||||
}
|
||||
buf = buf[:0]
|
||||
@ -475,7 +465,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
|
||||
defer mc.finish()
|
||||
|
||||
if err = mc.writeCommandPacket(comPing); err != nil {
|
||||
return
|
||||
return mc.markBadConn(err)
|
||||
}
|
||||
|
||||
return mc.readResultOK()
|
||||
@ -595,33 +585,32 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error {
|
||||
mc.cleanup()
|
||||
return nil
|
||||
}
|
||||
// When ctx is already cancelled, don't watch it.
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
// When ctx is not cancellable, don't watch it.
|
||||
if ctx.Done() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.watching = true
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
// When watcher is not alive, can't watch it.
|
||||
if mc.watcher == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.watching = true
|
||||
mc.watcher <- ctx
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) startWatcher() {
|
||||
watcher := make(chan mysqlContext, 1)
|
||||
watcher := make(chan context.Context, 1)
|
||||
mc.watcher = watcher
|
||||
finished := make(chan struct{})
|
||||
mc.finished = finished
|
||||
go func() {
|
||||
for {
|
||||
var ctx mysqlContext
|
||||
var ctx context.Context
|
||||
select {
|
||||
case ctx = <-watcher:
|
||||
case <-mc.closech:
|
||||
|
||||
@ -9,7 +9,10 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -79,3 +82,76 @@ func TestCheckNamedValue(t *testing.T) {
|
||||
t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCleanCancel tests passed context is cancelled at start.
|
||||
// No packet should be sent. Connection should keep current status.
|
||||
func TestCleanCancel(t *testing.T) {
|
||||
mc := &mysqlConn{
|
||||
closech: make(chan struct{}),
|
||||
}
|
||||
mc.startWatcher()
|
||||
defer mc.cleanup()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
for i := 0; i < 3; i++ { // Repeat same behavior
|
||||
err := mc.Ping(ctx)
|
||||
if err != context.Canceled {
|
||||
t.Errorf("expected context.Canceled, got %#v", err)
|
||||
}
|
||||
|
||||
if mc.closed.IsSet() {
|
||||
t.Error("expected mc is not closed, closed actually")
|
||||
}
|
||||
|
||||
if mc.watching {
|
||||
t.Error("expected watching is false, but true")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPingMarkBadConnection(t *testing.T) {
|
||||
nc := badConnection{err: errors.New("boom")}
|
||||
ms := &mysqlConn{
|
||||
netConn: nc,
|
||||
buf: newBuffer(nc),
|
||||
maxAllowedPacket: defaultMaxAllowedPacket,
|
||||
}
|
||||
|
||||
err := ms.Ping(context.Background())
|
||||
|
||||
if err != driver.ErrBadConn {
|
||||
t.Errorf("expected driver.ErrBadConn, got %#v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPingErrInvalidConn(t *testing.T) {
|
||||
nc := badConnection{err: errors.New("failed to write"), n: 10}
|
||||
ms := &mysqlConn{
|
||||
netConn: nc,
|
||||
buf: newBuffer(nc),
|
||||
maxAllowedPacket: defaultMaxAllowedPacket,
|
||||
closech: make(chan struct{}),
|
||||
}
|
||||
|
||||
err := ms.Ping(context.Background())
|
||||
|
||||
if err != ErrInvalidConn {
|
||||
t.Errorf("expected ErrInvalidConn, got %#v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type badConnection struct {
|
||||
n int
|
||||
err error
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (bc badConnection) Write(b []byte) (n int, err error) {
|
||||
return bc.n, bc.err
|
||||
}
|
||||
|
||||
func (bc badConnection) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
// The driver should be used via the database/sql package:
|
||||
//
|
||||
// import "database/sql"
|
||||
// import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
|
||||
// import _ "github.com/go-sql-driver/mysql"
|
||||
//
|
||||
// db, err := sql.Open("mysql", "user:password@/dbname")
|
||||
//
|
||||
@ -50,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) {
|
||||
|
||||
// Open new Connection.
|
||||
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
|
||||
// the DSN string is formated
|
||||
// the DSN string is formatted
|
||||
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||
var err error
|
||||
|
||||
@ -77,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
|
||||
}
|
||||
if err != nil {
|
||||
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
|
||||
errLog.Print("net.Error from Dial()': ", nerr.Error())
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -110,18 +114,18 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
||||
}
|
||||
|
||||
// Send Client Authentication Packet
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
authResp, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
// try the default auth plugin, if using the requested plugin failed
|
||||
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
|
||||
plugin = defaultAuthPlugin
|
||||
authResp, addNUL, err = mc.auth(authData, plugin)
|
||||
authResp, err = mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
|
||||
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -85,6 +85,23 @@ type DBTest struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
type netErrorMock struct {
|
||||
temporary bool
|
||||
timeout bool
|
||||
}
|
||||
|
||||
func (e netErrorMock) Temporary() bool {
|
||||
return e.temporary
|
||||
}
|
||||
|
||||
func (e netErrorMock) Timeout() bool {
|
||||
return e.timeout
|
||||
}
|
||||
|
||||
func (e netErrorMock) Error() string {
|
||||
return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout)
|
||||
}
|
||||
|
||||
func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
|
||||
if !available {
|
||||
t.Skipf("MySQL server not running on %s", netAddr)
|
||||
@ -1287,7 +1304,7 @@ func TestFoundRows(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTLS(t *testing.T) {
|
||||
tlsTest := func(dbt *DBTest) {
|
||||
tlsTestReq := func(dbt *DBTest) {
|
||||
if err := dbt.db.Ping(); err != nil {
|
||||
if err == ErrNoTLS {
|
||||
dbt.Skip("server does not support TLS")
|
||||
@ -1304,19 +1321,27 @@ func TestTLS(t *testing.T) {
|
||||
dbt.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
dbt.Fatal("no Cipher")
|
||||
if (*value == nil) || (len(*value) == 0) {
|
||||
dbt.Fatalf("no Cipher")
|
||||
} else {
|
||||
dbt.Logf("Cipher: %s", *value)
|
||||
}
|
||||
}
|
||||
}
|
||||
tlsTestOpt := func(dbt *DBTest) {
|
||||
if err := dbt.db.Ping(); err != nil {
|
||||
dbt.Fatalf("error on Ping: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
runTests(t, dsn+"&tls=skip-verify", tlsTest)
|
||||
runTests(t, dsn+"&tls=preferred", tlsTestOpt)
|
||||
runTests(t, dsn+"&tls=skip-verify", tlsTestReq)
|
||||
|
||||
// Verify that registering / using a custom cfg works
|
||||
RegisterTLSConfig("custom-skip-verify", &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
runTests(t, dsn+"&tls=custom-skip-verify", tlsTest)
|
||||
runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq)
|
||||
}
|
||||
|
||||
func TestReuseClosedConnection(t *testing.T) {
|
||||
@ -1801,6 +1826,38 @@ func TestConcurrent(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func testDialError(t *testing.T, dialErr error, expectErr error) {
|
||||
RegisterDial("mydial", func(addr string) (net.Conn, error) {
|
||||
return nil, dialErr
|
||||
})
|
||||
|
||||
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting: %s", err.Error())
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
_, err = db.Exec("DO 1")
|
||||
if err != expectErr {
|
||||
t.Fatalf("was expecting %s. Got: %s", dialErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialUnknownError(t *testing.T) {
|
||||
testErr := fmt.Errorf("test")
|
||||
testDialError(t, testErr, testErr)
|
||||
}
|
||||
|
||||
func TestDialNonRetryableNetErr(t *testing.T) {
|
||||
testErr := netErrorMock{}
|
||||
testDialError(t, testErr, testErr)
|
||||
}
|
||||
|
||||
func TestDialTemporaryNetErr(t *testing.T) {
|
||||
testErr := netErrorMock{temporary: true}
|
||||
testDialError(t, testErr, driver.ErrBadConn)
|
||||
}
|
||||
|
||||
// Tests custom dial functions
|
||||
func TestCustomDial(t *testing.T) {
|
||||
if !available {
|
||||
|
||||
@ -560,7 +560,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
||||
} else {
|
||||
cfg.TLSConfig = "false"
|
||||
}
|
||||
} else if vl := strings.ToLower(value); vl == "skip-verify" {
|
||||
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
|
||||
cfg.TLSConfig = vl
|
||||
cfg.tls = &tls.Config{InsecureSkipVerify: true}
|
||||
} else {
|
||||
|
||||
@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
|
||||
mc.sequence++
|
||||
|
||||
// packets with length 0 terminate a previous packet which is a
|
||||
// multiple of (2^24)−1 bytes long
|
||||
// multiple of (2^24)-1 bytes long
|
||||
if pktLen == 0 {
|
||||
// there was no previous packet
|
||||
if prevData == nil {
|
||||
@ -194,7 +194,11 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
|
||||
return nil, "", ErrOldProtocol
|
||||
}
|
||||
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
|
||||
return nil, "", ErrNoTLS
|
||||
if mc.cfg.TLSConfig == "preferred" {
|
||||
mc.cfg.tls = nil
|
||||
} else {
|
||||
return nil, "", ErrNoTLS
|
||||
}
|
||||
}
|
||||
pos += 2
|
||||
|
||||
@ -243,7 +247,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
|
||||
|
||||
// Client Authentication Packet
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||||
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
|
||||
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
|
||||
// Adjust client flags based on server support
|
||||
clientFlags := clientProtocol41 |
|
||||
clientSecureConn |
|
||||
@ -269,7 +273,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
|
||||
|
||||
// encode length of the auth plugin data
|
||||
var authRespLEIBuf [9]byte
|
||||
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
|
||||
authRespLen := len(authResp)
|
||||
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
|
||||
if len(authRespLEI) > 1 {
|
||||
// if the length can not be written in 1 byte, it must be written as a
|
||||
// length encoded integer
|
||||
@ -277,9 +282,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
|
||||
}
|
||||
|
||||
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
|
||||
if addNUL {
|
||||
pktLen++
|
||||
}
|
||||
|
||||
// To specify a db name
|
||||
if n := len(mc.cfg.DBName); n > 0 {
|
||||
@ -288,10 +290,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
|
||||
}
|
||||
|
||||
// Calculate packet length and get buffer with that size
|
||||
data := mc.buf.takeSmallBuffer(pktLen + 4)
|
||||
if data == nil {
|
||||
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
@ -350,10 +352,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
|
||||
// Auth Data [length encoded integer]
|
||||
pos += copy(data[pos:], authRespLEI)
|
||||
pos += copy(data[pos:], authResp)
|
||||
if addNUL {
|
||||
data[pos] = 0x00
|
||||
pos++
|
||||
}
|
||||
|
||||
// Databasename [null terminated string]
|
||||
if len(mc.cfg.DBName) > 0 {
|
||||
@ -364,30 +362,24 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool,
|
||||
|
||||
pos += copy(data[pos:], plugin)
|
||||
data[pos] = 0x00
|
||||
pos++
|
||||
|
||||
// Send Auth packet
|
||||
return mc.writePacket(data)
|
||||
return mc.writePacket(data[:pos])
|
||||
}
|
||||
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
||||
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
|
||||
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
|
||||
pktLen := 4 + len(authData)
|
||||
if addNUL {
|
||||
pktLen++
|
||||
}
|
||||
data := mc.buf.takeSmallBuffer(pktLen)
|
||||
if data == nil {
|
||||
data, err := mc.buf.takeSmallBuffer(pktLen)
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
// Add the auth data [EOF]
|
||||
copy(data[4:], authData)
|
||||
if addNUL {
|
||||
data[pktLen-1] = 0x00
|
||||
}
|
||||
|
||||
return mc.writePacket(data)
|
||||
}
|
||||
|
||||
@ -399,10 +391,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
|
||||
// Reset Packet Sequence
|
||||
mc.sequence = 0
|
||||
|
||||
data := mc.buf.takeSmallBuffer(4 + 1)
|
||||
if data == nil {
|
||||
data, err := mc.buf.takeSmallBuffer(4 + 1)
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
@ -418,10 +410,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
|
||||
mc.sequence = 0
|
||||
|
||||
pktLen := 1 + len(arg)
|
||||
data := mc.buf.takeBuffer(pktLen + 4)
|
||||
if data == nil {
|
||||
data, err := mc.buf.takeBuffer(pktLen + 4)
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
@ -439,10 +431,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
|
||||
// Reset Packet Sequence
|
||||
mc.sequence = 0
|
||||
|
||||
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
|
||||
if data == nil {
|
||||
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
@ -479,7 +471,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
|
||||
return data[1:], "", err
|
||||
|
||||
case iEOF:
|
||||
if len(data) < 1 {
|
||||
if len(data) == 1 {
|
||||
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
|
||||
return nil, "mysql_old_password", nil
|
||||
}
|
||||
@ -895,7 +887,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
const minPktLen = 4 + 1 + 4 + 1 + 4
|
||||
mc := stmt.mc
|
||||
|
||||
// Determine threshould dynamically to avoid packet size shortage.
|
||||
// Determine threshold dynamically to avoid packet size shortage.
|
||||
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
|
||||
if longDataSize < 64 {
|
||||
longDataSize = 64
|
||||
@ -905,15 +897,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
mc.sequence = 0
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
if len(args) == 0 {
|
||||
data = mc.buf.takeBuffer(minPktLen)
|
||||
data, err = mc.buf.takeBuffer(minPktLen)
|
||||
} else {
|
||||
data = mc.buf.takeCompleteBuffer()
|
||||
data, err = mc.buf.takeCompleteBuffer()
|
||||
// In this case the len(data) == cap(data) which is used to optimise the flow below.
|
||||
}
|
||||
if data == nil {
|
||||
if err != nil {
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
@ -939,7 +933,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
pos := minPktLen
|
||||
|
||||
var nullMask []byte
|
||||
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
|
||||
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
|
||||
// buffer has to be extended but we don't know by how much so
|
||||
// we depend on append after all data with known sizes fit.
|
||||
// We stop at that because we deal with a lot of columns here
|
||||
@ -948,10 +942,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
copy(tmp[:pos], data[:pos])
|
||||
data = tmp
|
||||
nullMask = data[pos : pos+maskLen]
|
||||
// No need to clean nullMask as make ensures that.
|
||||
pos += maskLen
|
||||
} else {
|
||||
nullMask = data[pos : pos+maskLen]
|
||||
for i := 0; i < maskLen; i++ {
|
||||
for i := range nullMask {
|
||||
nullMask[i] = 0
|
||||
}
|
||||
pos += maskLen
|
||||
@ -1088,7 +1083,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
||||
// In that case we must build the data packet with the new values buffer
|
||||
if valuesCap != cap(paramValues) {
|
||||
data = append(data[:pos], paramValues...)
|
||||
mc.buf.buf = data
|
||||
if err = mc.buf.store(data); err != nil {
|
||||
errLog.Print(err)
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
}
|
||||
|
||||
pos += len(paramValues)
|
||||
|
||||
Reference in New Issue
Block a user