diff --git a/TODO.MD b/TODO.MD index 2c1f997f6..3d7f9cdb1 100644 --- a/TODO.MD +++ b/TODO.MD @@ -42,6 +42,8 @@ 1. gtcp提供简便的包发送/接收方法(SendPkg/RecvPkg)以解决常见的TCP通信粘包问题,并完善文档(参考:https://www.cnblogs.com/kex1n/p/6502002.html); 1. gfile对于文件的读写强行使用了gfpool,在某些场景下不合适,需要考虑剥离开,并为开发者提供单独的指针池文件操作特性; 1. 路由增加不区分大小写得匹配方式; +1. str_ireplace: http://php.net/manual/en/function.str-ireplace.php +1. strpos/stripos/strrpos/strripos: http://php.net/manual/en/function.stripos.php diff --git a/g/container/gvar/gvar.go b/g/container/gvar/gvar.go index 2f65f7a38..48be7e36c 100644 --- a/g/container/gvar/gvar.go +++ b/g/container/gvar/gvar.go @@ -10,6 +10,7 @@ package gvar import ( "gitee.com/johng/gf/g/container/gtype" + "gitee.com/johng/gf/g/os/gtime" "gitee.com/johng/gf/g/util/gconv" "time" ) @@ -92,6 +93,8 @@ func (v *Var) Interfaces() []interface{} { return gconv.Interfaces(v.Val() func (v *Var) Time(format...string) time.Time { return gconv.Time(v.Val(), format...) } func (v *Var) TimeDuration() time.Duration { return gconv.TimeDuration(v.Val()) } +func (v *Var) GTime(format...string) *gtime.Time { return gconv.GTime(v.Val(), format...) } + // 将变量转换为对象,注意 objPointer 参数必须为struct指针 func (v *Var) Struct(objPointer interface{}, attrMapping...map[string]string) error { return gconv.Struct(v.Val(), objPointer, attrMapping...) diff --git a/g/container/gvar/gvar_read.go b/g/container/gvar/gvar_read.go index bd4b7f8ef..4b75f250b 100644 --- a/g/container/gvar/gvar_read.go +++ b/g/container/gvar/gvar_read.go @@ -6,7 +6,10 @@ package gvar -import "time" +import ( + "gitee.com/johng/gf/g/os/gtime" + "time" +) // 只读变量接口 type VarRead interface { @@ -34,5 +37,6 @@ type VarRead interface { Interfaces() []interface{} Time(format ...string) time.Time TimeDuration() time.Duration + GTime(format...string) *gtime.Time Struct(objPointer interface{}, attrMapping ...map[string]string) error } \ No newline at end of file diff --git a/g/database/gdb/gdb.go b/g/database/gdb/gdb.go index e66f23931..65b97f494 100644 --- a/g/database/gdb/gdb.go +++ b/g/database/gdb/gdb.go @@ -4,6 +4,7 @@ // If a copy of the MIT was not distributed with this file, // You can obtain one at https://gitee.com/johng/gf. +// Package gdb provides ORM features for popular relationship databases. // 数据库ORM. // 默认内置支持MySQL, 其他数据库需要手动import对应的数据库引擎第三方包. package gdb @@ -12,7 +13,6 @@ import ( "database/sql" "errors" "fmt" - "gitee.com/johng/gf/g/container/gmap" "gitee.com/johng/gf/g/container/gring" "gitee.com/johng/gf/g/container/gtype" "gitee.com/johng/gf/g/container/gvar" @@ -30,31 +30,40 @@ const ( ) // 数据库操作接口 -type Link interface { - // 打开数据库连接,建立数据库操作对象 - Open(c *ConfigNode) (*sql.DB, error) +type DB interface { + // 建立数据库连接方法(开发者一般不需要直接调用) + Open(config *ConfigNode) (*sql.DB, error) // SQL操作方法 - Query(q string, args ...interface{}) (*sql.Rows, error) - Exec(q string, args ...interface{}) (sql.Result, error) - Prepare(q string) (*sql.Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + Exec(sql string, args ...interface{}) (sql.Result, error) + Prepare(sql string, execOnMaster...bool) (*sql.Stmt, error) + + doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) + doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) + doPrepare(link dbLink, query string) (*sql.Stmt, error) + doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error) + doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error) + doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error) + doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) // 数据库查询 - GetAll(q string, args ...interface{}) (Result, error) - GetOne(q string, args ...interface{}) (Record, error) - GetValue(q string, args ...interface{}) (Value, error) + GetAll(query string, args ...interface{}) (Result, error) + GetOne(query string, args ...interface{}) (Record, error) + GetValue(query string, args ...interface{}) (Value, error) + GetCount(query string, args ...interface{}) (int, error) + GetStruct(obj interface{}, query string, args ...interface{}) error - // Ping + // 创建底层数据库master/slave链接对象 + Master() (*sql.DB, error) + Slave() (*sql.DB, error) + + // Ping PingMaster() error PingSlave() error - // 连接属性设置 - SetMaxIdleConns(n int) - SetMaxOpenConns(n int) - SetConnMaxLifetime(n int) - // 开启事务操作 - Begin() (*Tx, error) + Begin() (*TX, error) // 数据表插入/更新/保存操作 Insert(table string, data Map) (sql.Result, error) @@ -74,21 +83,34 @@ type Link interface { Table(tables string) *Model From(tables string) *Model - // 内部方法 - insert(table string, data Map, option uint8) (sql.Result, error) - batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) + // 设置管理 + SetDebug(debug bool) + GetQueriedSqls() []*Sql + PrintQueriedSqls() + SetMaxIdleConns(n int) + SetMaxOpenConns(n int) + SetConnMaxLifetime(n int) - getQuoteCharLeft() string - getQuoteCharRight() string - handleSqlBeforeExec(q *string) *string + // 内部方法接口 + getCache() (*gcache.Cache) + getChars() (charLeft string, charRight string) + getDebug() bool + filterFields(table string, data map[string]interface{}) map[string]interface{} + getTableFields(table string) (map[string]string, error) + handleSqlBeforeExec(sql string) string +} + +// 执行底层数据库操作的核心接口 +type dbLink interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + Exec(sql string, args ...interface{}) (sql.Result, error) + Prepare(sql string) (*sql.Stmt, error) } // 数据库链接对象 -type Db struct { - link Link // 底层数据库类型管理对象 +type dbBase struct { + db DB // 数据库对象 group string // 配置分组名称 - charl string // SQL安全符号(左) - charr string // SQL安全符号(右) debug *gtype.Bool // (默认关闭)是否开启调试模式,当开启时会启用一些调试特性 sqls *gring.Ring // (debug=true时有效)已执行的SQL列表 cache *gcache.Cache // 数据库缓存,包括底层连接池对象缓存及查询缓存;需要注意的是,事务查询不支持查询缓存 @@ -104,7 +126,7 @@ type Sql struct { Error error // 执行结果(nil为成功) Start int64 // 执行开始时间(毫秒) End int64 // 执行结束时间(毫秒) - Func string // 执行方法名称 + Func string // 执行方法 } // 返回数据表记录值 @@ -117,27 +139,13 @@ type Record map[string]Value type Result []Record // 关联数组,绑定一条数据表记录(使用别名) -type Map = map[string]interface{} +type Map = map[string]interface{} // 关联数组列表(索引从0开始的数组),绑定多条记录(使用别名) type List = []Map -var ( - // 支持的数据库类型map - driverMap = make(map[string]interface{}) - // 数据库查询缓存对象map,使用数据库连接名称作为键名,键值为查询缓存对象 - dbCaches = gmap.NewStringInterfaceMap() -) -func init() { - driverMap["mysql"] = linkMysql - driverMap["oracle"] = linkOracle - driverMap["sqlite"] = linkSqlite - driverMap["pgsql"] = linkPgsql - driverMap["mssql"] = linkMssql -} - // 使用默认/指定分组配置进行连接,数据库集群配置项:default -func New(groupName ...string) (*Db, error) { +func New(groupName ...string) (db DB, err error) { group := config.d if len(groupName) > 0 { group = groupName[0] @@ -150,24 +158,29 @@ func New(groupName ...string) (*Db, error) { } if _, ok := config.c[group]; ok { if node, err := getConfigNodeByGroup(group, true); err == nil { - link, err := getLinkByType(node.Type) - if err != nil { - return nil, err - } - db := &Db { - link : link, + base := &dbBase { group : group, - charl : link.getQuoteCharLeft(), - charr : link.getQuoteCharRight(), debug : gtype.NewBool(), + cache : gcache.New(), maxIdleConnCount : gtype.NewInt(), maxOpenConnCount : gtype.NewInt(), maxConnLifetime : gtype.NewInt(), } - db.cache = dbCaches.GetOrSetFuncLock(group, func() interface{} { - return gcache.New() - }).(*gcache.Cache) - return db, nil + switch node.Type { + case "mysql": + base.db = &dbMysql{dbBase : base} + case "pgsql": + base.db = &dbPgsql{dbBase : base} + case "mssql": + base.db = &dbMssql{dbBase : base} + case "sqlite": + base.db = &dbSqlite{dbBase : base} + case "oracle": + base.db = &dbOracle{dbBase : base} + default: + return nil, errors.New(fmt.Sprintf(`unsupported database type "%s"`, node.Type)) + } + return base.db, nil } else { return nil, err } @@ -238,51 +251,37 @@ func getConfigNodeByPriority(cg ConfigGroup) *ConfigNode { return nil } -// 根据配置的数据库;类型获得Link接口对象 -func getLinkByType(dbType string) (Link, error) { - if dblink, ok := driverMap[dbType]; ok == false { - return nil, errors.New(fmt.Sprintf("unsupported db type '%s'", dbType)) - } else { - return dblink.(Link), nil - } -} - // 获得底层数据库链接对象 -func (db *Db) getSqlDb(master bool) (sqlDb *sql.DB, err error) { +func (bs *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) { // 负载均衡 - node, err := getConfigNodeByGroup(db.group, master) - if err != nil { - return nil, err - } - // 类型对象 - link, err := getLinkByType(node.Type) + node, err := getConfigNodeByGroup(bs.group, master) if err != nil { return nil, err } // 检查缓存连接池对象 cacheKey := node.String() - if v := db.cache.Get(cacheKey); v != nil { + if v := bs.cache.Get(cacheKey); v != nil { return v.(*sql.DB), nil } - v := db.cache.GetOrSetFuncLock(node.String(), func() interface{} { - sqlDb, err = link.Open(node) + v := bs.cache.GetOrSetFuncLock(node.String(), func() interface{} { + sqlDb, err = bs.db.Open(node) if err != nil { return nil } - if n := db.maxIdleConnCount.Val(); n > 0 { + if n := bs.maxIdleConnCount.Val(); n > 0 { sqlDb.SetMaxIdleConns(n) } else if node.MaxIdleConnCount > 0 { sqlDb.SetMaxIdleConns(node.MaxIdleConnCount) } - if n := db.maxOpenConnCount.Val(); n > 0 { + if n := bs.maxOpenConnCount.Val(); n > 0 { sqlDb.SetMaxOpenConns(n) } else if node.MaxOpenConnCount > 0 { sqlDb.SetMaxOpenConns(node.MaxOpenConnCount) } - if n := db.maxConnLifetime.Val(); n > 0 { + if n := bs.maxConnLifetime.Val(); n > 0 { sqlDb.SetConnMaxLifetime(time.Duration(n) * time.Second) } else if node.MaxConnLifetime > 0 { sqlDb.SetConnMaxLifetime(time.Duration(node.MaxConnLifetime) * time.Second) @@ -296,11 +295,11 @@ func (db *Db) getSqlDb(master bool) (sqlDb *sql.DB, err error) { } // 创建底层数据库master链接对象 -func (db *Db) Master() (*sql.DB, error) { - return db.getSqlDb(true) +func (bs *dbBase) Master() (*sql.DB, error) { + return bs.getSqlDb(true) } // 创建底层数据库slave链接对象 -func (db *Db) Slave() (*sql.DB, error) { - return db.getSqlDb(false) +func (bs *dbBase) Slave() (*sql.DB, error) { + return bs.getSqlDb(false) } diff --git a/g/database/gdb/gdb_base.go b/g/database/gdb/gdb_base.go index ade922df1..ad774aff4 100644 --- a/g/database/gdb/gdb_base.go +++ b/g/database/gdb/gdb_base.go @@ -8,39 +8,29 @@ package gdb import ( - "fmt" - "errors" - "strings" - "reflect" "database/sql" - "gitee.com/johng/gf/g/util/gstr" - "gitee.com/johng/gf/g/util/gconv" - "gitee.com/johng/gf/g/container/gring" + "errors" + "fmt" + "gitee.com/johng/gf/g/os/gcache" "gitee.com/johng/gf/g/os/gtime" - "gitee.com/johng/gf/g/os/glog" - "gitee.com/johng/gf/g/container/gvar" + "gitee.com/johng/gf/g/util/gconv" + "gitee.com/johng/gf/g/util/gregex" + "reflect" + "strings" ) const ( gDEFAULT_DEBUG_SQL_LENGTH = 1000 // 默认调试模式下记录的SQL条数 ) -// 是否开启调试服务 -func (db *Db) SetDebug(debug bool) { - db.debug.Set(debug) - if debug && db.sqls == nil { - db.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH) - } -} - // 获取已经执行的SQL列表(仅在debug=true时有效) -func (db *Db) GetQueriedSqls() []*Sql { - if db.sqls == nil { +func (bs *dbBase) GetQueriedSqls() []*Sql { + if bs.sqls == nil { return nil } sqls := make([]*Sql, 0) - db.sqls.Prev() - db.sqls.RLockIteratorPrev(func(value interface{}) bool { + bs.sqls.Prev() + bs.sqls.RLockIteratorPrev(func(value interface{}) bool { if value == nil { return false } @@ -51,8 +41,8 @@ func (db *Db) GetQueriedSqls() []*Sql { } // 打印已经执行的SQL列表(仅在debug=true时有效) -func (db *Db) PrintQueriedSqls() { - sqls := db.GetQueriedSqls() +func (bs *dbBase) PrintQueriedSqls() { + sqls := bs.GetQueriedSqls() for k, v := range sqls { fmt.Println(len(sqls) - k, ":") fmt.Println(" Sql :", v.Sql) @@ -61,143 +51,110 @@ func (db *Db) PrintQueriedSqls() { fmt.Println(" Start:", gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u")) fmt.Println(" End :", gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u")) fmt.Println(" Cost :", v.End - v.Start, "ms") - fmt.Println(" Func :", v.Func) - } -} - -// 打印SQL对象(仅在debug=true时有效) -func (db *Db) printSql(v *Sql) { - s := fmt.Sprintf("%s, %v, %s, %s, %d ms, %s", v.Sql, v.Args, - gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u"), - gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u"), - v.End - v.Start, v.Func, - ) - if v.Error != nil { - s += "\nError: " + v.Error.Error() - glog.Backtrace(true, 2).Error(s) - } else { - glog.Debug(s) } } // 数据库sql查询操作,主要执行查询 -func (db *Db) Query(query string, args ...interface{}) (*sql.Rows, error) { - var err error - var rows *sql.Rows - var slave *sql.DB - slave, err = db.Slave(); +func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { + link, err := bs.db.Slave() if err != nil { return nil,err } - p := db.link.handleSqlBeforeExec(&query) - if db.debug.Val() { - militime1 := gtime.Millisecond() - rows, err = slave.Query(*p, args ...) - militime2 := gtime.Millisecond() - s := &Sql{ - Sql : *p, + return bs.db.doQuery(link, query, args...) +} + +// 数据库sql查询操作,主要执行查询 +func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) { + query = bs.db.handleSqlBeforeExec(query) + if bs.db.getDebug() { + mTime1 := gtime.Millisecond() + rows, err = link.Query(query, args...) + mTime2 := gtime.Millisecond() + s := &Sql { + Sql : query, Args : args, Error : err, - Start : militime1, - End : militime2, - Func : "DB:Query", + Start : mTime1, + End : mTime2, } - db.sqls.Put(s) - db.printSql(s) + bs.sqls.Put(s) + printSql(s) } else { - rows, err = slave.Query(*p, args ...) + rows, err = link.Query(query, args ...) } if err == nil { return rows, nil } else { - err = db.formatError(err, p, args...) + err = formatError(err, query, args...) } return nil, err } // 执行一条sql,并返回执行情况,主要用于非查询操作 -func (db *Db) Exec(query string, args ...interface{}) (sql.Result, error) { - var err error - var result sql.Result - var master *sql.DB - master, err = db.Master(); +func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) { + link, err := bs.db.Master() if err != nil { return nil,err } - p := db.link.handleSqlBeforeExec(&query) - if db.debug.Val() { - militime1 := gtime.Millisecond() - result, err = master.Exec(*p, args ...) - militime2 := gtime.Millisecond() - s := &Sql{ - Sql : *p, + return bs.db.doExec(link, query, args...) +} + +// 执行一条sql,并返回执行情况,主要用于非查询操作 +func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) { + query = bs.db.handleSqlBeforeExec(query) + if bs.db.getDebug() { + mTime1 := gtime.Millisecond() + result, err = link.Exec(query, args ...) + mTime2 := gtime.Millisecond() + s := &Sql{ + Sql : query, Args : args, Error : err, - Start : militime1, - End : militime2, - Func : "DB:Exec", + Start : mTime1, + End : mTime2, } - db.sqls.Put(s) - db.printSql(s) + bs.sqls.Put(s) + printSql(s) } else { - result, err = master.Exec(*p, args ...) + result, err = link.Exec(query, args ...) } - return result, db.formatError(err, p, args...) + return result, formatError(err, query, args...) } -// 格式化错误信息 -func (db *Db) formatError(err error, query *string, args ...interface{}) error { - if err != nil { - errstr := fmt.Sprintf("DB ERROR: %s\n", err.Error()) - errstr += fmt.Sprintf("DB QUERY: %s\n", *query) - if len(args) > 0 { - errstr += fmt.Sprintf("DB PARAM: %v\n", args) +// SQL预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作; 默认执行在Slave上, 通过第二个参数指定执行在Master上 +func (bs *dbBase) Prepare(query string, execOnMaster...bool) (*sql.Stmt, error) { + err := (error)(nil) + link := (dbLink)(nil) + if len(execOnMaster) > 0 && execOnMaster[0] { + if link, err = bs.db.Master(); err != nil { + return nil, err + } + } else { + if link, err = bs.db.Slave(); err != nil { + return nil, err } - err = errors.New(errstr) } - return err + return bs.db.doPrepare(link, query) } +// SQL预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 +func (bs *dbBase) doPrepare(link dbLink, query string) (*sql.Stmt, error) { + return link.Prepare(query) +} // 数据库查询,获取查询结果集,以列表结构返回 -func (db *Db) GetAll(query string, args ...interface{}) (Result, error) { - // 执行sql - rows, err := db.Query(query, args ...) +func (bs *dbBase) GetAll(query string, args ...interface{}) (Result, error) { + rows, err := bs.Query(query, args ...) if err != nil || rows == nil { return nil, err } - // 列名称列表 - columns, err := rows.Columns() - if err != nil { - return nil, err - } - // 返回结构组装 - values := make([]sql.RawBytes, len(columns)) - scanArgs := make([]interface{}, len(values)) - records := make(Result, 0) - for i := range values { - scanArgs[i] = &values[i] - } - for rows.Next() { - err = rows.Scan(scanArgs...) - if err != nil { - return records, err - } - row := make(Record) - // 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址 - for i, col := range values { - v := make([]byte, len(col)) - copy(v, col) - row[columns[i]] = gvar.New(v, false) - } - records = append(records, row) - } - return records, nil + defer rows.Close() + return rowsToResult(rows) } // 数据库查询,获取查询结果记录,以关联数组结构返回 -func (db *Db) GetOne(query string, args ...interface{}) (Record, error) { - list, err := db.GetAll(query, args ...) +func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) { + list, err := bs.GetAll(query, args ...) if err != nil { return nil, err } @@ -208,18 +165,17 @@ func (db *Db) GetOne(query string, args ...interface{}) (Record, error) { } // 数据库查询,获取查询结果记录,自动映射数据到给定的struct对象中 -func (db *Db) GetStruct(obj interface{}, query string, args ...interface{}) error { - one, err := db.GetOne(query, args...) +func (bs *dbBase) GetStruct(obj interface{}, query string, args ...interface{}) error { + one, err := bs.GetOne(query, args...) if err != nil { return err } return one.ToStruct(obj) } - // 数据库查询,获取查询字段值 -func (db *Db) GetValue(query string, args ...interface{}) (Value, error) { - one, err := db.GetOne(query, args ...) +func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) { + one, err := bs.GetOne(query, args ...) if err != nil { return nil, err } @@ -230,44 +186,20 @@ func (db *Db) GetValue(query string, args ...interface{}) (Value, error) { } // 数据库查询,获取查询数量 -func (db *Db) GetCount(query string, args ...interface{}) (int, error) { - val, err := db.GetValue(query, args ...) +func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) { + if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) { + query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query) + } + value, err := bs.GetValue(query, args ...) if err != nil { return 0, err } - return gconv.Int(val), nil -} - -// 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作 -func (db *Db) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) { - s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables) - if condition != nil { - s += fmt.Sprintf("WHERE %s ", db.formatCondition(condition)) - } - if len(groupBy) > 0 { - s += fmt.Sprintf("GROUP BY %s ", groupBy) - } - if len(orderBy) > 0 { - s += fmt.Sprintf("ORDER BY %s ", orderBy) - } - if limit > 0 { - s += fmt.Sprintf("LIMIT %d,%d ", first, limit) - } - return db.GetAll(s, args ... ) -} - -// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 -func (db *Db) Prepare(query string) (*sql.Stmt, error) { - if master, err := db.Master(); err != nil { - return nil, err - } else { - return master.Prepare(query) - } + return value.Int(), nil } // ping一下,判断或保持数据库链接(master) -func (db *Db) PingMaster() error { - if master, err := db.Master(); err != nil { +func (bs *dbBase) PingMaster() error { + if master, err := bs.db.Master(); err != nil { return err } else { return master.Ping() @@ -275,8 +207,8 @@ func (db *Db) PingMaster() error { } // ping一下,判断或保持数据库链接(slave) -func (db *Db) PingSlave() error { - if slave, err := db.Slave(); err != nil { +func (bs *dbBase) PingSlave() error { + if slave, err := bs.db.Slave(); err != nil { return err } else { return slave.Ping() @@ -285,13 +217,13 @@ func (db *Db) PingSlave() error { // 事务操作,开启,会返回一个底层的事务操作对象链接如需要嵌套事务,那么可以使用该对象,否则请忽略 // 只有在tx.Commit/tx.Rollback时,链接会自动Close -func (db *Db) Begin() (*Tx, error) { - if master, err := db.Master(); err != nil { +func (bs *dbBase) Begin() (*TX, error) { + if master, err := bs.db.Master(); err != nil { return nil, err } else { if tx, err := master.Begin(); err == nil { - return &Tx { - db : db, + return &TX { + db : bs.db, tx : tx, master : master, }, nil @@ -301,17 +233,19 @@ func (db *Db) Begin() (*Tx, error) { } } -// 根据insert选项获得操作名称 -func (db *Db) getInsertOperationByOption(option uint8) string { - oper := "INSERT" - switch option { - case OPTION_REPLACE: - oper = "REPLACE" - case OPTION_SAVE: - case OPTION_IGNORE: - oper = "INSERT IGNORE" - } - return oper +// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 +func (bs *dbBase) Insert(table string, data Map) (sql.Result, error) { + return bs.db.doInsert(nil, table, data, OPTION_INSERT) +} + +// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 +func (bs *dbBase) Replace(table string, data Map) (sql.Result, error) { + return bs.db.doInsert(nil, table, data, OPTION_REPLACE) +} + +// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 +func (bs *dbBase) Save(table string, data Map) (sql.Result, error) { + return bs.db.doInsert(nil, table, data, OPTION_SAVE) } // insert、replace, save, ignore操作 @@ -319,95 +253,102 @@ func (db *Db) getInsertOperationByOption(option uint8) string { // 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 // 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 // 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做 -func (db *Db) insert(table string, data Map, option uint8) (sql.Result, error) { +func (bs *dbBase) doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error) { var fields []string var values []string var params []interface{} + charl, charr := bs.db.getChars() for k, v := range data { - fields = append(fields, db.charl + k + db.charr) + fields = append(fields, charl + k + charr) values = append(values, "?") params = append(params, v) } - operation := db.getInsertOperationByOption(option) + operation := getInsertOperationByOption(option) updatestr := "" if option == OPTION_SAVE { var updates []string for k, _ := range data { updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s%s%s)", - db.charl, k, db.charr, - db.charl, k, db.charr, + charl, k, charr, + charl, k, charr, ), ) } updatestr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) } - return db.Exec( - fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s", - operation, table, strings.Join(fields, ","), - strings.Join(values, ","), - updatestr), - params... - ) + if link == nil { + if link, err = bs.db.Master(); err != nil { + return nil, err + } + } + return bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s", + operation, table, strings.Join(fields, ","), + strings.Join(values, ","), updatestr), + params...) } -// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 -func (db *Db) Insert(table string, data Map) (sql.Result, error) { - return db.insert(table, data, OPTION_INSERT) +// CURD操作:批量数据指定批次量写入 +func (bs *dbBase) BatchInsert(table string, list List, batch int) (sql.Result, error) { + return bs.db.doBatchInsert(nil, table, list, batch, OPTION_INSERT) } -// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -func (db *Db) Replace(table string, data Map) (sql.Result, error) { - return db.insert(table, data, OPTION_REPLACE) +// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 +func (bs *dbBase) BatchReplace(table string, list List, batch int) (sql.Result, error) { + return bs.db.doBatchInsert(nil, table, list, batch, OPTION_REPLACE) } -// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -func (db *Db) Save(table string, data Map) (sql.Result, error) { - return db.insert(table, data, OPTION_SAVE) +// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 +func (bs *dbBase) BatchSave(table string, list List, batch int) (sql.Result, error) { + return bs.db.doBatchInsert(nil, table, list, batch, OPTION_SAVE) } // 批量写入数据 -func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) { +func (bs *dbBase) doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error) { var keys []string var values []string var bvalues []string var params []interface{} - var result sql.Result - var size = len(list) // 判断长度 - if size < 1 { + if len(list) < 1 { return result, errors.New("empty data list") } + if link == nil { + if link, err = bs.db.Master(); err != nil { + return + } + } // 首先获取字段名称及记录长度 for k, _ := range list[0] { keys = append(keys, k) values = append(values, "?") } - keyStr := db.charl + strings.Join(keys, db.charl + "," + db.charr) + db.charr + charl, charr := bs.db.getChars() + keyStr := charl + strings.Join(keys, charl + "," + charr) + charr valueHolderStr := "(" + strings.Join(values, ",") + ")" // 操作判断 - operation := db.getInsertOperationByOption(option) + operation := getInsertOperationByOption(option) updatestr := "" if option == OPTION_SAVE { var updates []string for _, k := range keys { updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s%s%s)", - db.charl, k, db.charr, - db.charl, k, db.charr, + charl, k, charr, + charl, k, charr, ), ) } updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) } // 构造批量写入数据格式(注意map的遍历是无序的) - for i := 0; i < size; i++ { + for i := 0; i < len(list); i++ { for _, k := range keys { params = append(params, list[i][k]) } bvalues = append(bvalues, valueHolderStr) if len(bvalues) == batch { - r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", + r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", operation, table, keyStr, strings.Join(bvalues, ","), updatestr), params...) @@ -421,7 +362,7 @@ func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql } // 处理最后不构成指定批量的数据 if len(bvalues) > 0 { - r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", + r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", operation, table, keyStr, strings.Join(bvalues, ","), updatestr), params...) @@ -433,32 +374,28 @@ func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql return result, nil } -// CURD操作:批量数据指定批次量写入 -func (db *Db) BatchInsert(table string, list List, batch int) (sql.Result, error) { - return db.batchInsert(table, list, batch, OPTION_INSERT) -} - -// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -func (db *Db) BatchReplace(table string, list List, batch int) (sql.Result, error) { - return db.batchInsert(table, list, batch, OPTION_REPLACE) -} - -// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -func (db *Db) BatchSave(table string, list List, batch int) (sql.Result, error) { - return db.batchInsert(table, list, batch, OPTION_SAVE) +// CURD操作:数据更新,统一采用sql预处理 +// data参数支持字符串或者关联数组类型,内部会自行做判断处理 +func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { + link, err := bs.db.Master() + if err != nil { + return nil, err + } + return bs.db.doUpdate(link, table, data, condition, args ...) } // CURD操作:数据更新,统一采用sql预处理 // data参数支持字符串或者关联数组类型,内部会自行做判断处理 -func (db *Db) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { - var params []interface{} - var updates string - refValue := reflect.ValueOf(data) +func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error) { + params := ([]interface{})(nil) + updates := "" + charl, charr := bs.db.getChars() + refValue := reflect.ValueOf(data) if refValue.Kind() == reflect.Map { var fields []string keys := refValue.MapKeys() for _, k := range keys { - fields = append(fields, fmt.Sprintf("%s%s%s=?", db.charl, k, db.charr)) + fields = append(fields, fmt.Sprintf("%s%s%s=?", charl, k, charr)) params = append(params, gconv.String(refValue.MapIndex(k).Interface())) } updates = strings.Join(fields, ",") @@ -468,34 +405,62 @@ func (db *Db) Update(table string, data interface{}, condition interface{}, args for _, v := range args { params = append(params, gconv.String(v)) } - return db.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, db.formatCondition(condition)), params...) + if link == nil { + if link, err = bs.db.Master(); err != nil { + return nil, err + } + } + return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, formatCondition(condition)), params...) } // CURD操作:删除数据 -func (db *Db) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { - return db.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, db.formatCondition(condition)), args...) +func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) { + link, err := bs.db.Master() + if err != nil { + return nil, err + } + return bs.db.doDelete(link, table, condition, args ...) } -// 格式化SQL查询条件 -func (db *Db) formatCondition(condition interface{}) (where string) { - if reflect.ValueOf(condition).Kind() == reflect.Map { - ks := reflect.ValueOf(condition).MapKeys() - vs := reflect.ValueOf(condition) - for _, k := range ks { - key := gconv.String(k.Interface()) - value := gconv.String(vs.MapIndex(k).Interface()) - isNum := gstr.IsNumeric(value) - if len(where) > 0 { - where += " AND " - } - if isNum || value == "?" { - where += key + "=" + value - } else { - where += key + "='" + value + "'" +// CURD操作:删除数据 +func (bs *dbBase) doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) { + return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, formatCondition(condition)), args...) +} + +// 获得缓存对象 +func (bs *dbBase) getCache() *gcache.Cache { + return bs.cache +} + +// 将map的数据按照fields进行过滤,只保留与表字段同名的数据 +func (bs *dbBase) filterFields(table string, data map[string]interface{}) map[string]interface{} { + if fields, err := bs.db.getTableFields(table); err == nil { + for k, _ := range data { + if _, ok := fields[k]; !ok { + delete(data, k) } } - } else { - where += gconv.String(condition) + } + return data +} + +// 获得指定表表的数据结构map +func (bs *dbBase) getTableFields(table string) (fields map[string]string, err error) { + v := bs.cache.GetOrSetFunc("table_fields_" + table, func() interface{} { + result := (Result)(nil) + charl, charr := bs.db.getChars() + result, err = bs.GetAll(fmt.Sprintf(`SHOW COLUMNS FROM %s%s%s`, charl, table, charr)) + if err != nil { + return nil + } + fields = make(map[string]string) + for _, m := range result { + fields[m["Field"].String()] = m["Type"].String() + } + return fields + }, 0) + if err == nil { + fields = v.(map[string]string) } return } \ No newline at end of file diff --git a/g/database/gdb/gdb_config.go b/g/database/gdb/gdb_config.go index f3a6f490c..10ca88dec 100644 --- a/g/database/gdb/gdb_config.go +++ b/g/database/gdb/gdb_config.go @@ -9,6 +9,7 @@ package gdb import ( "fmt" + "gitee.com/johng/gf/g/container/gring" "sync" ) @@ -122,19 +123,19 @@ func SetDefaultGroup (groupName string) { } // 设置数据库连接池中空闲链接的大小 -func (db *Db) SetMaxIdleConns(n int) { - db.maxIdleConnCount.Set(n) +func (bs *dbBase) SetMaxIdleConns(n int) { + bs.maxIdleConnCount.Set(n) } // 设置数据库连接池最大打开的链接数量 -func (db *Db) SetMaxOpenConns(n int) { - db.maxOpenConnCount.Set(n) +func (bs *dbBase) SetMaxOpenConns(n int) { + bs.maxOpenConnCount.Set(n) } // 设置数据库连接可重复利用的时间,超过该时间则被关闭废弃 // 如果 d <= 0 表示该链接会一直重复利用 -func (db *Db) SetConnMaxLifetime(n int) { - db.maxConnLifetime.Set(n) +func (bs *dbBase) SetConnMaxLifetime(n int) { + bs.maxConnLifetime.Set(n) } // 节点配置转换为字符串 @@ -146,4 +147,17 @@ func (node *ConfigNode) String() string { node.Name, node.Type, node.Role, node.Charset, node.MaxIdleConnCount, node.MaxOpenConnCount, node.MaxConnLifetime, ) +} + +// 是否开启调试服务 +func (bs *dbBase) SetDebug(debug bool) { + bs.debug.Set(debug) + if debug && bs.sqls == nil { + bs.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH) + } +} + +// 获取是否开启调试服务 +func (bs *dbBase) getDebug() bool { + return bs.debug.Val() } \ No newline at end of file diff --git a/g/database/gdb/gdb_func.go b/g/database/gdb/gdb_func.go new file mode 100644 index 000000000..42210e115 --- /dev/null +++ b/g/database/gdb/gdb_func.go @@ -0,0 +1,124 @@ +// Copyright 2017-2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://gitee.com/johng/gf. + +package gdb + +import ( + "database/sql" + "errors" + "fmt" + "gitee.com/johng/gf/g/container/gvar" + "gitee.com/johng/gf/g/os/glog" + "gitee.com/johng/gf/g/os/gtime" + "gitee.com/johng/gf/g/util/gconv" + "gitee.com/johng/gf/g/util/gstr" + _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" + "reflect" +) + +// 将数据查询的列表数据*sql.Rows转换为Result类型 +func rowsToResult(rows *sql.Rows) (Result, error) { + // 列名称列表 + columns, err := rows.Columns() + if err != nil { + return nil, err + } + // 返回结构组装 + values := make([]sql.RawBytes, len(columns)) + scanArgs := make([]interface{}, len(values)) + records := make(Result, 0) + for i := range values { + scanArgs[i] = &values[i] + } + for rows.Next() { + err = rows.Scan(scanArgs...) + if err != nil { + return records, err + } + row := make(Record) + // 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址 + for i, col := range values { + if col == nil { + row[columns[i]] = gvar.New(nil, false) + } else { + v := make([]byte, len(col)) + copy(v, col) + row[columns[i]] = gvar.New(v, false) + } + } + records = append(records, row) + } + return records, nil +} + +// 格式化SQL查询条件 +func formatCondition(condition interface{}) (where string) { + if reflect.ValueOf(condition).Kind() == reflect.Map { + ks := reflect.ValueOf(condition).MapKeys() + vs := reflect.ValueOf(condition) + for _, k := range ks { + key := gconv.String(k.Interface()) + value := gconv.String(vs.MapIndex(k).Interface()) + isNum := gstr.IsNumeric(value) + if len(where) > 0 { + where += " AND " + } + if isNum || value == "?" { + where += key + "=" + value + } else { + where += key + "='" + value + "'" + } + } + } else { + where += gconv.String(condition) + } + if len(where) == 0 { + where = "1" + } + return +} + +// 打印SQL对象(仅在debug=true时有效) +func printSql(v *Sql) { + s := fmt.Sprintf("%s, %v, %s, %s, %d ms, %s", v.Sql, v.Args, + gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u"), + gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u"), + v.End - v.Start, + v.Func, + ) + if v.Error != nil { + s += "\nError: " + v.Error.Error() + glog.Backtrace(true, 2).Error(s) + } else { + glog.Debug(s) + } +} + +// 格式化错误信息 +func formatError(err error, query string, args ...interface{}) error { + if err != nil { + errstr := fmt.Sprintf("DB ERROR: %s\n", err.Error()) + errstr += fmt.Sprintf("DB QUERY: %s\n", query) + if len(args) > 0 { + errstr += fmt.Sprintf("DB PARAM: %v\n", args) + } + err = errors.New(errstr) + } + return err +} + +// 根据insert选项获得操作名称 +func getInsertOperationByOption(option int) string { + oper := "INSERT" + switch option { + case OPTION_REPLACE: + oper = "REPLACE" + case OPTION_SAVE: + case OPTION_IGNORE: + oper = "INSERT IGNORE" + } + return oper +} diff --git a/g/database/gdb/gdb_model.go b/g/database/gdb/gdb_model.go index f9347f2c3..95780f6d1 100644 --- a/g/database/gdb/gdb_model.go +++ b/g/database/gdb/gdb_model.go @@ -12,12 +12,15 @@ import ( "database/sql" "gitee.com/johng/gf/g/util/gconv" _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" + "reflect" + "strings" ) // 数据库链式操作模型对象 type Model struct { - tx *Tx // 数据库事务对象 - db *Db // 数据库操作对象 + db DB // 数据库操作对象 + tx *TX // 数据库事务对象 + tablesInit string // 初始化Model时的表名称(可以是多个) tables string // 数据库操作表 fields string // 操作字段 where string // 操作条件 @@ -28,39 +31,51 @@ type Model struct { limit int // 分页条数 data interface{} // 操作记录(支持Map/List/string类型) batch int // 批量操作条数 + filter bool // 是否按照表字段过滤data参数 cacheEnabled bool // 当前SQL操作是否开启查询缓存功能 cacheTime int // 查询缓存时间 cacheName string // 查询缓存名称 } // 链式操作,数据表字段,可支持多个表,以半角逗号连接 -func (db *Db) Table(tables string) (*Model) { - return &Model{ - db: db, - tables: tables, - fields: "*", +func (bs *dbBase) Table(tables string) (*Model) { + return &Model { + db : bs.db, + tablesInit : tables, + tables : tables, + fields : "*", } } // 链式操作,数据表字段,可支持多个表,以半角逗号连接 -func (db *Db) From(tables string) (*Model) { - return db.Table(tables) +func (bs *dbBase) From(tables string) (*Model) { + return bs.db.Table(tables) } // (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接 -func (tx *Tx) Table(tables string) (*Model) { +func (tx *TX) Table(tables string) (*Model) { return &Model{ - db: tx.db, - tx: tx, - tables: tables, + db : tx.db, + tx : tx, + tablesInit : tables, + tables : tables, } } // (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接 -func (tx *Tx) From(tables string) (*Model) { +func (tx *TX) From(tables string) (*Model) { return tx.Table(tables) } +// 清空链式操作数据,以便改model可以重复使用 +func (md *Model) clear() { + if md.tx != nil { + *md = *md.tx.Table(md.tablesInit) + } else { + *md = *md.db.Table(md.tablesInit) + } +} + // 链式操作,左联表 func (md *Model) LeftJoin(joinTable string, on string) (*Model) { md.tables += fmt.Sprintf(" LEFT JOIN %s ON (%s)", joinTable, on) @@ -85,23 +100,33 @@ func (md *Model) Fields(fields string) (*Model) { return md } +// 链式操作,过滤字段 +func (md *Model) Filter() (*Model) { + md.filter = true + return md +} + // 链式操作,condition,支持string & gdb.Map func (md *Model) Where(where interface{}, args ...interface{}) (*Model) { - md.where = md.db.formatCondition(where) + md.where = formatCondition(where) md.whereArgs = append(md.whereArgs, args...) + // 支持 Where("uid", 1)这种格式 + if len(args) == 1 && strings.Index(md.where , "?") < 0 { + md.where += "=?" + } return md } // 链式操作,添加AND条件到Where中 func (md *Model) And(where interface{}, args ...interface{}) (*Model) { - md.where += " AND " + md.db.formatCondition(where) + md.where += " AND " + formatCondition(where) md.whereArgs = append(md.whereArgs, args...) return md } // 链式操作,添加OR条件到Where中 func (md *Model) Or(where interface{}, args ...interface{}) (*Model) { - md.where += " OR " + md.db.formatCondition(where) + md.where += " OR " + formatCondition(where) md.whereArgs = append(md.whereArgs, args...) return md } @@ -142,7 +167,32 @@ func (md *Model) Data(data ...interface{}) (*Model) { } md.data = m } else { - md.data = data[0] + switch data[0].(type) { + case List: + md.data = data[0] + case Map: + md.data = data[0] + default: + rv := reflect.ValueOf(data[0]) + kind := rv.Kind() + if kind == reflect.Ptr { + rv = rv.Elem() + kind = rv.Kind() + } + switch kind { + case reflect.Slice: fallthrough + case reflect.Array: + list := make(List, rv.Len()) + for i := 0; i < rv.Len(); i++ { + list[i] = gconv.Map(rv.Index(i).Interface()) + } + md.data = list + case reflect.Map: + md.data = gconv.Map(data[0]) + default: + md.data = data[0] + } + } } return md } @@ -153,6 +203,7 @@ func (md *Model) Insert() (result sql.Result, err error) { if err == nil { md.checkAndRemoveCache() } + md.clear() }() if md.data == nil { return nil, errors.New("inserting into table with empty data") @@ -163,16 +214,24 @@ func (md *Model) Insert() (result sql.Result, err error) { if md.batch > 0 { batch = md.batch } + if md.filter { + for k, m := range list { + list[k] = md.db.filterFields(md.tables, m) + } + } if md.tx == nil { return md.db.BatchInsert(md.tables, list, batch) } else { return md.tx.BatchInsert(md.tables, list, batch) } - } else if dataMap, ok := md.data.(Map); ok { + } else if data, ok := md.data.(Map); ok { + if md.filter { + data = md.db.filterFields(md.tables, data) + } if md.tx == nil { - return md.db.Insert(md.tables, dataMap) + return md.db.Insert(md.tables, data) } else { - return md.tx.Insert(md.tables, dataMap) + return md.tx.Insert(md.tables, data) } } return nil, errors.New("inserting into table with invalid data type") @@ -184,6 +243,7 @@ func (md *Model) Replace() (result sql.Result, err error) { if err == nil { md.checkAndRemoveCache() } + md.clear() }() if md.data == nil { return nil, errors.New("replacing into table with empty data") @@ -194,16 +254,24 @@ func (md *Model) Replace() (result sql.Result, err error) { if md.batch > 0 { batch = md.batch } + if md.filter { + for k, m := range list { + list[k] = md.db.filterFields(md.tables, m) + } + } if md.tx == nil { return md.db.BatchReplace(md.tables, list, batch) } else { return md.tx.BatchReplace(md.tables, list, batch) } - } else if dataMap, ok := md.data.(Map); ok { + } else if data, ok := md.data.(Map); ok { + if md.filter { + data = md.db.filterFields(md.tables, data) + } if md.tx == nil { - return md.db.Insert(md.tables, dataMap) + return md.db.Replace(md.tables, data) } else { - return md.tx.Insert(md.tables, dataMap) + return md.tx.Replace(md.tables, data) } } return nil, errors.New("replacing into table with invalid data type") @@ -215,6 +283,7 @@ func (md *Model) Save() (result sql.Result, err error) { if err == nil { md.checkAndRemoveCache() } + md.clear() }() if md.data == nil { return nil, errors.New("replacing into table with empty data") @@ -225,16 +294,24 @@ func (md *Model) Save() (result sql.Result, err error) { if md.batch > 0 { batch = md.batch } + if md.filter { + for k, m := range list { + list[k] = md.db.filterFields(md.tables, m) + } + } if md.tx == nil { return md.db.BatchSave(md.tables, list, batch) } else { return md.tx.BatchSave(md.tables, list, batch) } - } else if dataMap, ok := md.data.(Map); ok { + } else if data, ok := md.data.(Map); ok { + if md.filter { + data = md.db.filterFields(md.tables, data) + } if md.tx == nil { - return md.db.Save(md.tables, dataMap) + return md.db.Save(md.tables, data) } else { - return md.tx.Save(md.tables, dataMap) + return md.tx.Save(md.tables, data) } } return nil, errors.New("saving into table with invalid data type") @@ -246,10 +323,18 @@ func (md *Model) Update() (result sql.Result, err error) { if err == nil { md.checkAndRemoveCache() } + md.clear() }() if md.data == nil { return nil, errors.New("updating table with empty data") } + if md.filter { + if data, ok := md.data.(Map); ok { + if md.filter { + md.data = md.db.filterFields(md.tables, data) + } + } + } if md.tx == nil { return md.db.Update(md.tables, md.data, md.where, md.whereArgs ...) } else { @@ -263,10 +348,8 @@ func (md *Model) Delete() (result sql.Result, err error) { if err == nil { md.checkAndRemoveCache() } + md.clear() }() - if md.where == "" { - return nil, errors.New("where is required while deleting") - } if md.tx == nil { return md.db.Delete(md.tables, md.where, md.whereArgs...) } else { @@ -298,12 +381,13 @@ func (md *Model) Cache(time int, name ... string) *Model { // 链式操作,select func (md *Model) Select() (Result, error) { - return md.getAll(md.getFormattedSql(), md.whereArgs...) + return md.All() } // 链式操作,查询所有记录 func (md *Model) All() (Result, error) { - return md.Select() + defer md.clear() + return md.getAll(md.getFormattedSql(), md.whereArgs...) } // 链式操作,查询单条记录 @@ -342,6 +426,7 @@ func (md *Model) Struct(obj interface{}) error { // 链式操作,查询数量,fields可以为空,也可以自定义查询字段, // 当给定自定义查询字段时,该字段必须为数量结果,否则会引起歧义,使用如:md.Fields("COUNT(id)") func (md *Model) Count() (int, error) { + defer md.clear() if md.fields == "" || md.fields == "*" { md.fields = "COUNT(1)" } else { @@ -372,7 +457,7 @@ func (md *Model) getAll(sql string, args ...interface{}) (result Result, err err if len(cacheKey) == 0 { cacheKey = sql + "/" + gconv.String(args) } - if v := md.db.cache.Get(cacheKey); v != nil { + if v := md.db.getCache().Get(cacheKey); v != nil { return v.(Result), nil } } @@ -384,9 +469,9 @@ func (md *Model) getAll(sql string, args ...interface{}) (result Result, err err // 查询缓存保存处理 if len(cacheKey) > 0 && err == nil { if md.cacheTime < 0 { - md.db.cache.Remove(cacheKey) + md.db.getCache().Remove(cacheKey) } else { - md.db.cache.Set(cacheKey, result, md.cacheTime*1000) + md.db.getCache().Set(cacheKey, result, md.cacheTime*1000) } } return result, err @@ -395,7 +480,7 @@ func (md *Model) getAll(sql string, args ...interface{}) (result Result, err err // 检查是否需要查询查询缓存 func (md *Model) checkAndRemoveCache() { if md.cacheEnabled && md.cacheTime < 0 && len(md.cacheName) > 0 { - md.db.cache.Remove(md.cacheName) + md.db.getCache().Remove(md.cacheName) } } @@ -424,11 +509,11 @@ func (md *Model) getFormattedSql() string { // @author ymrjqyy // @author 2018-08-15 func (md *Model) Chunk(limit int, callback func(result Result, err error) bool) { - var page = 1 + defer md.clear() + page := 1 for { md.ForPage(page, limit) - sqls := md.getFormattedSql() - data, err := md.getAll(sqls, md.whereArgs...) + data, err := md.getAll(md.getFormattedSql(), md.whereArgs...) if err != nil { callback(nil, err) break diff --git a/g/database/gdb/gdb_mssql.go b/g/database/gdb/gdb_mssql.go index 4ffbecc9a..ba1f7a8ab 100644 --- a/g/database/gdb/gdb_mssql.go +++ b/g/database/gdb/gdb_mssql.go @@ -22,20 +22,19 @@ import ( ) -var linkMssql = &dbmssql{} - // 数据库链接对象 -type dbmssql struct { - Db +type dbMssql struct { + *dbBase } // 创建SQL操作对象 -func (db *dbmssql) Open(c *ConfigNode) (*sql.DB, error) { - var source string - if c.Linkinfo != "" { - source = c.Linkinfo +func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) { + source := "" + if config.Linkinfo != "" { + source = config.Linkinfo } else { - source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable", c.User, c.Pass, c.Host, c.Port, c.Name) + source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable", + config.User, config.Pass, config.Host, config.Port, config.Name) } if db, err := sql.Open("sqlserver", source); err == nil { return db, nil @@ -44,43 +43,38 @@ func (db *dbmssql) Open(c *ConfigNode) (*sql.DB, error) { } } -// 获得关键字操作符 - 左 -func (db *dbmssql) getQuoteCharLeft() string { - return "\"" -} - -// 获得关键字操作符 - 右 -func (db *dbmssql) getQuoteCharRight() string { - return "\"" +// 获得关键字操作符 +func (db *dbMssql) getChars () (charLeft string, charRight string) { + return "\"", "\"" } // 在执行sql之前对sql进行进一步处理 -func (db *dbmssql) handleSqlBeforeExec(q *string) *string { +func (db *dbMssql) handleSqlBeforeExec(query string) string { index := 0 - str, _ := gregex.ReplaceStringFunc("\\?", *q, func(s string) string { + str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string { index++ return fmt.Sprintf("@p%d", index) }) str, _ = gregex.ReplaceString("\"", "", str) - return db.parseSql(&str) + return db.parseSql(str) } //将MYSQL的SQL语法转换为MSSQL的语法 //1.由于mssql不支持limit写法所以需要对mysql中的limit用法做转换 -func (db *dbmssql) parseSql(sql *string) *string { +func (db *dbMssql) parseSql(sql string) string { //下面的正则表达式匹配出SELECT和INSERT的关键字后分别做不同的处理,如有LIMIT则将LIMIT的关键字也匹配出 patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))` - if gregex.IsMatchString(patten, *sql) == false { + if gregex.IsMatchString(patten, sql) == false { fmt.Println("not matched..") return sql } - res, err := gregex.MatchAllString(patten, *sql) + res, err := gregex.MatchAllString(patten, sql) if err != nil { fmt.Println("MatchString error.", err) - return nil + return "" } index := 0 @@ -96,17 +90,17 @@ func (db *dbmssql) parseSql(sql *string) *string { } //不含LIMIT则不处理 - if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) == false { + if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) == false { break } //判断SQL中是否含有order by selectStr := "" orderbyStr := "" - haveOrderby := gregex.IsMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", *sql) + haveOrderby := gregex.IsMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql) if haveOrderby { //取order by 前面的字符串 - queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)ORDER BY)", *sql) + queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql) if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "ORDER BY") == false{ break @@ -114,13 +108,13 @@ func (db *dbmssql) parseSql(sql *string) *string { selectStr = queryExpr[2] //取order by表达式的值 - orderbyExpr, _ := gregex.MatchString("((?i)ORDER BY)(.+)((?i)LIMIT)", *sql) + orderbyExpr, _ := gregex.MatchString("((?i)ORDER BY)(.+)((?i)LIMIT)", sql) if len(orderbyExpr) != 4 || strings.EqualFold(orderbyExpr[1], "ORDER BY") == false || strings.EqualFold(orderbyExpr[3], "LIMIT") == false{ break } orderbyStr = orderbyExpr[2] } else { - queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) + queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "LIMIT") == false{ break } @@ -144,14 +138,14 @@ func (db *dbmssql) parseSql(sql *string) *string { } if haveOrderby { - *sql = fmt.Sprintf("SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d", orderbyStr, selectStr, first, limit) + sql = fmt.Sprintf("SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d", orderbyStr, selectStr, first, limit) } else { if first == 0 { first = limit } else { first = limit - first } - *sql = fmt.Sprintf("SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ", first, limit, selectStr) + sql = fmt.Sprintf("SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ", first, limit, selectStr) } default: } diff --git a/g/database/gdb/gdb_mysql.go b/g/database/gdb/gdb_mysql.go index 76775312b..5ddfcc3c7 100644 --- a/g/database/gdb/gdb_mysql.go +++ b/g/database/gdb/gdb_mysql.go @@ -12,22 +12,18 @@ import ( "database/sql" ) -// MySQL接口对象 -var linkMysql = &dbmysql{} - - // 数据库链接对象 -type dbmysql struct { - Db +type dbMysql struct { + *dbBase } // 创建SQL操作对象,内部采用了lazy link处理 -func (db *dbmysql) Open (c *ConfigNode) (*sql.DB, error) { +func (db *dbMysql) Open (config *ConfigNode) (*sql.DB, error) { var source string - if c.Linkinfo != "" { - source = c.Linkinfo + if config.Linkinfo != "" { + source = config.Linkinfo } else { - source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", c.User, c.Pass, c.Host, c.Port, c.Name) + source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true", config.User, config.Pass, config.Host, config.Port, config.Name) } if db, err := sql.Open("mysql", source); err == nil { return db, nil @@ -36,17 +32,12 @@ func (db *dbmysql) Open (c *ConfigNode) (*sql.DB, error) { } } -// 获得关键字操作符 - 左 -func (db *dbmysql) getQuoteCharLeft () string { - return "`" -} - -// 获得关键字操作符 - 右 -func (db *dbmysql) getQuoteCharRight () string { - return "`" +// 获得关键字操作符 +func (db *dbMysql) getChars () (charLeft string, charRight string) { + return "`", "`" } // 在执行sql之前对sql进行进一步处理 -func (db *dbmysql) handleSqlBeforeExec(q *string) *string { - return q +func (db *dbMysql) handleSqlBeforeExec(query string) string { + return query } \ No newline at end of file diff --git a/g/database/gdb/gdb_oracle.go b/g/database/gdb/gdb_oracle.go index 93ad1ac56..965a4f5c1 100644 --- a/g/database/gdb/gdb_oracle.go +++ b/g/database/gdb/gdb_oracle.go @@ -21,20 +21,18 @@ import ( "strings" ) -var linkOracle = &dboracle{} - // 数据库链接对象 -type dboracle struct { - Db +type dbOracle struct { + *dbBase } // 创建SQL操作对象 -func (db *dboracle) Open(c *ConfigNode) (*sql.DB, error) { +func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) { var source string - if c.Linkinfo != "" { - source = c.Linkinfo + if config.Linkinfo != "" { + source = config.Linkinfo } else { - source = fmt.Sprintf("%s/%s@%s", c.User, c.Pass, c.Name) + source = fmt.Sprintf("%s/%s@%s", config.User, config.Pass, config.Name) } if db, err := sql.Open("oci8", source); err == nil { return db, nil @@ -43,42 +41,37 @@ func (db *dboracle) Open(c *ConfigNode) (*sql.DB, error) { } } -// 获得关键字操作符 - 左 -func (db *dboracle) getQuoteCharLeft() string { - return "\"" -} - -// 获得关键字操作符 - 右 -func (db *dboracle) getQuoteCharRight() string { - return "\"" +// 获得关键字操作符 +func (db *dbOracle) getChars () (charLeft string, charRight string) { + return "\"", "\"" } // 在执行sql之前对sql进行进一步处理 -func (db *dboracle) handleSqlBeforeExec(q *string) *string { +func (db *dbOracle) handleSqlBeforeExec(query string) string { index := 0 - str, _ := gregex.ReplaceStringFunc("\\?", *q, func(s string) string { + str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string { index++ return fmt.Sprintf(":%d", index) }) str, _ = gregex.ReplaceString("\"", "", str) - return db.parseSql(&str) + return db.parseSql(str) } //由于ORACLE中对LIMIT和批量插入的语法与MYSQL不一致,所以这里需要对LIMIT和批量插入做语法上的转换 -func (db *dboracle) parseSql(sql *string) *string { +func (db *dbOracle) parseSql(sql string) string { //下面的正则表达式匹配出SELECT和INSERT的关键字后分别做不同的处理,如有LIMIT则将LIMIT的关键字也匹配出 patten := `^\s*(?i)(SELECT)|(INSERT)|(LIMIT\s*(\d+)\s*,\s*(\d+))` - if gregex.IsMatchString(patten, *sql) == false { + if gregex.IsMatchString(patten, sql) == false { fmt.Println("not matched..") return sql } - res, err := gregex.MatchAllString(patten, *sql) + res, err := gregex.MatchAllString(patten, sql) if err != nil { fmt.Println("MatchString error.", err) - return nil + return "" } index := 0 @@ -94,11 +87,11 @@ func (db *dboracle) parseSql(sql *string) *string { } //取limit前面的字符串 - if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) == false { + if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) == false { break } - queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) + queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "LIMIT") == false{ break } @@ -118,10 +111,10 @@ func (db *dboracle) parseSql(sql *string) *string { } //也可以使用between,据说这种写法的性能会比between好点,里层SQL中的ROWNUM_ >= limit可以缩小查询后的数据集规模 - *sql = fmt.Sprintf("SELECT * FROM (SELECT GFORM.*, ROWNUM ROWNUM_ FROM (%s %s) GFORM WHERE ROWNUM <= %d) WHERE ROWNUM_ >= %d", queryExpr[1], queryExpr[2], limit, first) + sql = fmt.Sprintf("SELECT * FROM (SELECT GFORM.*, ROWNUM ROWNUM_ FROM (%s %s) GFORM WHERE ROWNUM <= %d) WHERE ROWNUM_ >= %d", queryExpr[1], queryExpr[2], limit, first) case "INSERT": //获取VALUE的值,匹配所有带括号的值,会将INSERT INTO后的值匹配到,所以下面的判断语句会判断数组长度是否小于3 - valueExpr, err := gregex.MatchAllString(`(\s*\(([^\(\)]*)\))`, *sql) + valueExpr, err := gregex.MatchAllString(`(\s*\(([^\(\)]*)\))`, sql) if err != nil { return sql } @@ -132,17 +125,17 @@ func (db *dboracle) parseSql(sql *string) *string { } //获取INTO后面的值 - tableExpr, err := gregex.MatchString(`(?i)\s*(INTO\s+\w+\(([^\(\)]*)\))`, *sql) + tableExpr, err := gregex.MatchString(`(?i)\s*(INTO\s+\w+\(([^\(\)]*)\))`, sql) if err != nil { return sql } tableExpr[0] = strings.TrimSpace(tableExpr[0]) - *sql = "INSERT ALL" + sql = "INSERT ALL" for i := 1; i < len(valueExpr); i++ { - *sql += fmt.Sprintf(" %s VALUES%s", tableExpr[0], strings.TrimSpace(valueExpr[i][0])) + sql += fmt.Sprintf(" %s VALUES%s", tableExpr[0], strings.TrimSpace(valueExpr[i][0])) } - *sql += " SELECT 1 FROM DUAL" + sql += " SELECT 1 FROM DUAL" default: } diff --git a/g/database/gdb/gdb_pgsql.go b/g/database/gdb/gdb_pgsql.go index a16a1efe4..13fc77cad 100644 --- a/g/database/gdb/gdb_pgsql.go +++ b/g/database/gdb/gdb_pgsql.go @@ -18,22 +18,18 @@ import ( // _ "gitee.com/johng/gf/third/github.com/lib/pq" // @todo 需要完善replace和save的操作覆盖 -// PostgreSQL接口对象 -var linkPgsql = &dbpgsql{} - - // 数据库链接对象 -type dbpgsql struct { - Db +type dbPgsql struct { + *dbBase } // 创建SQL操作对象,内部采用了lazy link处理 -func (db *dbpgsql) Open (c *ConfigNode) (*sql.DB, error) { +func (db *dbPgsql) Open (config *ConfigNode) (*sql.DB, error) { var source string - if c.Linkinfo != "" { - source = c.Linkinfo + if config.Linkinfo != "" { + source = config.Linkinfo } else { - source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", c.User, c.Pass, c.Host, c.Port, c.Name) + source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", config.User, config.Pass, config.Host, config.Port, config.Name) } if db, err := sql.Open("postgres", source); err == nil { return db, nil @@ -42,23 +38,18 @@ func (db *dbpgsql) Open (c *ConfigNode) (*sql.DB, error) { } } -// 获得关键字操作符 - 左 -func (db *dbpgsql) getQuoteCharLeft () string { - return "\"" -} - -// 获得关键字操作符 - 右 -func (db *dbpgsql) getQuoteCharRight () string { - return "\"" +// 获得关键字操作符 +func (db *dbPgsql) getChars () (charLeft string, charRight string) { + return "\"", "\"" } // 在执行sql之前对sql进行进一步处理 -func (db *dbpgsql) handleSqlBeforeExec(q *string) *string { +func (db *dbPgsql) handleSqlBeforeExec(query string) string { reg := regexp.MustCompile("\\?") index := 0 - str := reg.ReplaceAllStringFunc(*q, func (s string) string { + str := reg.ReplaceAllStringFunc(query, func (s string) string { index ++ return fmt.Sprintf("$%d", index) }) - return &str + return str } \ No newline at end of file diff --git a/g/database/gdb/gdb_sqlite.go b/g/database/gdb/gdb_sqlite.go index 0a92972bb..70029bdc8 100644 --- a/g/database/gdb/gdb_sqlite.go +++ b/g/database/gdb/gdb_sqlite.go @@ -16,20 +16,18 @@ import ( // Sqlite接口对象 // @author wxkj -var linkSqlite = &dbsqlite{} - // 数据库链接对象 -type dbsqlite struct { - Db +type dbSqlite struct { + *dbBase } -func (db *dbsqlite) Open(c *ConfigNode) (*sql.DB, error) { +func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) { var source string - if c.Linkinfo != "" { - source = c.Linkinfo + if config.Linkinfo != "" { + source = config.Linkinfo } else { - source = c.Name + source = config.Name } if db, err := sql.Open("sqlite3", source); err == nil { return db, nil @@ -38,20 +36,14 @@ func (db *dbsqlite) Open(c *ConfigNode) (*sql.DB, error) { } } -// 获得关键字操作符 - 左 -func (db *dbsqlite) getQuoteCharLeft() string { - return "`" -} - -// 获得关键字操作符 - 右 -func (db *dbsqlite) getQuoteCharRight() string { - return "`" +// 获得关键字操作符 +func (db *dbSqlite) getChars () (charLeft string, charRight string) { + return "`", "`" } // 在执行sql之前对sql进行进一步处理 // @todo 需要增加对Save方法的支持,可使用正则来实现替换, // @todo 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE) -func (db *dbsqlite) handleSqlBeforeExec(q *string) *string { - - return q +func (db *dbSqlite) handleSqlBeforeExec(query string) string { + return query } \ No newline at end of file diff --git a/g/database/gdb/gdb_transaction.go b/g/database/gdb/gdb_transaction.go index 692b3bb5b..b513189db 100644 --- a/g/database/gdb/gdb_transaction.go +++ b/g/database/gdb/gdb_transaction.go @@ -7,128 +7,55 @@ package gdb import ( - "fmt" - "errors" - "strings" - "reflect" "database/sql" - "gitee.com/johng/gf/g/os/gtime" - "gitee.com/johng/gf/g/util/gconv" + "gitee.com/johng/gf/g/util/gregex" _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" - "gitee.com/johng/gf/g/container/gvar" ) // 数据库事务对象 -type Tx struct { - db *Db +type TX struct { + db DB tx *sql.Tx master *sql.DB } // 事务操作,提交 -func (tx *Tx) Commit() error { +func (tx *TX) Commit() error { return tx.tx.Commit() } // 事务操作,回滚 -func (tx *Tx) Rollback() error { +func (tx *TX) Rollback() error { return tx.tx.Rollback() } // (事务)数据库sql查询操作,主要执行查询 -func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { - var err error - var rows *sql.Rows - p := tx.db.link.handleSqlBeforeExec(&query) - if tx.db.debug.Val() { - militime1 := gtime.Millisecond() - rows, err = tx.tx.Query(*p, args ...) - militime2 := gtime.Millisecond() - s := &Sql{ - Sql : *p, - Args : args, - Error : err, - Start : militime1, - End : militime2, - Func : "TX:Query", - } - tx.db.sqls.Put(s) - tx.db.printSql(s) - } else { - rows, err = tx.tx.Query(*p, args ...) - } - if err == nil { - return rows, nil - } else { - err = tx.db.formatError(err, p, args...) - } - return nil, err +func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { + return tx.db.doQuery(tx.tx, query, args...) } // (事务)执行一条sql,并返回执行情况,主要用于非查询操作 -func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { - var err error - var result sql.Result - p := tx.db.link.handleSqlBeforeExec(&query) - if tx.db.debug.Val() { - militime1 := gtime.Millisecond() - result, err = tx.tx.Exec(*p, args ...) - militime2 := gtime.Millisecond() - s := &Sql{ - Sql : *p, - Args : args, - Error : err, - Start : militime1, - End : militime2, - Func : "TX:Exec", - } - tx.db.sqls.Put(s) - tx.db.printSql(s) - } else { - result, err = tx.tx.Exec(*p, args ...) - } - return result, tx.db.formatError(err, p, args...) +func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) { + return tx.db.doExec(tx.tx, query, args...) +} + +// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 +func (tx *TX) Prepare(query string) (*sql.Stmt, error) { + return tx.db.doPrepare(tx.tx, query) } // 数据库查询,获取查询结果集,以列表结构返回 -func (tx *Tx) GetAll(query string, args ...interface{}) (Result, error) { - // 执行sql +func (tx *TX) GetAll(query string, args ...interface{}) (Result, error) { rows, err := tx.Query(query, args ...) if err != nil || rows == nil { return nil, err } - // 列名称列表 - columns, err := rows.Columns() - if err != nil { - return nil, err - } - // 返回结构组装 - values := make([]sql.RawBytes, len(columns)) - scanArgs := make([]interface{}, len(values)) - records := make(Result, 0) - for i := range values { - scanArgs[i] = &values[i] - } - for rows.Next() { - err = rows.Scan(scanArgs...) - if err != nil { - return records, err - } - row := make(Record) - // 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址 - for i, col := range values { - v := make([]byte, len(col)) - copy(v, col) - row[columns[i]] = gvar.New(v, false) - } - //fmt.Printf("%p\n", row["typeid"]) - records = append(records, row) - } - return records, nil + defer rows.Close() + return rowsToResult(rows) } // 数据库查询,获取查询结果记录,以关联数组结构返回 -func (tx *Tx) GetOne(query string, args ...interface{}) (Record, error) { +func (tx *TX) GetOne(query string, args ...interface{}) (Record, error) { list, err := tx.GetAll(query, args ...) if err != nil { return nil, err @@ -140,7 +67,7 @@ func (tx *Tx) GetOne(query string, args ...interface{}) (Record, error) { } // 数据库查询,获取查询结果记录,自动映射数据到给定的struct对象中 -func (tx *Tx) GetStruct(obj interface{}, query string, args ...interface{}) error { +func (tx *TX) GetStruct(obj interface{}, query string, args ...interface{}) error { one, err := tx.GetOne(query, args...) if err != nil { return err @@ -148,9 +75,8 @@ func (tx *Tx) GetStruct(obj interface{}, query string, args ...interface{}) erro return one.ToStruct(obj) } - // 数据库查询,获取查询字段值 -func (tx *Tx) GetValue(query string, args ...interface{}) (Value, error) { +func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) { one, err := tx.GetOne(query, args ...) if err != nil { return nil, err @@ -162,186 +88,55 @@ func (tx *Tx) GetValue(query string, args ...interface{}) (Value, error) { } // 数据库查询,获取查询数量 -func (tx *Tx) GetCount(query string, args ...interface{}) (int, error) { - val, err := tx.GetValue(query, args ...) +func (tx *TX) GetCount(query string, args ...interface{}) (int, error) { + if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) { + query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query) + } + value, err := tx.GetValue(query, args ...) if err != nil { return 0, err } - return gconv.Int(val), nil -} - -// 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作 -func (tx *Tx) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) { - s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables) - if condition != nil { - s += fmt.Sprintf("WHERE %s ", tx.db.formatCondition(condition)) - } - if len(groupBy) > 0 { - s += fmt.Sprintf("GROUP BY %s ", groupBy) - } - if len(orderBy) > 0 { - s += fmt.Sprintf("ORDER BY %s ", orderBy) - } - if limit > 0 { - s += fmt.Sprintf("LIMIT %d,%d ", first, limit) - } - return tx.GetAll(s, args ... ) -} - -// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 -func (tx *Tx) Prepare(query string) (*sql.Stmt, error) { - return tx.tx.Prepare(query) -} - -// insert、replace, save, ignore操作 -// 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 -// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做 -func (tx *Tx) insert(table string, data Map, option uint8) (sql.Result, error) { - var keys []string - var values []string - var params []interface{} - for k, v := range data { - keys = append(keys, tx.db.charl + k + tx.db.charr) - values = append(values, "?") - params = append(params, v) - } - operation := tx.db.getInsertOperationByOption(option) - updatestr := "" - if option == OPTION_SAVE { - var updates []string - for k, _ := range data { - updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k)) - } - updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) - } - return tx.Exec( - fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s", - operation, table, strings.Join(keys, ","), - strings.Join(values, ","), - updatestr), - params... - ) + return value.Int(), nil } // CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 -func (tx *Tx) Insert(table string, data Map) (sql.Result, error) { - return tx.insert(table, data, OPTION_INSERT) +func (tx *TX) Insert(table string, data Map) (sql.Result, error) { + return tx.db.doInsert(tx.tx, table, data, OPTION_INSERT) } // CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -func (tx *Tx) Replace(table string, data Map) (sql.Result, error) { - return tx.insert(table, data, OPTION_REPLACE) +func (tx *TX) Replace(table string, data Map) (sql.Result, error) { + return tx.db.doInsert(tx.tx, table, data, OPTION_REPLACE) } // CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -func (tx *Tx) Save(table string, data Map) (sql.Result, error) { - return tx.insert(table, data, OPTION_SAVE) -} - -// 批量写入数据 -func (tx *Tx) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) { - var keys []string - var values []string - var bvalues []string - var params []interface{} - var result sql.Result - var size = len(list) - // 判断长度 - if size < 1 { - return result, errors.New("empty data list") - } - // 首先获取字段名称及记录长度 - for k, _ := range list[0] { - keys = append(keys, k) - values = append(values, "?") - } - keyStr := tx.db.charl + strings.Join(keys, tx.db.charl + "," + tx.db.charr) + tx.db.charr - valueHolderStr := "(" + strings.Join(values, ",") + ")" - // 操作判断 - operation := tx.db.getInsertOperationByOption(option) - updatestr := "" - if option == OPTION_SAVE { - var updates []string - for _, k := range keys { - updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k)) - } - updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) - } - // 构造批量写入数据格式(注意map的遍历是无序的) - for i := 0; i < size; i++ { - for _, k := range keys { - params = append(params, list[i][k]) - } - bvalues = append(bvalues, valueHolderStr) - if len(bvalues) == batch { - r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", - operation, table, keyStr, strings.Join(bvalues, ","), - updatestr), - params...) - if err != nil { - return result, err - } - result = r - params = params[:0] - bvalues = bvalues[:0] - } - } - // 处理最后不构成指定批量的数据 - if len(bvalues) > 0 { - r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", - operation, table, keyStr, strings.Join(bvalues, ","), - updatestr), - params...) - if err != nil { - return result, err - } - result = r - } - return result, nil +func (tx *TX) Save(table string, data Map) (sql.Result, error) { + return tx.db.doInsert(tx.tx, table, data, OPTION_SAVE) } // CURD操作:批量数据指定批次量写入 -func (tx *Tx) BatchInsert(table string, list List, batch int) (sql.Result, error) { - return tx.batchInsert(table, list, batch, OPTION_INSERT) +func (tx *TX) BatchInsert(table string, list List, batch int) (sql.Result, error) { + return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_INSERT) } // CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -func (tx *Tx) BatchReplace(table string, list List, batch int) (sql.Result, error) { - return tx.batchInsert(table, list, batch, OPTION_REPLACE) +func (tx *TX) BatchReplace(table string, list List, batch int) (sql.Result, error) { + return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_REPLACE) } // CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -func (tx *Tx) BatchSave(table string, list List, batch int) (sql.Result, error) { - return tx.batchInsert(table, list, batch, OPTION_SAVE) +func (tx *TX) BatchSave(table string, list List, batch int) (sql.Result, error) { + return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_SAVE) } // CURD操作:数据更新,统一采用sql预处理 // data参数支持字符串或者关联数组类型,内部会自行做判断处理 -func (tx *Tx) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { - var params []interface{} - var updates string - refValue := reflect.ValueOf(data) - if refValue.Kind() == reflect.Map { - var fields []string - keys := refValue.MapKeys() - for _, k := range keys { - fields = append(fields, fmt.Sprintf("%s%s%s=?", tx.db.charl, k, tx.db.charr)) - params = append(params, gconv.String(refValue.MapIndex(k).Interface())) - updates = strings.Join(fields, ",") - } - } else { - updates = gconv.String(data) - } - for _, v := range args { - params = append(params, gconv.String(v)) - } - return tx.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, tx.db.formatCondition(condition)), params...) +func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { + return tx.db.doUpdate(tx.tx, table, data, condition, args ...) } // CURD操作:删除数据 -func (tx *Tx) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { - return tx.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, tx.db.formatCondition(condition)), args...) +func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { + return tx.db.doDelete(tx.tx, table, condition, args ...) } diff --git a/g/database/gdb/gdb_unit_1_test.go b/g/database/gdb/gdb_unit_1_test.go new file mode 100644 index 000000000..1755bf63b --- /dev/null +++ b/g/database/gdb/gdb_unit_1_test.go @@ -0,0 +1,215 @@ +package gdb_test + +import ( + "gitee.com/johng/gf/g" + "gitee.com/johng/gf/g/database/gdb" + "gitee.com/johng/gf/g/os/gtime" + "gitee.com/johng/gf/g/util/gtest" + "testing" +) + +var ( + // 数据库对象/接口 + db gdb.DB +) + +// 初始化连接参数。 +// 测试前需要修改连接参数。 +func init() { + gdb.AddDefaultConfigNode(gdb.ConfigNode{ + Host: "127.0.0.1", + Port: "3306", + User: "root", + Pass: "12345678", + Name: "test", + Type: "mysql", + Role: "master", + Charset: "utf8", + Priority: 1, + }) + if r, err := gdb.New(); err != nil { + gtest.Fatal(err) + } else { + db = r + } + // 准备测试数据结构 + if _, err := db.Exec("DROP TABLE `user`"); err != nil { + gtest.Fatal(err) + } + if _, err := db.Exec(` + CREATE TABLE user ( + id int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '用户ID', + passport varchar(45) NOT NULL COMMENT '账号', + password char(32) NOT NULL COMMENT '密码', + nickname varchar(45) NOT NULL COMMENT '昵称', + create_time timestamp NOT NULL COMMENT '创建时间/注册时间', + PRIMARY KEY (id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_Query(t *testing.T) { + if _, err := db.Query("SELECT ?", 1); err != nil { + gtest.Fatal(err) + } + if _, err := db.Query("ERROR"); err == nil { + gtest.Fatal("FAIL") + } +} + +func TestDbBase_Exec(t *testing.T) { + if _, err := db.Exec("SELECT ?", 1); err != nil { + gtest.Fatal(err) + } + if _, err := db.Exec("ERROR"); err == nil { + gtest.Fatal("FAIL") + } +} + +func TestDbBase_Prepare(t *testing.T) { + st, err := db.Prepare("SELECT 100") + if err != nil { + gtest.Fatal(err) + } + rows, err := st.Query() + if err != nil { + gtest.Fatal(err) + } + array, err := rows.Columns() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(array[0], "100") + if err := rows.Close(); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_Insert(t *testing.T) { + if _, err := db.Insert("user", g.Map{ + "id" : 1, + "passport" : "t1", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T1", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_BatchInsert(t *testing.T) { + if _, err := db.BatchInsert("user", g.List { + { + "id" : 2, + "passport" : "t2", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T2", + "create_time" : gtime.Now().String(), + }, + { + "id" : 3, + "passport" : "t3", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T3", + "create_time" : gtime.Now().String(), + }, + }, 10); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_Save(t *testing.T) { + if _, err := db.Save("user", g.Map{ + "id" : 1, + "passport" : "t1", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T11", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_Replace(t *testing.T) { + if _, err := db.Save("user", g.Map{ + "id" : 1, + "passport" : "t1", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T111", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_Update(t *testing.T) { + if result, err := db.Update("user", "create_time='2010-10-10 00:00:01'", "id=3"); err != nil { + gtest.Fatal(err) + } else { + n, _ := result.RowsAffected() + gtest.Assert(n, 1) + } +} + +func TestDbBase_GetAll(t *testing.T) { + if result, err := db.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(len(result), 1) + } +} + +func TestDbBase_GetOne(t *testing.T) { + if record, err := db.GetOne("SELECT * FROM user WHERE passport=?", "t1"); err != nil { + gtest.Fatal(err) + } else { + if record == nil { + gtest.Fatal("FAIL") + } + gtest.Assert(record["nickname"].String(), "T111") + } +} + +func TestDbBase_GetValue(t *testing.T) { + if value, err := db.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.Int(), 3) + } +} + +func TestDbBase_GetCount(t *testing.T) { + if count, err := db.GetCount("SELECT * FROM user"); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(count, 3) + } +} + +func TestDbBase_GetStruct(t *testing.T) { + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime gtime.Time + } + user := new(User) + if err := db.GetStruct(user, "SELECT * FROM user WHERE id=?", 3); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(user.CreateTime.String(), "2010-10-10 00:00:01") + } +} + +func TestDbBase_Delete(t *testing.T) { + if result, err := db.Delete("user", nil); err != nil { + gtest.Fatal(err) + } else { + n, _ := result.RowsAffected() + gtest.Assert(n, 3) + } +} + diff --git a/g/database/gdb/gdb_unit_2_test.go b/g/database/gdb/gdb_unit_2_test.go new file mode 100644 index 000000000..469db382f --- /dev/null +++ b/g/database/gdb/gdb_unit_2_test.go @@ -0,0 +1,180 @@ +package gdb_test + +import ( + "gitee.com/johng/gf/g" + "gitee.com/johng/gf/g/os/gtime" + "gitee.com/johng/gf/g/util/gtest" + "testing" +) + +func TestModel_Insert(t *testing.T) { + result, err := db.Table("user").Filter().Data(g.Map{ + "id" : 1, + "uid" : 1, + "passport" : "t1", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T1", + "create_time" : gtime.Now().String(), + }).Insert() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.LastInsertId() + gtest.Assert(n, 1) +} + +func TestModel_Batch(t *testing.T) { + result, err := db.Table("user").Filter().Data(g.List{ + { + "id" : 2, + "uid" : 2, + "passport" : "t2", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T2", + "create_time" : gtime.Now().String(), + }, + { + "id" : 3, + "uid" : 3, + "passport" : "t3", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T3", + "create_time" : gtime.Now().String(), + }, + }).Batch(10).Insert() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.RowsAffected() + gtest.Assert(n, 2) +} + +func TestModel_Replace(t *testing.T) { + result, err := db.Table("user").Data(g.Map{ + "id" : 1, + "passport" : "t11", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T11", + "create_time" : gtime.Now().String(), + }).Replace() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.RowsAffected() + gtest.Assert(n, 2) +} + +func TestModel_Save(t *testing.T) { + result, err := db.Table("user").Data(g.Map{ + "id" : 1, + "passport" : "t111", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T111", + "create_time" : gtime.Now().String(), + }).Save() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.RowsAffected() + gtest.Assert(n, 2) +} + +func TestModel_Update(t *testing.T) { + result, err := db.Table("user").Data("passport", "t22").Where("passport=?", "t2").Update() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.RowsAffected() + gtest.Assert(n, 1) +} + +func TestModel_All(t *testing.T) { + result, err := db.Table("user").All() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(len(result), 3) +} + +func TestModel_One(t *testing.T) { + record, err := db.Table("user").Where("id", 1).One() + if err != nil { + gtest.Fatal(err) + } + if record == nil { + gtest.Fatal("FAIL") + } + gtest.Assert(record["nickname"].String(), "T111") +} + +func TestModel_Value(t *testing.T) { + value, err := db.Table("user").Fields("nickname").Where("id", 1).Value() + if err != nil { + gtest.Fatal(err) + } + if value == nil { + gtest.Fatal("FAIL") + } + gtest.Assert(value.String(), "T111") +} + +func TestModel_Count(t *testing.T) { + count, err := db.Table("user").Count() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(count, 3) +} + +func TestModel_Select(t *testing.T) { + result, err := db.Table("user").Select() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(len(result), 3) +} + +func TestModel_Struct(t *testing.T) { + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime gtime.Time + } + user := new(User) + err := db.Table("user").Where("id=1").Struct(user) + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(user.NickName, "T111") +} + +func TestModel_OrderBy(t *testing.T) { + result, err := db.Table("user").OrderBy("id DESC").Select() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(len(result), 3) + gtest.Assert(result[0]["nickname"].String(), "T3") +} + +func TestModel_GroupBy(t *testing.T) { + result, err := db.Table("user").GroupBy("id").Select() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(len(result), 3) + gtest.Assert(result[0]["nickname"].String(), "T111") +} + +func TestModel_Delete(t *testing.T) { + result, err := db.Table("user").Delete() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.RowsAffected() + gtest.Assert(n, 3) +} + + diff --git a/g/database/gdb/gdb_unit_3_test.go b/g/database/gdb/gdb_unit_3_test.go new file mode 100644 index 000000000..3180c906d --- /dev/null +++ b/g/database/gdb/gdb_unit_3_test.go @@ -0,0 +1,372 @@ +package gdb_test + +import ( + "gitee.com/johng/gf/g" + "gitee.com/johng/gf/g/os/gtime" + "gitee.com/johng/gf/g/util/gtest" + "testing" +) + +func TestTX_Query(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if rows, err := tx.Query("SELECT ?", 1); err != nil { + gtest.Fatal(err) + } else { + rows.Close() + } + if _, err := tx.Query("ERROR"); err == nil { + gtest.Fatal("FAIL") + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Exec(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.Exec("SELECT ?", 1); err != nil { + gtest.Fatal(err) + } + if _, err := tx.Exec("ERROR"); err == nil { + gtest.Fatal("FAIL") + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Commit(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Rollback(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if err := tx.Rollback(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Prepare(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + st, err := tx.Prepare("SELECT 100") + if err != nil { + gtest.Fatal(err) + } + rows, err := st.Query() + if err != nil { + gtest.Fatal(err) + } + array, err := rows.Columns() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(array[0], "100") + if err := rows.Close(); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Insert(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.Insert("user", g.Map { + "id" : 1, + "passport" : "t1", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T1", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + if n, err := db.Table("user").Count(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(n, 1) + } +} + +func TestTX_BatchInsert(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.BatchInsert("user", g.List { + { + "id" : 2, + "passport" : "t", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T2", + "create_time" : gtime.Now().String(), + }, + { + "id" : 3, + "passport" : "t3", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T3", + "create_time" : gtime.Now().String(), + }, + }, 10); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + if n, err := db.Table("user").Count(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(n, 3) + } +} + +func TestTX_BatchReplace(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.BatchReplace("user", g.List { + { + "id" : 2, + "passport" : "t2", + "password" : "p2", + "nickname" : "T2", + "create_time" : gtime.Now().String(), + }, + { + "id" : 4, + "passport" : "t4", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T4", + "create_time" : gtime.Now().String(), + }, + }, 10); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + // 数据数量 + if n, err := db.Table("user").Count(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(n, 4) + } + // 检查replace后的数值 + if value, err := db.Table("user").Fields("password").Where("id", 2).Value(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.String(), "p2") + } +} + +func TestTX_BatchSave(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.BatchSave("user", g.List { + { + "id" : 4, + "passport" : "t4", + "password" : "p4", + "nickname" : "T4", + "create_time" : gtime.Now().String(), + }, + }, 10); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + // 数据数量 + if n, err := db.Table("user").Count(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(n, 4) + } + // 检查replace后的数值 + if value, err := db.Table("user").Fields("password").Where("id", 4).Value(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.String(), "p4") + } +} + +func TestTX_Replace(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.Replace("user", g.Map { + "id" : 1, + "passport" : "t11", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T11", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } + if err := tx.Rollback(); err != nil { + gtest.Fatal(err) + } + if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.String(), "T1") + } +} + +func TestTX_Save(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.Save("user", g.Map { + "id" : 1, + "passport" : "t11", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T11", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.String(), "T11") + } +} + +func TestTX_GetAll(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if result, err := tx.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(len(result), 1) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_GetOne(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if record, err := tx.GetOne("SELECT * FROM user WHERE passport=?", "t2"); err != nil { + gtest.Fatal(err) + } else { + if record == nil { + gtest.Fatal("FAIL") + } + gtest.Assert(record["nickname"].String(), "T2") + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_GetValue(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if value, err := tx.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.Int(), 3) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_GetCount(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if count, err := tx.GetCount("SELECT * FROM user"); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(count, 4) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_GetStruct(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime gtime.Time + } + user := new(User) + if err := tx.GetStruct(user, "SELECT * FROM user WHERE id=?", 1); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(user.NickName, "T11") + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Delete(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.Delete("user", nil); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + if n, err := db.Table("user").Count(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(n, 0) + } +} + + diff --git a/g/frame/gins/gins.go b/g/frame/gins/gins.go index 0c8ba867c..66598e940 100644 --- a/g/frame/gins/gins.go +++ b/g/frame/gins/gins.go @@ -120,7 +120,7 @@ func Config(file...string) *gcfg.Config { } // 数据库操作对象,使用了连接池 -func Database(name...string) *gdb.Db { +func Database(name...string) gdb.DB { config := Config() group := gdb.DEFAULT_GROUP_NAME if len(name) > 0 { @@ -195,7 +195,7 @@ func Database(name...string) *gdb.Db { return nil }) if db != nil { - return db.(*gdb.Db) + return db.(gdb.DB) } return nil } diff --git a/g/g.go b/g/g.go index 94a4d5601..e73b91b6e 100644 --- a/g/g.go +++ b/g/g.go @@ -10,14 +10,27 @@ package g import "gitee.com/johng/gf/g/container/gvar" // 框架动态变量,可以用该类型替代interface{}类型 -type Var = gvar.Var +type Var = gvar.Var // 常用map数据结构(使用别名) -type Map = map[string]interface{} +type Map = map[string]interface{} +type MapStrStr = map[string]string +type MapStrInt = map[string]int +type MapIntStr = map[int]string +type MapIntInt = map[int]int // 常用list数据结构(使用别名) -type List = []Map +type List = []Map +type ListStrStr = []map[string]string +type ListStrInt = []map[string]int +type ListIntStr = []map[int]string +type ListIntInt = []map[int]int + // 常用slice数据结构(使用别名) -type Slice = []interface{} -type Array = Slice +type Slice = []interface{} +type SliceStr = []string +type SliceInt = []int +type Array = Slice +type ArrayStr = SliceStr +type ArrayInt = SliceInt diff --git a/g/g_object.go b/g/g_object.go index ff298fdb4..012cf1a16 100644 --- a/g/g_object.go +++ b/g/g_object.go @@ -44,12 +44,12 @@ func Config(file...string) *gcfg.Config { } // 数据库操作对象,使用了连接池 -func Database(name...string) *gdb.Db { +func Database(name...string) gdb.DB { return gins.Database(name...) } // (别名)Database -func DB(name...string) *gdb.Db { +func DB(name...string) gdb.DB { return gins.Database(name...) } diff --git a/g/net/ghttp/ghttp_server.go b/g/net/ghttp/ghttp_server.go index d79125abe..40f340382 100644 --- a/g/net/ghttp/ghttp_server.go +++ b/g/net/ghttp/ghttp_server.go @@ -15,6 +15,7 @@ import ( "gitee.com/johng/gf/g/container/gtype" "gitee.com/johng/gf/g/os/gcache" "gitee.com/johng/gf/g/os/genv" + "gitee.com/johng/gf/g/os/gfile" "gitee.com/johng/gf/g/os/glog" "gitee.com/johng/gf/g/os/gproc" "gitee.com/johng/gf/g/os/gtime" @@ -254,6 +255,10 @@ func (s *Server) Start() error { } }) } + // 是否处于开发环境 + if gfile.MainPkgPath() != "" { + glog.Backtrace(false, 0).Notice("GF notices that you're in develop environment, so error logs are auto enabled to stdout.") + } // 打印展示路由表 s.DumpRoutesMap() diff --git a/g/net/ghttp/ghttp_server_log.go b/g/net/ghttp/ghttp_server_log.go index a4cd12f0c..62ace6fda 100644 --- a/g/net/ghttp/ghttp_server_log.go +++ b/g/net/ghttp/ghttp_server_log.go @@ -9,6 +9,7 @@ package ghttp import ( "fmt" + "gitee.com/johng/gf/g/os/gfile" "net/http" ) @@ -36,7 +37,7 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) { r.Response.WriteStatus(http.StatusInternalServerError) // 错误输出默认是开启的 - if !s.IsErrorLogEnabled() { + if !s.IsErrorLogEnabled() && gfile.MainPkgPath() == "" { return } @@ -56,5 +57,9 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) { s.logger.Cat("error").Backtrace(true, 2).StdPrint(true).Error(content) } else { s.logger.Cat("error").Backtrace(true, 2).Error(content) + // 开发环境下(MainPkgPath)自动输出错误信息到标准输出 + if gfile.MainPkgPath() != "" { + s.logger.Cat("error").Backtrace(true, 2).StdPrint(true).Error(content) + } } } diff --git a/g/os/gfile/gfile.go b/g/os/gfile/gfile.go index 420c8f10a..23f82c3d3 100644 --- a/g/os/gfile/gfile.go +++ b/g/os/gfile/gfile.go @@ -370,7 +370,7 @@ func homeWindows() (string, error) { return home, nil } -// 获取入口函数文件所在目录(main包文件目录), +// 获取入口函数文件所在目录(main包文件目录), // **仅对源码开发环境有效(即仅对生成该可执行文件的系统下有效)** func MainPkgPath() string { path := mainPkgPath.Val() @@ -401,6 +401,7 @@ func MainPkgPath() string { if p == f { break } + // 会自动扫描源码,寻找main包 if paths, err := ScanDir(p, "*.go"); err == nil && len(paths) > 0 { for _, path := range paths { if gregex.IsMatchString(`package\s+main`, GetContents(path)) { diff --git a/g/os/glog/glog_logger.go b/g/os/glog/glog_logger.go index 92959f2b1..4b022cf48 100644 --- a/g/os/glog/glog_logger.go +++ b/g/os/glog/glog_logger.go @@ -31,7 +31,7 @@ type Logger struct { file *gtype.String // 日志文件名称格式 level *gtype.Int // 日志输出等级 btSkip *gtype.Int // 错误产生时的backtrace回调信息skip条数 - btEnabled *gtype.Bool // 是否当打印错误时同时开启backtrace打印 + btStatus *gtype.Int // 是否当打印错误时同时开启backtrace打印(默认-1,表示默认打印逻辑 - 错误才打印) printHeader *gtype.Bool // 是否不打印前缀信息(时间,级别等) alsoStdPrint *gtype.Bool // 控制台打印开关,当输出到文件/自定义输出时也同时打印到终端 } @@ -65,7 +65,7 @@ func New() *Logger { file : gtype.NewString(gDEFAULT_FILE_FORMAT), level : gtype.NewInt(defaultLevel.Val()), btSkip : gtype.NewInt(), - btEnabled : gtype.NewBool(true), + btStatus : gtype.NewInt(-1), printHeader : gtype.NewBool(true), alsoStdPrint : gtype.NewBool(true), } @@ -80,7 +80,7 @@ func (l *Logger) Clone() *Logger { file : l.file.Clone(), level : l.level.Clone(), btSkip : l.btSkip.Clone(), - btEnabled : l.btEnabled.Clone(), + btStatus : l.btStatus.Clone(), printHeader : l.printHeader.Clone(), alsoStdPrint : l.alsoStdPrint.Clone(), } @@ -106,7 +106,12 @@ func (l *Logger) SetDebug(debug bool) { } func (l *Logger) SetBacktrace(enabled bool) { - l.btEnabled.Set(enabled) + if enabled { + l.btStatus.Set(1) + } else { + l.btStatus.Set(0) + } + } // 设置BacktraceSkip @@ -214,27 +219,37 @@ func (l *Logger) doStdLockPrint(std io.Writer, s string) { // 核心打印数据方法(标准输出) func (l *Logger) stdPrint(s string) { + if l.btStatus.Val() == 1 { + s = l.appendBacktrace(s) + } l.print(os.Stdout, s) } // 核心打印数据方法(标准错误) func (l *Logger) errPrint(s string) { // 记录调用回溯信息 - if l.btEnabled.Val() { - tracestr := l.GetBacktrace() - if tracestr != "" { - backtrace := "Backtrace:" + ln + tracestr - if s[len(s) - 1] == byte('\n') { - s = s + backtrace + ln - } else { - s = s + ln + backtrace + ln - } - } + status := l.btStatus.Val() + if status == -1 || status == 1 { + s = l.appendBacktrace(s) } // 防止串日志情况,这里不使用stderr,而是使用stdout l.print(os.Stdout, s) } +// 输出内容中添加回溯信息 +func (l *Logger) appendBacktrace(s string) string { + trace := l.GetBacktrace() + if trace != "" { + backtrace := "Backtrace:" + ln + trace + if s[len(s) - 1] == byte('\n') { + s = s + backtrace + ln + } else { + s = s + ln + backtrace + ln + } + } + return s +} + // 直接打印回溯信息,参数skip表示调用端往上多少级开始回溯 func (l *Logger) PrintBacktrace(skip...int) { l.Println(l.GetBacktrace(skip...)) diff --git a/g/util/gconv/gconv_map.go b/g/util/gconv/gconv_map.go index 2bb3f573c..5f046ffc9 100644 --- a/g/util/gconv/gconv_map.go +++ b/g/util/gconv/gconv_map.go @@ -99,8 +99,11 @@ func Map(i interface{}, noTagCheck...bool) map[string]interface{} { rt := rv.Type() name := "" for i := 0; i < rv.NumField(); i++ { - if name = rt.Field(i).Tag.Get("json"); name == "" { - name = rt.Field(i).Name + // 检查json tag + if len(noTagCheck) == 0 || !noTagCheck[0] { + if name = rt.Field(i).Tag.Get("json"); name == "" { + name = rt.Field(i).Name + } } m[name] = rv.Field(i).Interface() } diff --git a/g/util/gconv/gconv_struct.go b/g/util/gconv/gconv_struct.go index 1261065f1..d0cd15dbf 100644 --- a/g/util/gconv/gconv_struct.go +++ b/g/util/gconv/gconv_struct.go @@ -157,12 +157,10 @@ func bindVarToStruct(elem reflect.Value, name string, value interface{}) (err er structFieldValue := elem.FieldByName(name) // 键名与对象属性匹配检测,map中如果有struct不存在的属性,那么不做处理,直接return if !structFieldValue.IsValid() { - //return errors.New(fmt.Sprintf(`invalid struct attribute of name "%s"`, name)) return nil } // CanSet的属性必须为公开属性(首字母大写) if !structFieldValue.CanSet() { - //return errors.New(fmt.Sprintf(`struct attribute of name "%s" cannot be set`, name)) return nil } // 必须将value转换为struct属性的数据类型,这里必须用到gconv包 @@ -181,12 +179,10 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e structFieldValue := elem.FieldByIndex([]int{index}) // 键名与对象属性匹配检测 if !structFieldValue.IsValid() { - //return errors.New(fmt.Sprintf("invalid struct attribute at index %d", index)) return nil } // CanSet的属性必须为公开属性(首字母大写) if !structFieldValue.CanSet() { - //return errors.New(fmt.Sprintf("struct attribute cannot be set at index %d", index)) return nil } // 必须将value转换为struct属性的数据类型,这里必须用到gconv包 diff --git a/g/util/gconv/gconv_time.go b/g/util/gconv/gconv_time.go index fefc07b83..42a13712b 100644 --- a/g/util/gconv/gconv_time.go +++ b/g/util/gconv/gconv_time.go @@ -14,21 +14,26 @@ import ( // 将变量i转换为time.Time类型 func Time(i interface{}, format...string) time.Time { - s := String(i) - // 优先使用用户输入日期格式进行转换 - if len(format) > 0 { - t, _ := gtime.StrToTimeFormat(s, format[0]) - return t.Time - } - if gstr.IsNumeric(s) { - return gtime.NewFromTimeStamp(Int64(s)).Time - } else { - t, _ := gtime.StrToTime(s) - return t.Time - } + return GTime(i, format...).Time } // 将变量i转换为time.Time类型 func TimeDuration(i interface{}) time.Duration { return time.Duration(Int64(i)) +} + +// 将变量i转换为time.Time类型 +func GTime(i interface{}, format...string) *gtime.Time { + s := String(i) + // 优先使用用户输入日期格式进行转换 + if len(format) > 0 { + t, _ := gtime.StrToTimeFormat(s, format[0]) + return t + } + if gstr.IsNumeric(s) { + return gtime.NewFromTimeStamp(Int64(s)) + } else { + t, _ := gtime.StrToTime(s) + return t + } } \ No newline at end of file diff --git a/g/util/gstr/gstr.go b/g/util/gstr/gstr.go index 8decd74ba..1758de12a 100644 --- a/g/util/gstr/gstr.go +++ b/g/util/gstr/gstr.go @@ -13,7 +13,7 @@ import ( "strings" ) -// 字符串替换 +// 字符串替换(大小写敏感) func Replace(origin, search, replace string, count...int) string { n := -1 if len(count) > 0 { @@ -22,7 +22,7 @@ func Replace(origin, search, replace string, count...int) string { return strings.Replace(origin, search, replace, n) } -// 使用map进行字符串替换 +// 使用map进行字符串替换(大小写敏感) func ReplaceByMap(origin string, replaces map[string]string) string { result := origin for k, v := range replaces { diff --git a/g/util/gtest/gtest.go b/g/util/gtest/gtest.go new file mode 100644 index 000000000..2b054afa1 --- /dev/null +++ b/g/util/gtest/gtest.go @@ -0,0 +1,30 @@ +// Copyright 2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://gitee.com/johng/gf. + +// Package gtest provides useful test utils. +// 测试模块. +package gtest + +import ( + "fmt" + "gitee.com/johng/gf/g/os/glog" + "gitee.com/johng/gf/g/util/gconv" + "os" +) + +// 断言判断 +func Assert(value, expect interface{}) { + if gconv.String(value) != gconv.String(expect) { + glog.Backtrace(true, 1).Printfln(`[ASSERT] VALUE: %v, EXPECT: %v`, value, expect) + os.Exit(1) + } +} + +// 提示错误并退出 +func Fatal(message...interface{}) { + glog.Backtrace(true, 1).Println(`[FATAL] `, fmt.Sprint(message...)) + os.Exit(1) +} \ No newline at end of file diff --git a/g/util/gutil/gutil.go b/g/util/gutil/gutil.go index ff315f1ed..5d1ed1288 100644 --- a/g/util/gutil/gutil.go +++ b/g/util/gutil/gutil.go @@ -4,7 +4,7 @@ // If a copy of the MIT was not distributed with this file, // You can obtain one at https://gitee.com/johng/gf. -// 其他工具包 +// 工具包 package gutil import ( diff --git a/g/util/gvalid/gvalid_check_map.go b/g/util/gvalid/gvalid_check_map.go index 3f26bbd4a..5ced2acc1 100644 --- a/g/util/gvalid/gvalid_check_map.go +++ b/g/util/gvalid/gvalid_check_map.go @@ -14,7 +14,12 @@ import ( // 检测键值对参数Map, // rules参数支持 []string / map[string]string 类型,前面一种类型支持返回校验结果顺序(具体格式参考struct tag),后一种不支持; // rules参数中得 map[string]string 是一个2维的关联数组,第一维键名为参数键名,第二维为带有错误的校验规则名称,值为错误信息。 -func CheckMap(params map[string]interface{}, rules interface{}, msgs...CustomMsg) *Error { +func CheckMap(params interface{}, rules interface{}, msgs...CustomMsg) *Error { + // 将参数转换为 map[string]interface{}类型 + data := gconv.Map(params) + if data == nil { + return newErrorStr("invalid_params", "invalid params type: convert to map[string]interface{} failed") + } // 真实校验规则数据结构 checkRules := make(map[string]string) // 真实自定义错误信息数据结构 @@ -73,11 +78,15 @@ func CheckMap(params map[string]interface{}, rules interface{}, msgs...CustomMsg value := (interface{})(nil) // 这里的rule变量为多条校验规则,不包含名字或者错误信息定义 for key, rule := range checkRules { + // 如果规则为空,那么不执行校验 + if len(rule) == 0 { + continue + } value = nil - if v, ok := params[key]; ok { + if v, ok := data[key]; ok { value = v } - if e := Check(value, rule, customMsgs[key], params); e != nil { + if e := Check(value, rule, customMsgs[key], data); e != nil { _, item := e.FirstItem() // 如果值为nil|"",并且不需要require*验证时,其他验证失效 if value == nil || gconv.String(value) == "" { diff --git a/g/util/gvalid/gvalid_error.go b/g/util/gvalid/gvalid_error.go index 8db62bba7..0f495bc64 100644 --- a/g/util/gvalid/gvalid_error.go +++ b/g/util/gvalid/gvalid_error.go @@ -22,7 +22,7 @@ type Error struct { type ErrorMap map[string]map[string]string -// 创建一个校验错误对象指针 +// 创建一个校验错误对象指针(校验错误) func newError(rules []string, errors map[string]map[string]string) *Error { return &Error { rules : rules, @@ -30,6 +30,18 @@ func newError(rules []string, errors map[string]map[string]string) *Error { } } +// 创建一个校验错误对象指针(内部错误) +func newErrorStr(key, err string) *Error { + return &Error { + rules : nil, + errors : map[string]map[string]string{ + "__gvalid__" : { + key: err, + }, + }, + } +} + // 获得规则与错误信息的map; 当校验结果为多条数据校验时,返回第一条错误map(此时类似FirstItem) func (e *Error) Map() map[string]string { _, m := e.FirstItem() diff --git a/geg/database/orm/mysql/gdb.go b/geg/database/orm/mysql/gdb.go index d951d1821..8c276459d 100644 --- a/geg/database/orm/mysql/gdb.go +++ b/geg/database/orm/mysql/gdb.go @@ -9,7 +9,7 @@ import ( // 本文件用于gf框架的mysql数据库操作示例,不作为单元测试使用 -var db *gdb.Db +var db gdb.DB // 初始化配置及创建数据库 func init () { @@ -17,7 +17,7 @@ func init () { Host : "127.0.0.1", Port : "3306", User : "root", - Pass : "8692651", + Pass : "12345678", Name : "test", Type : "mysql", Role : "master", diff --git a/geg/database/orm/mysql/gdb_pool.go b/geg/database/orm/mysql/gdb_pool.go index 8f4c7dabf..ea99415d1 100644 --- a/geg/database/orm/mysql/gdb_pool.go +++ b/geg/database/orm/mysql/gdb_pool.go @@ -1,28 +1,16 @@ package main import ( - "gitee.com/johng/gf/g/database/gdb" + "gitee.com/johng/gf/g" "time" ) func main() { - gdb.AddDefaultConfigNode(gdb.ConfigNode { - Host : "127.0.0.1", - Port : "3306", - User : "root", - Pass : "12345678", - Name : "test", - Type : "mysql", - Role : "master", - Charset : "utf8", - MaxIdleConnCount : 10, - MaxOpenConnCount : 10, - MaxConnLifetime : 10, - }) - db, err := gdb.New() - if err != nil { - panic(err) - } + db := g.DB() + db.SetMaxIdleConns(10) + db.SetMaxOpenConns(10) + db.SetConnMaxLifetime(10) + // 开启调试模式,以便于记录所有执行的SQL db.SetDebug(true) diff --git a/geg/database/orm/mysql/gdb_value.go b/geg/database/orm/mysql/gdb_value.go index 7851f6dd6..31747b2ec 100644 --- a/geg/database/orm/mysql/gdb_value.go +++ b/geg/database/orm/mysql/gdb_value.go @@ -2,28 +2,15 @@ package main import ( "fmt" - "gitee.com/johng/gf/g/database/gdb" + "gitee.com/johng/gf/g" ) func main() { - gdb.AddDefaultConfigNode(gdb.ConfigNode { - Host : "192.168.1.11", - Port : "3306", - User : "root", - Pass : "8692651", - Name : "test", - Type : "mysql", - Role : "master", - Charset : "utf8", - }) - db, err := gdb.New() - if err != nil { - panic(err) - } + db := g.DB() // 开启调试模式,以便于记录所有执行的SQL db.SetDebug(true) - r, _ := db.Table("user").All() + r, _ := db.Table("test").Where("id IN (?,?)", 1,2).All() if r != nil { fmt.Println(r.ToList()) } diff --git a/geg/other/test/test_test.go b/geg/other/test/test_test.go deleted file mode 100644 index 30da58b5a..000000000 --- a/geg/other/test/test_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package test - -import ( - "strings" - "testing" -) - -var s = "/name/john///." -var c = "./" - -func t1(s string) string { - if len(s) == 0 { - return s - } - for _, cut := range c { - for s[len(s) - 1] == uint8(cut) { - s = s[:len(s) - 1] - if len(s) == 0 { - return s - } - } - } - return s -} - -func t2(s string) string { - return strings.TrimRight(s, c) -} - -func Benchmark_t1(b *testing.B) { - for i := 0; i < b.N; i++ { - t1(s) - } -} - -func Benchmark_t2(b *testing.B) { - for i := 0; i < b.N; i++ { - t2(s) - } -} - diff --git a/geg/other/test2.go b/geg/other/test2.go index 153822499..f6a556927 100644 --- a/geg/other/test2.go +++ b/geg/other/test2.go @@ -1,35 +1,15 @@ package main import ( - "gitee.com/johng/gf/g" - "gitee.com/johng/gf/g/util/gvalid" + "fmt" + "gitee.com/johng/gf/g/util/gregex" ) -type User struct { - Uid int `gvalid:"uid @integer|min:1"` - Name string `gvalid:"name @required|length:6,30#请输入用户名称|用户名称长度非法"` - Pass1 string `gvalid:"password1@required|password3"` - Pass2 string `gvalid:"password2@required|password3|same:password1#||两次密码不一致,请重新输入"` -} + func main() { - user := &User{ - Name : "john", - Pass1: "Abc123!@#", - Pass2: "123", - } - - // 使用结构体定义的校验规则和错误提示进行校验 - g.Dump(gvalid.CheckStruct(user, nil).Maps()) - - // 自定义校验规则和错误提示,对定义的特定校验规则和错误提示进行覆盖 - rules := map[string]string { - "Uid" : "required", - } - msgs := map[string]interface{} { - "Pass2" : map[string]string { - "password3" : "名称不能为空", - }, - } - g.Dump(gvalid.CheckStruct(user, rules, msgs).Maps()) + query := "select * from user" + q, err := gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query) + fmt.Println(err) + fmt.Println(q) } \ No newline at end of file diff --git a/third/github.com/go-sql-driver/mysql/.travis.yml b/third/github.com/go-sql-driver/mysql/.travis.yml index 47dd289a0..75505f144 100644 --- a/third/github.com/go-sql-driver/mysql/.travis.yml +++ b/third/github.com/go-sql-driver/mysql/.travis.yml @@ -4,6 +4,7 @@ go: - 1.8.x - 1.9.x - 1.10.x + - 1.11.x - master before_install: diff --git a/third/github.com/go-sql-driver/mysql/AUTHORS b/third/github.com/go-sql-driver/mysql/AUTHORS index fbe4ec442..5ce4f7eca 100644 --- a/third/github.com/go-sql-driver/mysql/AUTHORS +++ b/third/github.com/go-sql-driver/mysql/AUTHORS @@ -35,6 +35,7 @@ Hanno Braun Henri Yandell Hirotaka Yamamoto ICHINOSE Shogo +Ilia Cimpoes INADA Naoki Jacek Szwec James Harr @@ -72,7 +73,9 @@ Shuode Li Soroush Pour Stan Putrya Stanley Gunawan +Steven Hartland Thomas Wodarek +Tom Jenkinson Xiangyu Hu Xiaobing Jiang Xiuming Chen @@ -88,3 +91,4 @@ Keybase Inc. Percona LLC Pivotal Inc. Stripe Inc. +Multiplay Ltd. diff --git a/third/github.com/go-sql-driver/mysql/README.md b/third/github.com/go-sql-driver/mysql/README.md index bda11031e..6e4816e63 100644 --- a/third/github.com/go-sql-driver/mysql/README.md +++ b/third/github.com/go-sql-driver/mysql/README.md @@ -58,7 +58,7 @@ _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: ```go import "database/sql" -import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" +import _ "github.com/go-sql-driver/mysql" db, err := sql.Open("mysql", "user:password@/dbname") ``` @@ -328,11 +328,11 @@ Timeout for establishing connections, aka dial timeout. The value must be a deci ``` Type: bool / string -Valid Values: true, false, skip-verify, +Valid Values: true, false, skip-verify, preferred, Default: false ``` -`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). +`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use `preferred` to use TLS only when advertised by the server, this is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). ##### `writeTimeout` @@ -431,7 +431,7 @@ See [context support in the database/sql package](https://golang.org/doc/go1.8#d ### `LOAD DATA LOCAL INFILE` support For this feature you need direct access to the package. Therefore you must change the import path (no `_`): ```go -import "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" +import "github.com/go-sql-driver/mysql" ``` Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). diff --git a/third/github.com/go-sql-driver/mysql/auth.go b/third/github.com/go-sql-driver/mysql/auth.go index 2f61ecd4f..fec7040d4 100644 --- a/third/github.com/go-sql-driver/mysql/auth.go +++ b/third/github.com/go-sql-driver/mysql/auth.go @@ -234,64 +234,64 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro if err != nil { return err } - return mc.writeAuthSwitchPacket(enc, false) + return mc.writeAuthSwitchPacket(enc) } -func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { +func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { switch plugin { case "caching_sha2_password": authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) - return authResp, false, nil + return authResp, nil case "mysql_old_password": if !mc.cfg.AllowOldPasswords { - return nil, false, ErrOldPassword + return nil, ErrOldPassword } // Note: there are edge cases where this should work but doesn't; // this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 - authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) - return authResp, true, nil + authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) + return authResp, nil case "mysql_clear_password": if !mc.cfg.AllowCleartextPasswords { - return nil, false, ErrCleartextPassword + return nil, ErrCleartextPassword } // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - return []byte(mc.cfg.Passwd), true, nil + return append([]byte(mc.cfg.Passwd), 0), nil case "mysql_native_password": if !mc.cfg.AllowNativePasswords { - return nil, false, ErrNativePassword + return nil, ErrNativePassword } // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html // Native password authentication only need and will need 20-byte challenge. authResp := scramblePassword(authData[:20], mc.cfg.Passwd) - return authResp, false, nil + return authResp, nil case "sha256_password": if len(mc.cfg.Passwd) == 0 { - return nil, true, nil + return []byte{0}, nil } if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - return []byte(mc.cfg.Passwd), true, nil + return append([]byte(mc.cfg.Passwd), 0), nil } pubKey := mc.cfg.pubKey if pubKey == nil { // request public key from server - return []byte{1}, false, nil + return []byte{1}, nil } // encrypted password enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) - return enc, false, err + return enc, err default: errLog.Print("unknown auth plugin:", plugin) - return nil, false, ErrUnknownPlugin + return nil, ErrUnknownPlugin } } @@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { plugin = newPlugin - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { return err } - if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { + if err = mc.writeAuthSwitchPacket(authResp); err != nil { return err } @@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { case cachingSha2PasswordPerformFullAuthentication: if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) + err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) if err != nil { return err } @@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { pubKey := mc.cfg.pubKey if pubKey == nil { // request public key from server - data := mc.buf.takeSmallBuffer(4 + 1) + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + return err + } data[4] = cachingSha2PasswordRequestPublicKey mc.writePacket(data) // parse public key - data, err := mc.readPacket() - if err != nil { + if data, err = mc.readPacket(); err != nil { return err } diff --git a/third/github.com/go-sql-driver/mysql/auth_test.go b/third/github.com/go-sql-driver/mysql/auth_test.go index bd0e2189c..1920ef39f 100644 --- a/third/github.com/go-sql-driver/mysql/auth_test.go +++ b/third/github.com/go-sql-driver/mysql/auth_test.go @@ -85,11 +85,11 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -130,11 +130,11 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -172,11 +172,11 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -228,11 +228,11 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -280,11 +280,11 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -336,7 +336,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - _, _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) } @@ -353,11 +353,11 @@ func TestAuthFastCleartextPassword(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -367,8 +367,8 @@ func TestAuthFastCleartextPassword(t *testing.T) { authRespEnd := authRespStart + 1 + len(authResp) writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{115, 101, 99, 114, 101, 116} - if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } conn.written = nil @@ -396,11 +396,11 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -410,9 +410,9 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { authRespEnd := authRespStart + 1 + len(authResp) writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - if writtenAuthRespLen != 0 { - t.Fatalf("unexpected written auth response (%d bytes): %v", - writtenAuthRespLen, writtenAuthResp) + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } conn.written = nil @@ -439,7 +439,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - _, _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) } @@ -455,11 +455,11 @@ func TestAuthFastNativePassword(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -498,11 +498,11 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -540,11 +540,11 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -554,7 +554,8 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { authRespEnd := authRespStart + 1 + len(authResp) writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - if writtenAuthRespLen != 0 { + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } conn.written = nil @@ -587,11 +588,11 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -636,11 +637,11 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -669,7 +670,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { plugin := "sha256_password" // send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } @@ -677,18 +678,18 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // unset TLS config to prevent the actual establishment of a TLS wrapper mc.cfg.tls = nil - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } // check written auth response authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) + 1 + authRespEnd := authRespStart + 1 + len(authResp) writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} - if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } conn.written = nil @@ -1064,6 +1065,22 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { } } +// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request. +func TestOldAuthSwitchNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + conn.maxReads = 1 + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrOldPassword { + t.Errorf("expected ErrOldPassword, got %v", err) + } +} + func TestAuthSwitchOldPassword(t *testing.T) { conn, mc := newRWMockConn(2) mc.cfg.AllowOldPasswords = true @@ -1092,6 +1109,32 @@ func TestAuthSwitchOldPassword(t *testing.T) { } } +// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request. +func TestOldAuthSwitch(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "secret" + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} func TestAuthSwitchOldPasswordEmpty(t *testing.T) { conn, mc := newRWMockConn(2) mc.cfg.AllowOldPasswords = true @@ -1120,6 +1163,33 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) { } } +// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request. +func TestOldAuthSwitchPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "" + + // OldAuthSwitch request. + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { conn, mc := newRWMockConn(2) mc.cfg.Passwd = "" diff --git a/third/github.com/go-sql-driver/mysql/buffer.go b/third/github.com/go-sql-driver/mysql/buffer.go index eb4748bf4..19486bd6f 100644 --- a/third/github.com/go-sql-driver/mysql/buffer.go +++ b/third/github.com/go-sql-driver/mysql/buffer.go @@ -22,17 +22,17 @@ const defaultBufSize = 4096 // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. type buffer struct { - buf []byte + buf []byte // buf is a byte buffer who's length and capacity are equal. nc net.Conn idx int length int timeout time.Duration } +// newBuffer allocates and returns a new buffer. func newBuffer(nc net.Conn) buffer { - var b [defaultBufSize]byte return buffer{ - buf: b[:], + buf: make([]byte, defaultBufSize), nc: nc, } } @@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) { return b.buf[offset:b.idx], nil } -// returns a buffer with the requested size. +// takeBuffer returns a buffer with the requested size. // If possible, a slice from the existing buffer is returned. // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. -func (b *buffer) takeBuffer(length int) []byte { +func (b *buffer) takeBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } // test (cheap) general case first - if length <= defaultBufSize || length <= cap(b.buf) { - return b.buf[:length] + if length <= cap(b.buf) { + return b.buf[:length], nil } if length < maxPacketSize { b.buf = make([]byte, length) - return b.buf + return b.buf, nil } - return make([]byte, length) + + // buffer is larger than we want to store. + return make([]byte, length), nil } -// shortcut which can be used if the requested buffer is guaranteed to be -// smaller than defaultBufSize +// takeSmallBuffer is shortcut which can be used if length is +// known to be smaller than defaultBufSize. // Only one buffer (total) can be used at a time. -func (b *buffer) takeSmallBuffer(length int) []byte { +func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } - return b.buf[:length] + return b.buf[:length], nil } // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. +// cap and len of the returned buffer will be equal. // Only one buffer (total) can be used at a time. -func (b *buffer) takeCompleteBuffer() []byte { +func (b *buffer) takeCompleteBuffer() ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } - return b.buf + return b.buf, nil +} + +// store stores buf, an updated buffer, if its suitable to do so. +func (b *buffer) store(buf []byte) error { + if b.length > 0 { + return ErrBusyBuffer + } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) { + b.buf = buf[:cap(buf)] + } + return nil } diff --git a/third/github.com/go-sql-driver/mysql/connection.go b/third/github.com/go-sql-driver/mysql/connection.go index 911be2060..fc4ec7597 100644 --- a/third/github.com/go-sql-driver/mysql/connection.go +++ b/third/github.com/go-sql-driver/mysql/connection.go @@ -19,16 +19,6 @@ import ( "time" ) -// a copy of context.Context for Go 1.7 and earlier -type mysqlContext interface { - Done() <-chan struct{} - Err() error - - // defined in context.Context, but not used in this driver: - // Deadline() (deadline time.Time, ok bool) - // Value(key interface{}) interface{} -} - type mysqlConn struct { buf buffer netConn net.Conn @@ -45,7 +35,7 @@ type mysqlConn struct { // for context support (Go 1.8+) watching bool - watcher chan<- mysqlContext + watcher chan<- context.Context closech chan struct{} finished chan<- struct{} canceled atomicError // set non-nil if conn is canceled @@ -192,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - buf := mc.buf.takeCompleteBuffer() - if buf == nil { + buf, err := mc.buf.takeCompleteBuffer() + if err != nil { // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return "", ErrInvalidConn } buf = buf[:0] @@ -475,7 +465,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { defer mc.finish() if err = mc.writeCommandPacket(comPing); err != nil { - return + return mc.markBadConn(err) } return mc.readResultOK() @@ -595,33 +585,32 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error { mc.cleanup() return nil } + // When ctx is already cancelled, don't watch it. + if err := ctx.Err(); err != nil { + return err + } + // When ctx is not cancellable, don't watch it. if ctx.Done() == nil { return nil } - - mc.watching = true - select { - default: - case <-ctx.Done(): - return ctx.Err() - } + // When watcher is not alive, can't watch it. if mc.watcher == nil { return nil } + mc.watching = true mc.watcher <- ctx - return nil } func (mc *mysqlConn) startWatcher() { - watcher := make(chan mysqlContext, 1) + watcher := make(chan context.Context, 1) mc.watcher = watcher finished := make(chan struct{}) mc.finished = finished go func() { for { - var ctx mysqlContext + var ctx context.Context select { case ctx = <-watcher: case <-mc.closech: diff --git a/third/github.com/go-sql-driver/mysql/connection_test.go b/third/github.com/go-sql-driver/mysql/connection_test.go index dec376117..2a1c8e888 100644 --- a/third/github.com/go-sql-driver/mysql/connection_test.go +++ b/third/github.com/go-sql-driver/mysql/connection_test.go @@ -9,7 +9,10 @@ package mysql import ( + "context" "database/sql/driver" + "errors" + "net" "testing" ) @@ -79,3 +82,76 @@ func TestCheckNamedValue(t *testing.T) { t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value) } } + +// TestCleanCancel tests passed context is cancelled at start. +// No packet should be sent. Connection should keep current status. +func TestCleanCancel(t *testing.T) { + mc := &mysqlConn{ + closech: make(chan struct{}), + } + mc.startWatcher() + defer mc.cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + for i := 0; i < 3; i++ { // Repeat same behavior + err := mc.Ping(ctx) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %#v", err) + } + + if mc.closed.IsSet() { + t.Error("expected mc is not closed, closed actually") + } + + if mc.watching { + t.Error("expected watching is false, but true") + } + } +} + +func TestPingMarkBadConnection(t *testing.T) { + nc := badConnection{err: errors.New("boom")} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + } + + err := ms.Ping(context.Background()) + + if err != driver.ErrBadConn { + t.Errorf("expected driver.ErrBadConn, got %#v", err) + } +} + +func TestPingErrInvalidConn(t *testing.T) { + nc := badConnection{err: errors.New("failed to write"), n: 10} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + } + + err := ms.Ping(context.Background()) + + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %#v", err) + } +} + +type badConnection struct { + n int + err error + net.Conn +} + +func (bc badConnection) Write(b []byte) (n int, err error) { + return bc.n, bc.err +} + +func (bc badConnection) Close() error { + return nil +} diff --git a/third/github.com/go-sql-driver/mysql/driver.go b/third/github.com/go-sql-driver/mysql/driver.go index 53782ed24..9f4967087 100644 --- a/third/github.com/go-sql-driver/mysql/driver.go +++ b/third/github.com/go-sql-driver/mysql/driver.go @@ -9,7 +9,7 @@ // The driver should be used via the database/sql package: // // import "database/sql" -// import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" +// import _ "github.com/go-sql-driver/mysql" // // db, err := sql.Open("mysql", "user:password@/dbname") // @@ -50,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) { // Open new Connection. // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how -// the DSN string is formated +// the DSN string is formatted func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { var err error @@ -77,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) } if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + errLog.Print("net.Error from Dial()': ", nerr.Error()) + return nil, driver.ErrBadConn + } return nil, err } @@ -110,18 +114,18 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { // try the default auth plugin, if using the requested plugin failed errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) plugin = defaultAuthPlugin - authResp, addNUL, err = mc.auth(authData, plugin) + authResp, err = mc.auth(authData, plugin) if err != nil { mc.cleanup() return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { mc.cleanup() return nil, err } diff --git a/third/github.com/go-sql-driver/mysql/driver_test.go b/third/github.com/go-sql-driver/mysql/driver_test.go index f2bf344e5..46d1f7ff4 100644 --- a/third/github.com/go-sql-driver/mysql/driver_test.go +++ b/third/github.com/go-sql-driver/mysql/driver_test.go @@ -85,6 +85,23 @@ type DBTest struct { db *sql.DB } +type netErrorMock struct { + temporary bool + timeout bool +} + +func (e netErrorMock) Temporary() bool { + return e.temporary +} + +func (e netErrorMock) Timeout() bool { + return e.timeout +} + +func (e netErrorMock) Error() string { + return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout) +} + func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if !available { t.Skipf("MySQL server not running on %s", netAddr) @@ -1287,7 +1304,7 @@ func TestFoundRows(t *testing.T) { } func TestTLS(t *testing.T) { - tlsTest := func(dbt *DBTest) { + tlsTestReq := func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { if err == ErrNoTLS { dbt.Skip("server does not support TLS") @@ -1304,19 +1321,27 @@ func TestTLS(t *testing.T) { dbt.Fatal(err.Error()) } - if value == nil { - dbt.Fatal("no Cipher") + if (*value == nil) || (len(*value) == 0) { + dbt.Fatalf("no Cipher") + } else { + dbt.Logf("Cipher: %s", *value) } } } + tlsTestOpt := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } - runTests(t, dsn+"&tls=skip-verify", tlsTest) + runTests(t, dsn+"&tls=preferred", tlsTestOpt) + runTests(t, dsn+"&tls=skip-verify", tlsTestReq) // Verify that registering / using a custom cfg works RegisterTLSConfig("custom-skip-verify", &tls.Config{ InsecureSkipVerify: true, }) - runTests(t, dsn+"&tls=custom-skip-verify", tlsTest) + runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq) } func TestReuseClosedConnection(t *testing.T) { @@ -1801,6 +1826,38 @@ func TestConcurrent(t *testing.T) { }) } +func testDialError(t *testing.T, dialErr error, expectErr error) { + RegisterDial("mydial", func(addr string) (net.Conn, error) { + return nil, dialErr + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + _, err = db.Exec("DO 1") + if err != expectErr { + t.Fatalf("was expecting %s. Got: %s", dialErr, err) + } +} + +func TestDialUnknownError(t *testing.T) { + testErr := fmt.Errorf("test") + testDialError(t, testErr, testErr) +} + +func TestDialNonRetryableNetErr(t *testing.T) { + testErr := netErrorMock{} + testDialError(t, testErr, testErr) +} + +func TestDialTemporaryNetErr(t *testing.T) { + testErr := netErrorMock{temporary: true} + testDialError(t, testErr, driver.ErrBadConn) +} + // Tests custom dial functions func TestCustomDial(t *testing.T) { if !available { diff --git a/third/github.com/go-sql-driver/mysql/dsn.go b/third/github.com/go-sql-driver/mysql/dsn.go index be014babe..b9134722e 100644 --- a/third/github.com/go-sql-driver/mysql/dsn.go +++ b/third/github.com/go-sql-driver/mysql/dsn.go @@ -560,7 +560,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { } else { cfg.TLSConfig = "false" } - } else if vl := strings.ToLower(value); vl == "skip-verify" { + } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" { cfg.TLSConfig = vl cfg.tls = &tls.Config{InsecureSkipVerify: true} } else { diff --git a/third/github.com/go-sql-driver/mysql/packets.go b/third/github.com/go-sql-driver/mysql/packets.go index 170aaa02b..5e0853767 100644 --- a/third/github.com/go-sql-driver/mysql/packets.go +++ b/third/github.com/go-sql-driver/mysql/packets.go @@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { mc.sequence++ // packets with length 0 terminate a previous packet which is a - // multiple of (2^24)−1 bytes long + // multiple of (2^24)-1 bytes long if pktLen == 0 { // there was no previous packet if prevData == nil { @@ -194,7 +194,11 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return nil, "", ErrNoTLS + if mc.cfg.TLSConfig == "preferred" { + mc.cfg.tls = nil + } else { + return nil, "", ErrNoTLS + } } pos += 2 @@ -243,7 +247,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -269,7 +273,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, // encode length of the auth plugin data var authRespLEIBuf [9]byte - authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) + authRespLen := len(authResp) + authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) if len(authRespLEI) > 1 { // if the length can not be written in 1 byte, it must be written as a // length encoded integer @@ -277,9 +282,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, } pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 - if addNUL { - pktLen++ - } // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -288,10 +290,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, } // Calculate packet length and get buffer with that size - data := mc.buf.takeSmallBuffer(pktLen + 4) - if data == nil { + data, err := mc.buf.takeSmallBuffer(pktLen + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -350,10 +352,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, // Auth Data [length encoded integer] pos += copy(data[pos:], authRespLEI) pos += copy(data[pos:], authResp) - if addNUL { - data[pos] = 0x00 - pos++ - } // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { @@ -364,30 +362,24 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, pos += copy(data[pos:], plugin) data[pos] = 0x00 + pos++ // Send Auth packet - return mc.writePacket(data) + return mc.writePacket(data[:pos]) } // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) - if addNUL { - pktLen++ - } - data := mc.buf.takeSmallBuffer(pktLen) - if data == nil { + data, err := mc.buf.takeSmallBuffer(pktLen) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } // Add the auth data [EOF] copy(data[4:], authData) - if addNUL { - data[pktLen-1] = 0x00 - } - return mc.writePacket(data) } @@ -399,10 +391,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1) - if data == nil { + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -418,10 +410,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.sequence = 0 pktLen := 1 + len(arg) - data := mc.buf.takeBuffer(pktLen + 4) - if data == nil { + data, err := mc.buf.takeBuffer(pktLen + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -439,10 +431,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1 + 4) - if data == nil { + data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -479,7 +471,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { return data[1:], "", err case iEOF: - if len(data) < 1 { + if len(data) == 1 { // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest return nil, "mysql_old_password", nil } @@ -895,7 +887,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc - // Determine threshould dynamically to avoid packet size shortage. + // Determine threshold dynamically to avoid packet size shortage. longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) if longDataSize < 64 { longDataSize = 64 @@ -905,15 +897,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { mc.sequence = 0 var data []byte + var err error if len(args) == 0 { - data = mc.buf.takeBuffer(minPktLen) + data, err = mc.buf.takeBuffer(minPktLen) } else { - data = mc.buf.takeCompleteBuffer() + data, err = mc.buf.takeCompleteBuffer() + // In this case the len(data) == cap(data) which is used to optimise the flow below. } - if data == nil { + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -939,7 +933,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { pos := minPktLen var nullMask []byte - if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) { // buffer has to be extended but we don't know by how much so // we depend on append after all data with known sizes fit. // We stop at that because we deal with a lot of columns here @@ -948,10 +942,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { copy(tmp[:pos], data[:pos]) data = tmp nullMask = data[pos : pos+maskLen] + // No need to clean nullMask as make ensures that. pos += maskLen } else { nullMask = data[pos : pos+maskLen] - for i := 0; i < maskLen; i++ { + for i := range nullMask { nullMask[i] = 0 } pos += maskLen @@ -1088,7 +1083,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - mc.buf.buf = data + if err = mc.buf.store(data); err != nil { + errLog.Print(err) + return errBadConnNoWrite + } } pos += len(paramValues)