From e67aa63a508469869fe5e0dabcc02b118fc3ea48 Mon Sep 17 00:00:00 2001 From: John Date: Sat, 15 Dec 2018 15:50:39 +0800 Subject: [PATCH] refract gdb package, add complete unit test cases, almost there --- g/database/gdb/gdb.go | 65 +-- g/database/gdb/gdb_base.go | 299 ++++++++------ g/database/gdb/gdb_config.go | 24 +- g/database/gdb/gdb_func.go | 110 ++---- g/database/gdb/gdb_model.go | 35 +- g/database/gdb/gdb_mssql.go | 11 +- g/database/gdb/gdb_mysql.go | 8 +- g/database/gdb/gdb_oracle.go | 8 +- g/database/gdb/gdb_pgsql.go | 8 +- g/database/gdb/gdb_sqlite.go | 8 +- g/database/gdb/gdb_transaction.go | 206 ++-------- g/database/gdb/gdb_unit_1_test.go | 215 ++++++++++ g/database/gdb/gdb_unit_2_test.go | 177 +++++++++ g/database/gdb/gdb_unit_3_test.go | 372 ++++++++++++++++++ g/frame/gins/gins.go | 4 +- g/g_object.go | 4 +- g/os/glog/glog_logger.go | 43 +- g/util/gconv/gconv_struct.go | 4 - g/util/gtest/gtest.go | 30 ++ geg/database/orm/mysql/gdb.go | 4 +- geg/database/orm/mysql/gdb_pool.go | 24 +- geg/database/orm/mysql/gdb_value.go | 17 +- geg/other/test/test_test.go | 41 -- geg/other/test2.go | 26 +- .../go-sql-driver/mysql/.travis.yml | 1 + third/github.com/go-sql-driver/mysql/AUTHORS | 4 + .../github.com/go-sql-driver/mysql/README.md | 8 +- third/github.com/go-sql-driver/mysql/auth.go | 44 ++- .../go-sql-driver/mysql/auth_test.go | 142 +++++-- .../github.com/go-sql-driver/mysql/buffer.go | 49 ++- .../go-sql-driver/mysql/connection.go | 39 +- .../go-sql-driver/mysql/connection_test.go | 76 ++++ .../github.com/go-sql-driver/mysql/driver.go | 14 +- .../go-sql-driver/mysql/driver_test.go | 67 +++- third/github.com/go-sql-driver/mysql/dsn.go | 2 +- .../github.com/go-sql-driver/mysql/packets.go | 86 ++-- 36 files changed, 1530 insertions(+), 745 deletions(-) create mode 100644 g/database/gdb/gdb_unit_1_test.go create mode 100644 g/database/gdb/gdb_unit_2_test.go create mode 100644 g/database/gdb/gdb_unit_3_test.go create mode 100644 g/util/gtest/gtest.go delete mode 100644 geg/other/test/test_test.go diff --git a/g/database/gdb/gdb.go b/g/database/gdb/gdb.go index 5dd863618..cae550e8d 100644 --- a/g/database/gdb/gdb.go +++ b/g/database/gdb/gdb.go @@ -30,27 +30,39 @@ const ( // 数据库操作接口 type DB interface { + // 建立数据库连接方法(开发者一般不需要直接调用) + Open(config *ConfigNode) (*sql.DB, error) + // SQL操作方法 Query(query string, args ...interface{}) (*sql.Rows, error) Exec(sql string, args ...interface{}) (sql.Result, error) Prepare(sql string, execOnMaster...bool) (*sql.Stmt, error) + doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) + doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) + doPrepare(link dbLink, query string) (*sql.Stmt, error) + doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error) + doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error) + doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error) + doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) + // 数据库查询 GetAll(query string, args ...interface{}) (Result, error) GetOne(query string, args ...interface{}) (Record, error) GetValue(query string, args ...interface{}) (Value, error) + GetCount(query string, args ...interface{}) (int, error) + GetStruct(obj interface{}, query string, args ...interface{}) error - // Ping + // 创建底层数据库master/slave链接对象 + Master() (*sql.DB, error) + Slave() (*sql.DB, error) + + // Ping PingMaster() error PingSlave() error - // 连接属性设置 - SetMaxIdleConns(n int) - SetMaxOpenConns(n int) - SetConnMaxLifetime(n int) - // 开启事务操作 - Begin() (*Tx, error) + Begin() (*TX, error) // 数据表插入/更新/保存操作 Insert(table string, data Map) (sql.Result, error) @@ -72,17 +84,26 @@ type DB interface { // 设置管理 SetDebug(debug bool) + GetQueriedSqls() []*Sql + PrintQueriedSqls() + SetMaxIdleConns(n int) + SetMaxOpenConns(n int) + SetConnMaxLifetime(n int) // 内部方法接口 - open(c *ConfigNode) (*sql.DB, error) getCache() (*gcache.Cache) getChars() (charLeft string, charRight string) getDebug() bool - putSql(s *Sql) - formatCondition(condition interface{}) (where string) handleSqlBeforeExec(sql string) string } +// 执行底层数据库操作的核心接口 +type dbLink interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + Exec(sql string, args ...interface{}) (sql.Result, error) + Prepare(sql string) (*sql.Stmt, error) +} + // 数据库链接对象 type dbBase struct { db DB // 数据库对象 @@ -228,36 +249,36 @@ func getConfigNodeByPriority(cg ConfigGroup) *ConfigNode { } // 获得底层数据库链接对象 -func (db *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) { +func (bs *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) { // 负载均衡 - node, err := getConfigNodeByGroup(db.group, master) + node, err := getConfigNodeByGroup(bs.group, master) if err != nil { return nil, err } // 检查缓存连接池对象 cacheKey := node.String() - if v := db.cache.Get(cacheKey); v != nil { + if v := bs.cache.Get(cacheKey); v != nil { return v.(*sql.DB), nil } - v := db.cache.GetOrSetFuncLock(node.String(), func() interface{} { - sqlDb, err = db.db.open(node) + v := bs.cache.GetOrSetFuncLock(node.String(), func() interface{} { + sqlDb, err = bs.db.Open(node) if err != nil { return nil } - if n := db.maxIdleConnCount.Val(); n > 0 { + if n := bs.maxIdleConnCount.Val(); n > 0 { sqlDb.SetMaxIdleConns(n) } else if node.MaxIdleConnCount > 0 { sqlDb.SetMaxIdleConns(node.MaxIdleConnCount) } - if n := db.maxOpenConnCount.Val(); n > 0 { + if n := bs.maxOpenConnCount.Val(); n > 0 { sqlDb.SetMaxOpenConns(n) } else if node.MaxOpenConnCount > 0 { sqlDb.SetMaxOpenConns(node.MaxOpenConnCount) } - if n := db.maxConnLifetime.Val(); n > 0 { + if n := bs.maxConnLifetime.Val(); n > 0 { sqlDb.SetConnMaxLifetime(time.Duration(n) * time.Second) } else if node.MaxConnLifetime > 0 { sqlDb.SetConnMaxLifetime(time.Duration(node.MaxConnLifetime) * time.Second) @@ -271,11 +292,11 @@ func (db *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) { } // 创建底层数据库master链接对象 -func (db *dbBase) Master() (*sql.DB, error) { - return db.getSqlDb(true) +func (bs *dbBase) Master() (*sql.DB, error) { + return bs.getSqlDb(true) } // 创建底层数据库slave链接对象 -func (db *dbBase) Slave() (*sql.DB, error) { - return db.getSqlDb(false) +func (bs *dbBase) Slave() (*sql.DB, error) { + return bs.getSqlDb(false) } diff --git a/g/database/gdb/gdb_base.go b/g/database/gdb/gdb_base.go index e11618f33..07ab10b63 100644 --- a/g/database/gdb/gdb_base.go +++ b/g/database/gdb/gdb_base.go @@ -14,7 +14,7 @@ import ( "gitee.com/johng/gf/g/os/gcache" "gitee.com/johng/gf/g/os/gtime" "gitee.com/johng/gf/g/util/gconv" - "gitee.com/johng/gf/g/util/gstr" + "gitee.com/johng/gf/g/util/gregex" "reflect" "strings" ) @@ -24,13 +24,13 @@ const ( ) // 获取已经执行的SQL列表(仅在debug=true时有效) -func (db *dbBase) GetQueriedSqls() []*Sql { - if db.sqls == nil { +func (bs *dbBase) GetQueriedSqls() []*Sql { + if bs.sqls == nil { return nil } sqls := make([]*Sql, 0) - db.sqls.Prev() - db.sqls.RLockIteratorPrev(func(value interface{}) bool { + bs.sqls.Prev() + bs.sqls.RLockIteratorPrev(func(value interface{}) bool { if value == nil { return false } @@ -41,8 +41,8 @@ func (db *dbBase) GetQueriedSqls() []*Sql { } // 打印已经执行的SQL列表(仅在debug=true时有效) -func (db *dbBase) PrintQueriedSqls() { - sqls := db.GetQueriedSqls() +func (bs *dbBase) PrintQueriedSqls() { + sqls := bs.GetQueriedSqls() for k, v := range sqls { fmt.Println(len(sqls) - k, ":") fmt.Println(" Sql :", v.Sql) @@ -55,51 +55,106 @@ func (db *dbBase) PrintQueriedSqls() { } // 数据库sql查询操作,主要执行查询 -func (db *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { - link, err := db.Slave() +func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { + link, err := bs.db.Slave() if err != nil { return nil,err } - return doQuery(db.db, link, query, args...) + return bs.db.doQuery(link, query, args...) +} + +// 数据库sql查询操作,主要执行查询 +func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) { + query = bs.db.handleSqlBeforeExec(query) + if bs.db.getDebug() { + mTime1 := gtime.Millisecond() + rows, err = link.Query(query, args...) + mTime2 := gtime.Millisecond() + s := &Sql { + Sql : query, + Args : args, + Error : err, + Start : mTime1, + End : mTime2, + } + bs.sqls.Put(s) + printSql(s) + } else { + rows, err = link.Query(query, args ...) + } + if err == nil { + return rows, nil + } else { + err = formatError(err, query, args...) + } + return nil, err } // 执行一条sql,并返回执行情况,主要用于非查询操作 -func (db *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) { - link, err := db.Master() +func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) { + link, err := bs.db.Master() if err != nil { return nil,err } - return doExec(db.db, link, query, args...) + return bs.db.doExec(link, query, args...) +} + +// 执行一条sql,并返回执行情况,主要用于非查询操作 +func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) { + query = bs.db.handleSqlBeforeExec(query) + if bs.db.getDebug() { + mTime1 := gtime.Millisecond() + result, err = link.Exec(query, args ...) + mTime2 := gtime.Millisecond() + s := &Sql{ + Sql : query, + Args : args, + Error : err, + Start : mTime1, + End : mTime2, + } + bs.sqls.Put(s) + printSql(s) + } else { + result, err = link.Exec(query, args ...) + } + return result, formatError(err, query, args...) } // SQL预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作; 默认执行在Slave上, 通过第二个参数指定执行在Master上 -func (db *dbBase) Prepare(query string, execOnMaster...bool) (*sql.Stmt, error) { +func (bs *dbBase) Prepare(query string, execOnMaster...bool) (*sql.Stmt, error) { err := (error)(nil) - sqldb := (*sql.DB)(nil) + link := (dbLink)(nil) if len(execOnMaster) > 0 && execOnMaster[0] { - if sqldb, err = db.Master(); err != nil { + if link, err = bs.db.Master(); err != nil { return nil, err } } else { - if sqldb, err = db.Slave(); err != nil { + if link, err = bs.db.Slave(); err != nil { return nil, err } } - return sqldb.Prepare(query) + return bs.db.doPrepare(link, query) +} + +// SQL预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 +func (bs *dbBase) doPrepare(link dbLink, query string) (*sql.Stmt, error) { + return link.Prepare(query) } // 数据库查询,获取查询结果集,以列表结构返回 -func (db *dbBase) GetAll(query string, args ...interface{}) (Result, error) { - rows, err := db.Query(query, args ...) +func (bs *dbBase) GetAll(query string, args ...interface{}) (Result, error) { + rows, err := bs.Query(query, args ...) if err != nil || rows == nil { return nil, err } + defer rows.Close() return rowsToResult(rows) } // 数据库查询,获取查询结果记录,以关联数组结构返回 -func (db *dbBase) GetOne(query string, args ...interface{}) (Record, error) { - list, err := db.GetAll(query, args ...) +func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) { + list, err := bs.GetAll(query, args ...) if err != nil { return nil, err } @@ -110,18 +165,17 @@ func (db *dbBase) GetOne(query string, args ...interface{}) (Record, error) { } // 数据库查询,获取查询结果记录,自动映射数据到给定的struct对象中 -func (db *dbBase) GetStruct(obj interface{}, query string, args ...interface{}) error { - one, err := db.GetOne(query, args...) +func (bs *dbBase) GetStruct(obj interface{}, query string, args ...interface{}) error { + one, err := bs.GetOne(query, args...) if err != nil { return err } return one.ToStruct(obj) } - // 数据库查询,获取查询字段值 -func (db *dbBase) GetValue(query string, args ...interface{}) (Value, error) { - one, err := db.GetOne(query, args ...) +func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) { + one, err := bs.GetOne(query, args ...) if err != nil { return nil, err } @@ -132,35 +186,20 @@ func (db *dbBase) GetValue(query string, args ...interface{}) (Value, error) { } // 数据库查询,获取查询数量 -func (db *dbBase) GetCount(query string, args ...interface{}) (int, error) { - val, err := db.GetValue(query, args ...) +func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) { + if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) { + query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query) + } + value, err := bs.GetValue(query, args ...) if err != nil { return 0, err } - return gconv.Int(val), nil -} - -// 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作 -func (db *dbBase) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) { - s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables) - if condition != nil { - s += fmt.Sprintf("WHERE %s ", db.formatCondition(condition)) - } - if len(groupBy) > 0 { - s += fmt.Sprintf("GROUP BY %s ", groupBy) - } - if len(orderBy) > 0 { - s += fmt.Sprintf("ORDER BY %s ", orderBy) - } - if limit > 0 { - s += fmt.Sprintf("LIMIT %d,%d ", first, limit) - } - return db.GetAll(s, args ... ) + return value.Int(), nil } // ping一下,判断或保持数据库链接(master) -func (db *dbBase) PingMaster() error { - if master, err := db.Master(); err != nil { +func (bs *dbBase) PingMaster() error { + if master, err := bs.db.Master(); err != nil { return err } else { return master.Ping() @@ -168,8 +207,8 @@ func (db *dbBase) PingMaster() error { } // ping一下,判断或保持数据库链接(slave) -func (db *dbBase) PingSlave() error { - if slave, err := db.Slave(); err != nil { +func (bs *dbBase) PingSlave() error { + if slave, err := bs.db.Slave(); err != nil { return err } else { return slave.Ping() @@ -178,13 +217,13 @@ func (db *dbBase) PingSlave() error { // 事务操作,开启,会返回一个底层的事务操作对象链接如需要嵌套事务,那么可以使用该对象,否则请忽略 // 只有在tx.Commit/tx.Rollback时,链接会自动Close -func (db *dbBase) Begin() (*TX, error) { - if master, err := db.Master(); err != nil { +func (bs *dbBase) Begin() (*TX, error) { + if master, err := bs.db.Master(); err != nil { return nil, err } else { if tx, err := master.Begin(); err == nil { return &TX { - db : db.db, + db : bs.db, tx : tx, master : master, }, nil @@ -194,16 +233,31 @@ func (db *dbBase) Begin() (*TX, error) { } } +// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 +func (bs *dbBase) Insert(table string, data Map) (sql.Result, error) { + return bs.db.doInsert(nil, table, data, OPTION_INSERT) +} + +// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 +func (bs *dbBase) Replace(table string, data Map) (sql.Result, error) { + return bs.db.doInsert(nil, table, data, OPTION_REPLACE) +} + +// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 +func (bs *dbBase) Save(table string, data Map) (sql.Result, error) { + return bs.db.doInsert(nil, table, data, OPTION_SAVE) +} + // insert、replace, save, ignore操作 // 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 // 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 // 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 // 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做 -func (db *dbBase) insert(table string, data Map, option uint8) (sql.Result, error) { +func (bs *dbBase) doInsert(link dbLink, table string, data Map, option int) (result sql.Result, err error) { var fields []string var values []string var params []interface{} - charl, charr := db.db.getChars() + charl, charr := bs.db.getChars() for k, v := range data { fields = append(fields, charl + k + charr) values = append(values, "?") @@ -223,48 +277,53 @@ func (db *dbBase) insert(table string, data Map, option uint8) (sql.Result, erro } updatestr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) } - return db.Exec( - fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s", - operation, table, strings.Join(fields, ","), - strings.Join(values, ","), - updatestr), - params... - ) + if link == nil { + if link, err = bs.db.Master(); err != nil { + return nil, err + } + } + return bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s", + operation, table, strings.Join(fields, ","), + strings.Join(values, ","), updatestr), + params...) } -// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 -func (db *dbBase) Insert(table string, data Map) (sql.Result, error) { - return db.insert(table, data, OPTION_INSERT) +// CURD操作:批量数据指定批次量写入 +func (bs *dbBase) BatchInsert(table string, list List, batch int) (sql.Result, error) { + return bs.db.doBatchInsert(nil, table, list, batch, OPTION_INSERT) } -// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -func (db *dbBase) Replace(table string, data Map) (sql.Result, error) { - return db.insert(table, data, OPTION_REPLACE) +// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 +func (bs *dbBase) BatchReplace(table string, list List, batch int) (sql.Result, error) { + return bs.db.doBatchInsert(nil, table, list, batch, OPTION_REPLACE) } -// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -func (db *dbBase) Save(table string, data Map) (sql.Result, error) { - return db.insert(table, data, OPTION_SAVE) +// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 +func (bs *dbBase) BatchSave(table string, list List, batch int) (sql.Result, error) { + return bs.db.doBatchInsert(nil, table, list, batch, OPTION_SAVE) } // 批量写入数据 -func (db *dbBase) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) { +func (bs *dbBase) doBatchInsert(link dbLink, table string, list List, batch int, option int) (result sql.Result, err error) { var keys []string var values []string var bvalues []string var params []interface{} - var result sql.Result - var size = len(list) // 判断长度 - if size < 1 { + if len(list) < 1 { return result, errors.New("empty data list") } + if link == nil { + if link, err = bs.db.Master(); err != nil { + return + } + } // 首先获取字段名称及记录长度 for k, _ := range list[0] { keys = append(keys, k) values = append(values, "?") } - charl, charr := db.db.getChars() + charl, charr := bs.db.getChars() keyStr := charl + strings.Join(keys, charl + "," + charr) + charr valueHolderStr := "(" + strings.Join(values, ",") + ")" // 操作判断 @@ -283,13 +342,13 @@ func (db *dbBase) batchInsert(table string, list List, batch int, option uint8) updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) } // 构造批量写入数据格式(注意map的遍历是无序的) - for i := 0; i < size; i++ { + for i := 0; i < len(list); i++ { for _, k := range keys { params = append(params, list[i][k]) } bvalues = append(bvalues, valueHolderStr) if len(bvalues) == batch { - r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", + r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", operation, table, keyStr, strings.Join(bvalues, ","), updatestr), params...) @@ -303,7 +362,7 @@ func (db *dbBase) batchInsert(table string, list List, batch int, option uint8) } // 处理最后不构成指定批量的数据 if len(bvalues) > 0 { - r, err := db.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", + r, err := bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", operation, table, keyStr, strings.Join(bvalues, ","), updatestr), params...) @@ -315,27 +374,22 @@ func (db *dbBase) batchInsert(table string, list List, batch int, option uint8) return result, nil } -// CURD操作:批量数据指定批次量写入 -func (db *dbBase) BatchInsert(table string, list List, batch int) (sql.Result, error) { - return db.batchInsert(table, list, batch, OPTION_INSERT) -} - -// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -func (db *dbBase) BatchReplace(table string, list List, batch int) (sql.Result, error) { - return db.batchInsert(table, list, batch, OPTION_REPLACE) -} - -// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -func (db *dbBase) BatchSave(table string, list List, batch int) (sql.Result, error) { - return db.batchInsert(table, list, batch, OPTION_SAVE) +// CURD操作:数据更新,统一采用sql预处理 +// data参数支持字符串或者关联数组类型,内部会自行做判断处理 +func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { + link, err := bs.db.Master() + if err != nil { + return nil, err + } + return bs.db.doUpdate(link, table, data, condition, args ...) } // CURD操作:数据更新,统一采用sql预处理 // data参数支持字符串或者关联数组类型,内部会自行做判断处理 -func (db *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { - var params []interface{} - var updates string - charl, charr := db.db.getChars() +func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition interface{}, args ...interface{}) (result sql.Result, err error) { + params := ([]interface{})(nil) + updates := "" + charl, charr := bs.db.getChars() refValue := reflect.ValueOf(data) if refValue.Kind() == reflect.Map { var fields []string @@ -351,44 +405,29 @@ func (db *dbBase) Update(table string, data interface{}, condition interface{}, for _, v := range args { params = append(params, gconv.String(v)) } - return db.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, db.formatCondition(condition)), params...) + if link == nil { + if link, err = bs.db.Master(); err != nil { + return nil, err + } + } + return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, formatCondition(condition)), params...) } // CURD操作:删除数据 -func (db *dbBase) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { - return db.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, db.formatCondition(condition)), args...) +func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) { + link, err := bs.db.Master() + if err != nil { + return nil, err + } + return bs.db.doDelete(link, table, condition, args ...) } -// 格式化SQL查询条件 -func (db *dbBase) formatCondition(condition interface{}) (where string) { - if reflect.ValueOf(condition).Kind() == reflect.Map { - ks := reflect.ValueOf(condition).MapKeys() - vs := reflect.ValueOf(condition) - for _, k := range ks { - key := gconv.String(k.Interface()) - value := gconv.String(vs.MapIndex(k).Interface()) - isNum := gstr.IsNumeric(value) - if len(where) > 0 { - where += " AND " - } - if isNum || value == "?" { - where += key + "=" + value - } else { - where += key + "='" + value + "'" - } - } - } else { - where += gconv.String(condition) - } - return +// CURD操作:删除数据 +func (bs *dbBase) doDelete(link dbLink, table string, condition interface{}, args ...interface{}) (result sql.Result, err error) { + return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, formatCondition(condition)), args...) } // 获得缓存对象 -func (db *dbBase) getCache() *gcache.Cache { - return db.cache +func (bs *dbBase) getCache() *gcache.Cache { + return bs.cache } - -// 记录执行的SQL -func (db *dbBase) putSql(s *Sql) { - db.sqls.Put(s) -} \ No newline at end of file diff --git a/g/database/gdb/gdb_config.go b/g/database/gdb/gdb_config.go index a4c9fb312..10ca88dec 100644 --- a/g/database/gdb/gdb_config.go +++ b/g/database/gdb/gdb_config.go @@ -123,19 +123,19 @@ func SetDefaultGroup (groupName string) { } // 设置数据库连接池中空闲链接的大小 -func (db *dbBase) SetMaxIdleConns(n int) { - db.maxIdleConnCount.Set(n) +func (bs *dbBase) SetMaxIdleConns(n int) { + bs.maxIdleConnCount.Set(n) } // 设置数据库连接池最大打开的链接数量 -func (db *dbBase) SetMaxOpenConns(n int) { - db.maxOpenConnCount.Set(n) +func (bs *dbBase) SetMaxOpenConns(n int) { + bs.maxOpenConnCount.Set(n) } // 设置数据库连接可重复利用的时间,超过该时间则被关闭废弃 // 如果 d <= 0 表示该链接会一直重复利用 -func (db *dbBase) SetConnMaxLifetime(n int) { - db.maxConnLifetime.Set(n) +func (bs *dbBase) SetConnMaxLifetime(n int) { + bs.maxConnLifetime.Set(n) } // 节点配置转换为字符串 @@ -150,14 +150,14 @@ func (node *ConfigNode) String() string { } // 是否开启调试服务 -func (db *dbBase) SetDebug(debug bool) { - db.debug.Set(debug) - if debug && db.sqls == nil { - db.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH) +func (bs *dbBase) SetDebug(debug bool) { + bs.debug.Set(debug) + if debug && bs.sqls == nil { + bs.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH) } } // 获取是否开启调试服务 -func (db *dbBase) getDebug() bool { - return db.debug.Val() +func (bs *dbBase) getDebug() bool { + return bs.debug.Val() } \ No newline at end of file diff --git a/g/database/gdb/gdb_func.go b/g/database/gdb/gdb_func.go index e6e8f103c..1d782c81d 100644 --- a/g/database/gdb/gdb_func.go +++ b/g/database/gdb/gdb_func.go @@ -13,65 +13,12 @@ import ( "gitee.com/johng/gf/g/container/gvar" "gitee.com/johng/gf/g/os/glog" "gitee.com/johng/gf/g/os/gtime" + "gitee.com/johng/gf/g/util/gconv" + "gitee.com/johng/gf/g/util/gstr" _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" - "strings" + "reflect" ) -type dbLink interface { - Query(query string, args ...interface{}) (*sql.Rows, error) - Exec(sql string, args ...interface{}) (sql.Result, error) - Prepare(sql string) (*sql.Stmt, error) -} - -// 数据库sql查询操作,主要执行查询 -func doQuery(db DB, link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) { - query = db.handleSqlBeforeExec(query) - if db.getDebug() { - mTime1 := gtime.Millisecond() - rows, err = link.Query(query, args...) - mTime2 := gtime.Millisecond() - s := &Sql{ - Sql : query, - Args : args, - Error : err, - Start : mTime1, - End : mTime2, - } - db.putSql(s) - printSql(s) - } else { - rows, err = link.Query(query, args ...) - } - if err == nil { - return rows, nil - } else { - err = formatError(err, query, args...) - } - return nil, err -} - -// 执行一条sql,并返回执行情况,主要用于非查询操作 -func doExec(db DB, link dbLink, query string, args ...interface{}) (result sql.Result, err error) { - query = db.handleSqlBeforeExec(query) - if db.getDebug() { - mTime1 := gtime.Millisecond() - result, err = link.Exec(query, args ...) - mTime2 := gtime.Millisecond() - s := &Sql{ - Sql : query, - Args : args, - Error : err, - Start : mTime1, - End : mTime2, - } - db.putSql(s) - printSql(s) - } else { - result, err = link.Exec(query, args ...) - } - return result, formatError(err, query, args...) -} - // 将数据查询的列表数据*sql.Rows转换为Result类型 func rowsToResult(rows *sql.Rows) (Result, error) { // 列名称列表 @@ -103,34 +50,31 @@ func rowsToResult(rows *sql.Rows) (Result, error) { return records, nil } -func formatInsertQuery(db DB, table string, data Map, option uint8) (string, []interface{}) { - var fields []string - var values []string - var params []interface{} - charl, charr := db.getChars() - for k, v := range data { - fields = append(fields, charl + k + charr) - values = append(values, "?") - params = append(params, v) - } - operation := getInsertOperationByOption(option) - updatestr := "" - if option == OPTION_SAVE { - var updates []string - for k, _ := range data { - updates = append(updates, - fmt.Sprintf("%s%s%s=VALUES(%s%s%s)", - charl, k, charr, - charl, k, charr, - ), - ) +// 格式化SQL查询条件 +func formatCondition(condition interface{}) (where string) { + if reflect.ValueOf(condition).Kind() == reflect.Map { + ks := reflect.ValueOf(condition).MapKeys() + vs := reflect.ValueOf(condition) + for _, k := range ks { + key := gconv.String(k.Interface()) + value := gconv.String(vs.MapIndex(k).Interface()) + isNum := gstr.IsNumeric(value) + if len(where) > 0 { + where += " AND " + } + if isNum || value == "?" { + where += key + "=" + value + } else { + where += key + "='" + value + "'" + } } - updatestr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) + } else { + where += gconv.String(condition) } - return fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s", - operation, table, strings.Join(fields, ","), - strings.Join(values, ","), updatestr), - params + if len(where) == 0 { + where = "1" + } + return } // 打印SQL对象(仅在debug=true时有效) @@ -163,7 +107,7 @@ func formatError(err error, query string, args ...interface{}) error { } // 根据insert选项获得操作名称 -func getInsertOperationByOption(option uint8) string { +func getInsertOperationByOption(option int) string { oper := "INSERT" switch option { case OPTION_REPLACE: diff --git a/g/database/gdb/gdb_model.go b/g/database/gdb/gdb_model.go index f15ea7562..8ce73c338 100644 --- a/g/database/gdb/gdb_model.go +++ b/g/database/gdb/gdb_model.go @@ -17,8 +17,8 @@ import ( // 数据库链式操作模型对象 type Model struct { - tx *Tx // 数据库事务对象 db DB // 数据库操作对象 + tx *TX // 数据库事务对象 tablesInit string // 初始化Model时的表名称(可以是多个) tables string // 数据库操作表 fields string // 操作字段 @@ -36,9 +36,9 @@ type Model struct { } // 链式操作,数据表字段,可支持多个表,以半角逗号连接 -func (db *dbBase) Table(tables string) (*Model) { - return &Model{ - db : db.db, +func (bs *dbBase) Table(tables string) (*Model) { + return &Model { + db : bs.db, tablesInit : tables, tables : tables, fields : "*", @@ -46,12 +46,12 @@ func (db *dbBase) Table(tables string) (*Model) { } // 链式操作,数据表字段,可支持多个表,以半角逗号连接 -func (db *dbBase) From(tables string) (*Model) { - return db.Table(tables) +func (bs *dbBase) From(tables string) (*Model) { + return bs.db.Table(tables) } // (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接 -func (tx *Tx) Table(tables string) (*Model) { +func (tx *TX) Table(tables string) (*Model) { return &Model{ db : tx.db, tx : tx, @@ -61,7 +61,7 @@ func (tx *Tx) Table(tables string) (*Model) { } // (事务)链式操作,数据表字段,可支持多个表,以半角逗号连接 -func (tx *Tx) From(tables string) (*Model) { +func (tx *TX) From(tables string) (*Model) { return tx.Table(tables) } @@ -100,7 +100,7 @@ func (md *Model) Fields(fields string) (*Model) { // 链式操作,condition,支持string & gdb.Map func (md *Model) Where(where interface{}, args ...interface{}) (*Model) { - md.where = md.db.formatCondition(where) + md.where = formatCondition(where) md.whereArgs = append(md.whereArgs, args...) // 支持 Where("uid", 1)这种格式 if len(args) == 1 && strings.Index(md.where , "?") < 0 { @@ -111,14 +111,14 @@ func (md *Model) Where(where interface{}, args ...interface{}) (*Model) { // 链式操作,添加AND条件到Where中 func (md *Model) And(where interface{}, args ...interface{}) (*Model) { - md.where += " AND " + md.db.formatCondition(where) + md.where += " AND " + formatCondition(where) md.whereArgs = append(md.whereArgs, args...) return md } // 链式操作,添加OR条件到Where中 func (md *Model) Or(where interface{}, args ...interface{}) (*Model) { - md.where += " OR " + md.db.formatCondition(where) + md.where += " OR " + formatCondition(where) md.whereArgs = append(md.whereArgs, args...) return md } @@ -220,9 +220,9 @@ func (md *Model) Replace() (result sql.Result, err error) { } } else if dataMap, ok := md.data.(Map); ok { if md.tx == nil { - return md.db.Insert(md.tables, dataMap) + return md.db.Replace(md.tables, dataMap) } else { - return md.tx.Insert(md.tables, dataMap) + return md.tx.Replace(md.tables, dataMap) } } return nil, errors.New("replacing into table with invalid data type") @@ -286,9 +286,6 @@ func (md *Model) Delete() (result sql.Result, err error) { } md.clear() }() - if md.where == "" { - return nil, errors.New("where is required while deleting") - } if md.tx == nil { return md.db.Delete(md.tables, md.where, md.whereArgs...) } else { @@ -320,13 +317,13 @@ func (md *Model) Cache(time int, name ... string) *Model { // 链式操作,select func (md *Model) Select() (Result, error) { - defer md.clear() - return md.getAll(md.getFormattedSql(), md.whereArgs...) + return md.All() } // 链式操作,查询所有记录 func (md *Model) All() (Result, error) { - return md.Select() + defer md.clear() + return md.getAll(md.getFormattedSql(), md.whereArgs...) } // 链式操作,查询单条记录 diff --git a/g/database/gdb/gdb_mssql.go b/g/database/gdb/gdb_mssql.go index 5532b1b40..ba1f7a8ab 100644 --- a/g/database/gdb/gdb_mssql.go +++ b/g/database/gdb/gdb_mssql.go @@ -28,12 +28,13 @@ type dbMssql struct { } // 创建SQL操作对象 -func (db *dbMssql) open(c *ConfigNode) (*sql.DB, error) { - var source string - if c.Linkinfo != "" { - source = c.Linkinfo +func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) { + source := "" + if config.Linkinfo != "" { + source = config.Linkinfo } else { - source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable", c.User, c.Pass, c.Host, c.Port, c.Name) + source = fmt.Sprintf("user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable", + config.User, config.Pass, config.Host, config.Port, config.Name) } if db, err := sql.Open("sqlserver", source); err == nil { return db, nil diff --git a/g/database/gdb/gdb_mysql.go b/g/database/gdb/gdb_mysql.go index 0c6283648..5ddfcc3c7 100644 --- a/g/database/gdb/gdb_mysql.go +++ b/g/database/gdb/gdb_mysql.go @@ -18,12 +18,12 @@ type dbMysql struct { } // 创建SQL操作对象,内部采用了lazy link处理 -func (db *dbMysql) open (c *ConfigNode) (*sql.DB, error) { +func (db *dbMysql) Open (config *ConfigNode) (*sql.DB, error) { var source string - if c.Linkinfo != "" { - source = c.Linkinfo + if config.Linkinfo != "" { + source = config.Linkinfo } else { - source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", c.User, c.Pass, c.Host, c.Port, c.Name) + source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true", config.User, config.Pass, config.Host, config.Port, config.Name) } if db, err := sql.Open("mysql", source); err == nil { return db, nil diff --git a/g/database/gdb/gdb_oracle.go b/g/database/gdb/gdb_oracle.go index 962be5959..965a4f5c1 100644 --- a/g/database/gdb/gdb_oracle.go +++ b/g/database/gdb/gdb_oracle.go @@ -27,12 +27,12 @@ type dbOracle struct { } // 创建SQL操作对象 -func (db *dbOracle) open(c *ConfigNode) (*sql.DB, error) { +func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) { var source string - if c.Linkinfo != "" { - source = c.Linkinfo + if config.Linkinfo != "" { + source = config.Linkinfo } else { - source = fmt.Sprintf("%s/%s@%s", c.User, c.Pass, c.Name) + source = fmt.Sprintf("%s/%s@%s", config.User, config.Pass, config.Name) } if db, err := sql.Open("oci8", source); err == nil { return db, nil diff --git a/g/database/gdb/gdb_pgsql.go b/g/database/gdb/gdb_pgsql.go index 874263465..13fc77cad 100644 --- a/g/database/gdb/gdb_pgsql.go +++ b/g/database/gdb/gdb_pgsql.go @@ -24,12 +24,12 @@ type dbPgsql struct { } // 创建SQL操作对象,内部采用了lazy link处理 -func (db *dbPgsql) open (c *ConfigNode) (*sql.DB, error) { +func (db *dbPgsql) Open (config *ConfigNode) (*sql.DB, error) { var source string - if c.Linkinfo != "" { - source = c.Linkinfo + if config.Linkinfo != "" { + source = config.Linkinfo } else { - source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", c.User, c.Pass, c.Host, c.Port, c.Name) + source = fmt.Sprintf("user=%s password=%s host=%s port=%s dbname=%s", config.User, config.Pass, config.Host, config.Port, config.Name) } if db, err := sql.Open("postgres", source); err == nil { return db, nil diff --git a/g/database/gdb/gdb_sqlite.go b/g/database/gdb/gdb_sqlite.go index 68810c517..70029bdc8 100644 --- a/g/database/gdb/gdb_sqlite.go +++ b/g/database/gdb/gdb_sqlite.go @@ -22,12 +22,12 @@ type dbSqlite struct { *dbBase } -func (db *dbSqlite) open(c *ConfigNode) (*sql.DB, error) { +func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) { var source string - if c.Linkinfo != "" { - source = c.Linkinfo + if config.Linkinfo != "" { + source = config.Linkinfo } else { - source = c.Name + source = config.Name } if db, err := sql.Open("sqlite3", source); err == nil { return db, nil diff --git a/g/database/gdb/gdb_transaction.go b/g/database/gdb/gdb_transaction.go index cd8e48976..b513189db 100644 --- a/g/database/gdb/gdb_transaction.go +++ b/g/database/gdb/gdb_transaction.go @@ -7,15 +7,9 @@ package gdb import ( - "fmt" - "errors" - "strings" - "reflect" "database/sql" - "gitee.com/johng/gf/g/os/gtime" - "gitee.com/johng/gf/g/util/gconv" + "gitee.com/johng/gf/g/util/gregex" _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" - "gitee.com/johng/gf/g/container/gvar" ) // 数据库事务对象 @@ -37,49 +31,27 @@ func (tx *TX) Rollback() error { // (事务)数据库sql查询操作,主要执行查询 func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { - return doQuery(tx.db, tx.tx, query, args...) + return tx.db.doQuery(tx.tx, query, args...) } // (事务)执行一条sql,并返回执行情况,主要用于非查询操作 func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) { - return doExec(tx.db, tx.tx, query, args...) + return tx.db.doExec(tx.tx, query, args...) +} + +// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 +func (tx *TX) Prepare(query string) (*sql.Stmt, error) { + return tx.db.doPrepare(tx.tx, query) } // 数据库查询,获取查询结果集,以列表结构返回 func (tx *TX) GetAll(query string, args ...interface{}) (Result, error) { - // 执行sql rows, err := tx.Query(query, args ...) if err != nil || rows == nil { return nil, err } - // 列名称列表 - columns, err := rows.Columns() - if err != nil { - return nil, err - } - // 返回结构组装 - values := make([]sql.RawBytes, len(columns)) - scanArgs := make([]interface{}, len(values)) - records := make(Result, 0) - for i := range values { - scanArgs[i] = &values[i] - } - for rows.Next() { - err = rows.Scan(scanArgs...) - if err != nil { - return records, err - } - row := make(Record) - // 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址 - for i, col := range values { - v := make([]byte, len(col)) - copy(v, col) - row[columns[i]] = gvar.New(v, false) - } - //fmt.Printf("%p\n", row["typeid"]) - records = append(records, row) - } - return records, nil + defer rows.Close() + return rowsToResult(rows) } // 数据库查询,获取查询结果记录,以关联数组结构返回 @@ -103,7 +75,6 @@ func (tx *TX) GetStruct(obj interface{}, query string, args ...interface{}) erro return one.ToStruct(obj) } - // 数据库查询,获取查询字段值 func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) { one, err := tx.GetOne(query, args ...) @@ -118,185 +89,54 @@ func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) { // 数据库查询,获取查询数量 func (tx *TX) GetCount(query string, args ...interface{}) (int, error) { - val, err := tx.GetValue(query, args ...) + if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) { + query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query) + } + value, err := tx.GetValue(query, args ...) if err != nil { return 0, err } - return gconv.Int(val), nil -} - -// 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作 -func (tx *TX) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) { - s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables) - if condition != nil { - s += fmt.Sprintf("WHERE %s ", tx.db.formatCondition(condition)) - } - if len(groupBy) > 0 { - s += fmt.Sprintf("GROUP BY %s ", groupBy) - } - if len(orderBy) > 0 { - s += fmt.Sprintf("ORDER BY %s ", orderBy) - } - if limit > 0 { - s += fmt.Sprintf("LIMIT %d,%d ", first, limit) - } - return tx.GetAll(s, args ... ) -} - -// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 -func (tx *TX) Prepare(query string) (*sql.Stmt, error) { - return tx.tx.Prepare(query) -} - -// insert、replace, save, ignore操作 -// 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 -// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做 -func (tx *TX) insert(table string, data Map, option uint8) (sql.Result, error) { - var keys []string - var values []string - var params []interface{} - for k, v := range data { - keys = append(keys, tx.db.charl + k + tx.db.charr) - values = append(values, "?") - params = append(params, v) - } - operation := tx.db.getInsertOperationByOption(option) - updatestr := "" - if option == OPTION_SAVE { - var updates []string - for k, _ := range data { - updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k)) - } - updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) - } - return tx.Exec( - fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s", - operation, table, strings.Join(keys, ","), - strings.Join(values, ","), - updatestr), - params... - ) + return value.Int(), nil } // CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 func (tx *TX) Insert(table string, data Map) (sql.Result, error) { - return tx.insert(table, data, OPTION_INSERT) + return tx.db.doInsert(tx.tx, table, data, OPTION_INSERT) } // CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 func (tx *TX) Replace(table string, data Map) (sql.Result, error) { - return tx.insert(table, data, OPTION_REPLACE) + return tx.db.doInsert(tx.tx, table, data, OPTION_REPLACE) } // CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 func (tx *TX) Save(table string, data Map) (sql.Result, error) { - return tx.insert(table, data, OPTION_SAVE) -} - -// 批量写入数据 -func (tx *TX) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) { - var keys []string - var values []string - var bvalues []string - var params []interface{} - var result sql.Result - var size = len(list) - // 判断长度 - if size < 1 { - return result, errors.New("empty data list") - } - // 首先获取字段名称及记录长度 - for k, _ := range list[0] { - keys = append(keys, k) - values = append(values, "?") - } - keyStr := tx.db.charl + strings.Join(keys, tx.db.charl + "," + tx.db.charr) + tx.db.charr - valueHolderStr := "(" + strings.Join(values, ",") + ")" - // 操作判断 - operation := tx.db.getInsertOperationByOption(option) - updatestr := "" - if option == OPTION_SAVE { - var updates []string - for _, k := range keys { - updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k)) - } - updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) - } - // 构造批量写入数据格式(注意map的遍历是无序的) - for i := 0; i < size; i++ { - for _, k := range keys { - params = append(params, list[i][k]) - } - bvalues = append(bvalues, valueHolderStr) - if len(bvalues) == batch { - r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", - operation, table, keyStr, strings.Join(bvalues, ","), - updatestr), - params...) - if err != nil { - return result, err - } - result = r - params = params[:0] - bvalues = bvalues[:0] - } - } - // 处理最后不构成指定批量的数据 - if len(bvalues) > 0 { - r, err := tx.Exec(fmt.Sprintf("%s INTO %s(%s) VALUES%s %s", - operation, table, keyStr, strings.Join(bvalues, ","), - updatestr), - params...) - if err != nil { - return result, err - } - result = r - } - return result, nil + return tx.db.doInsert(tx.tx, table, data, OPTION_SAVE) } // CURD操作:批量数据指定批次量写入 func (tx *TX) BatchInsert(table string, list List, batch int) (sql.Result, error) { - return tx.batchInsert(table, list, batch, OPTION_INSERT) + return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_INSERT) } // CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 func (tx *TX) BatchReplace(table string, list List, batch int) (sql.Result, error) { - return tx.batchInsert(table, list, batch, OPTION_REPLACE) + return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_REPLACE) } // CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 func (tx *TX) BatchSave(table string, list List, batch int) (sql.Result, error) { - return tx.batchInsert(table, list, batch, OPTION_SAVE) + return tx.db.doBatchInsert(tx.tx, table, list, batch, OPTION_SAVE) } // CURD操作:数据更新,统一采用sql预处理 // data参数支持字符串或者关联数组类型,内部会自行做判断处理 func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { - var params []interface{} - var updates string - refValue := reflect.ValueOf(data) - if refValue.Kind() == reflect.Map { - var fields []string - keys := refValue.MapKeys() - for _, k := range keys { - fields = append(fields, fmt.Sprintf("%s%s%s=?", tx.db.charl, k, tx.db.charr)) - params = append(params, gconv.String(refValue.MapIndex(k).Interface())) - updates = strings.Join(fields, ",") - } - } else { - updates = gconv.String(data) - } - for _, v := range args { - params = append(params, gconv.String(v)) - } - return tx.Exec(fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, tx.db.formatCondition(condition)), params...) + return tx.db.doUpdate(tx.tx, table, data, condition, args ...) } // CURD操作:删除数据 func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { - return tx.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, tx.db.formatCondition(condition)), args...) + return tx.db.doDelete(tx.tx, table, condition, args ...) } diff --git a/g/database/gdb/gdb_unit_1_test.go b/g/database/gdb/gdb_unit_1_test.go new file mode 100644 index 000000000..1755bf63b --- /dev/null +++ b/g/database/gdb/gdb_unit_1_test.go @@ -0,0 +1,215 @@ +package gdb_test + +import ( + "gitee.com/johng/gf/g" + "gitee.com/johng/gf/g/database/gdb" + "gitee.com/johng/gf/g/os/gtime" + "gitee.com/johng/gf/g/util/gtest" + "testing" +) + +var ( + // 数据库对象/接口 + db gdb.DB +) + +// 初始化连接参数。 +// 测试前需要修改连接参数。 +func init() { + gdb.AddDefaultConfigNode(gdb.ConfigNode{ + Host: "127.0.0.1", + Port: "3306", + User: "root", + Pass: "12345678", + Name: "test", + Type: "mysql", + Role: "master", + Charset: "utf8", + Priority: 1, + }) + if r, err := gdb.New(); err != nil { + gtest.Fatal(err) + } else { + db = r + } + // 准备测试数据结构 + if _, err := db.Exec("DROP TABLE `user`"); err != nil { + gtest.Fatal(err) + } + if _, err := db.Exec(` + CREATE TABLE user ( + id int(10) unsigned NOT NULL AUTO_INCREMENT COMMENT '用户ID', + passport varchar(45) NOT NULL COMMENT '账号', + password char(32) NOT NULL COMMENT '密码', + nickname varchar(45) NOT NULL COMMENT '昵称', + create_time timestamp NOT NULL COMMENT '创建时间/注册时间', + PRIMARY KEY (id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8; + `); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_Query(t *testing.T) { + if _, err := db.Query("SELECT ?", 1); err != nil { + gtest.Fatal(err) + } + if _, err := db.Query("ERROR"); err == nil { + gtest.Fatal("FAIL") + } +} + +func TestDbBase_Exec(t *testing.T) { + if _, err := db.Exec("SELECT ?", 1); err != nil { + gtest.Fatal(err) + } + if _, err := db.Exec("ERROR"); err == nil { + gtest.Fatal("FAIL") + } +} + +func TestDbBase_Prepare(t *testing.T) { + st, err := db.Prepare("SELECT 100") + if err != nil { + gtest.Fatal(err) + } + rows, err := st.Query() + if err != nil { + gtest.Fatal(err) + } + array, err := rows.Columns() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(array[0], "100") + if err := rows.Close(); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_Insert(t *testing.T) { + if _, err := db.Insert("user", g.Map{ + "id" : 1, + "passport" : "t1", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T1", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_BatchInsert(t *testing.T) { + if _, err := db.BatchInsert("user", g.List { + { + "id" : 2, + "passport" : "t2", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T2", + "create_time" : gtime.Now().String(), + }, + { + "id" : 3, + "passport" : "t3", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T3", + "create_time" : gtime.Now().String(), + }, + }, 10); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_Save(t *testing.T) { + if _, err := db.Save("user", g.Map{ + "id" : 1, + "passport" : "t1", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T11", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_Replace(t *testing.T) { + if _, err := db.Save("user", g.Map{ + "id" : 1, + "passport" : "t1", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T111", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } +} + +func TestDbBase_Update(t *testing.T) { + if result, err := db.Update("user", "create_time='2010-10-10 00:00:01'", "id=3"); err != nil { + gtest.Fatal(err) + } else { + n, _ := result.RowsAffected() + gtest.Assert(n, 1) + } +} + +func TestDbBase_GetAll(t *testing.T) { + if result, err := db.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(len(result), 1) + } +} + +func TestDbBase_GetOne(t *testing.T) { + if record, err := db.GetOne("SELECT * FROM user WHERE passport=?", "t1"); err != nil { + gtest.Fatal(err) + } else { + if record == nil { + gtest.Fatal("FAIL") + } + gtest.Assert(record["nickname"].String(), "T111") + } +} + +func TestDbBase_GetValue(t *testing.T) { + if value, err := db.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.Int(), 3) + } +} + +func TestDbBase_GetCount(t *testing.T) { + if count, err := db.GetCount("SELECT * FROM user"); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(count, 3) + } +} + +func TestDbBase_GetStruct(t *testing.T) { + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime gtime.Time + } + user := new(User) + if err := db.GetStruct(user, "SELECT * FROM user WHERE id=?", 3); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(user.CreateTime.String(), "2010-10-10 00:00:01") + } +} + +func TestDbBase_Delete(t *testing.T) { + if result, err := db.Delete("user", nil); err != nil { + gtest.Fatal(err) + } else { + n, _ := result.RowsAffected() + gtest.Assert(n, 3) + } +} + diff --git a/g/database/gdb/gdb_unit_2_test.go b/g/database/gdb/gdb_unit_2_test.go new file mode 100644 index 000000000..48bdb8e77 --- /dev/null +++ b/g/database/gdb/gdb_unit_2_test.go @@ -0,0 +1,177 @@ +package gdb_test + +import ( + "gitee.com/johng/gf/g" + "gitee.com/johng/gf/g/os/gtime" + "gitee.com/johng/gf/g/util/gtest" + "testing" +) + +func TestModel_Insert(t *testing.T) { + result, err := db.Table("user").Data(g.Map{ + "id" : 1, + "passport" : "t1", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T1", + "create_time" : gtime.Now().String(), + }).Insert() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.LastInsertId() + gtest.Assert(n, 1) +} + +func TestModel_Batch(t *testing.T) { + result, err := db.Table("user").Data(g.List{ + { + "id" : 2, + "passport" : "t2", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T2", + "create_time" : gtime.Now().String(), + }, + { + "id" : 3, + "passport" : "t3", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T3", + "create_time" : gtime.Now().String(), + }, + }).Batch(10).Insert() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.RowsAffected() + gtest.Assert(n, 2) +} + +func TestModel_Replace(t *testing.T) { + result, err := db.Table("user").Data(g.Map{ + "id" : 1, + "passport" : "t11", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T11", + "create_time" : gtime.Now().String(), + }).Replace() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.RowsAffected() + gtest.Assert(n, 2) +} + +func TestModel_Save(t *testing.T) { + result, err := db.Table("user").Data(g.Map{ + "id" : 1, + "passport" : "t111", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T111", + "create_time" : gtime.Now().String(), + }).Save() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.RowsAffected() + gtest.Assert(n, 2) +} + +func TestModel_Update(t *testing.T) { + result, err := db.Table("user").Data("passport", "t22").Where("passport=?", "t2").Update() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.RowsAffected() + gtest.Assert(n, 1) +} + +func TestModel_All(t *testing.T) { + result, err := db.Table("user").All() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(len(result), 3) +} + +func TestModel_One(t *testing.T) { + record, err := db.Table("user").Where("id", 1).One() + if err != nil { + gtest.Fatal(err) + } + if record == nil { + gtest.Fatal("FAIL") + } + gtest.Assert(record["nickname"].String(), "T111") +} + +func TestModel_Value(t *testing.T) { + value, err := db.Table("user").Fields("nickname").Where("id", 1).Value() + if err != nil { + gtest.Fatal(err) + } + if value == nil { + gtest.Fatal("FAIL") + } + gtest.Assert(value.String(), "T111") +} + +func TestModel_Count(t *testing.T) { + count, err := db.Table("user").Count() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(count, 3) +} + +func TestModel_Select(t *testing.T) { + result, err := db.Table("user").Select() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(len(result), 3) +} + +func TestModel_Struct(t *testing.T) { + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime gtime.Time + } + user := new(User) + err := db.Table("user").Where("id=1").Struct(user) + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(user.NickName, "T111") +} + +func TestModel_OrderBy(t *testing.T) { + result, err := db.Table("user").OrderBy("id DESC").Select() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(len(result), 3) + gtest.Assert(result[0]["nickname"].String(), "T3") +} + +func TestModel_GroupBy(t *testing.T) { + result, err := db.Table("user").GroupBy("id").Select() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(len(result), 3) + gtest.Assert(result[0]["nickname"].String(), "T111") +} + +func TestModel_Delete(t *testing.T) { + result, err := db.Table("user").Delete() + if err != nil { + gtest.Fatal(err) + } + n, _ := result.RowsAffected() + gtest.Assert(n, 3) +} + + diff --git a/g/database/gdb/gdb_unit_3_test.go b/g/database/gdb/gdb_unit_3_test.go new file mode 100644 index 000000000..3180c906d --- /dev/null +++ b/g/database/gdb/gdb_unit_3_test.go @@ -0,0 +1,372 @@ +package gdb_test + +import ( + "gitee.com/johng/gf/g" + "gitee.com/johng/gf/g/os/gtime" + "gitee.com/johng/gf/g/util/gtest" + "testing" +) + +func TestTX_Query(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if rows, err := tx.Query("SELECT ?", 1); err != nil { + gtest.Fatal(err) + } else { + rows.Close() + } + if _, err := tx.Query("ERROR"); err == nil { + gtest.Fatal("FAIL") + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Exec(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.Exec("SELECT ?", 1); err != nil { + gtest.Fatal(err) + } + if _, err := tx.Exec("ERROR"); err == nil { + gtest.Fatal("FAIL") + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Commit(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Rollback(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if err := tx.Rollback(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Prepare(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + st, err := tx.Prepare("SELECT 100") + if err != nil { + gtest.Fatal(err) + } + rows, err := st.Query() + if err != nil { + gtest.Fatal(err) + } + array, err := rows.Columns() + if err != nil { + gtest.Fatal(err) + } + gtest.Assert(array[0], "100") + if err := rows.Close(); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Insert(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.Insert("user", g.Map { + "id" : 1, + "passport" : "t1", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T1", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + if n, err := db.Table("user").Count(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(n, 1) + } +} + +func TestTX_BatchInsert(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.BatchInsert("user", g.List { + { + "id" : 2, + "passport" : "t", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T2", + "create_time" : gtime.Now().String(), + }, + { + "id" : 3, + "passport" : "t3", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T3", + "create_time" : gtime.Now().String(), + }, + }, 10); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + if n, err := db.Table("user").Count(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(n, 3) + } +} + +func TestTX_BatchReplace(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.BatchReplace("user", g.List { + { + "id" : 2, + "passport" : "t2", + "password" : "p2", + "nickname" : "T2", + "create_time" : gtime.Now().String(), + }, + { + "id" : 4, + "passport" : "t4", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T4", + "create_time" : gtime.Now().String(), + }, + }, 10); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + // 数据数量 + if n, err := db.Table("user").Count(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(n, 4) + } + // 检查replace后的数值 + if value, err := db.Table("user").Fields("password").Where("id", 2).Value(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.String(), "p2") + } +} + +func TestTX_BatchSave(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.BatchSave("user", g.List { + { + "id" : 4, + "passport" : "t4", + "password" : "p4", + "nickname" : "T4", + "create_time" : gtime.Now().String(), + }, + }, 10); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + // 数据数量 + if n, err := db.Table("user").Count(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(n, 4) + } + // 检查replace后的数值 + if value, err := db.Table("user").Fields("password").Where("id", 4).Value(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.String(), "p4") + } +} + +func TestTX_Replace(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.Replace("user", g.Map { + "id" : 1, + "passport" : "t11", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T11", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } + if err := tx.Rollback(); err != nil { + gtest.Fatal(err) + } + if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.String(), "T1") + } +} + +func TestTX_Save(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.Save("user", g.Map { + "id" : 1, + "passport" : "t11", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T11", + "create_time" : gtime.Now().String(), + }); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + if value, err := db.Table("user").Fields("nickname").Where("id", 1).Value(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.String(), "T11") + } +} + +func TestTX_GetAll(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if result, err := tx.GetAll("SELECT * FROM user WHERE id=?", 1); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(len(result), 1) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_GetOne(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if record, err := tx.GetOne("SELECT * FROM user WHERE passport=?", "t2"); err != nil { + gtest.Fatal(err) + } else { + if record == nil { + gtest.Fatal("FAIL") + } + gtest.Assert(record["nickname"].String(), "T2") + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_GetValue(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if value, err := tx.GetValue("SELECT id FROM user WHERE passport=?", "t3"); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(value.Int(), 3) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_GetCount(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if count, err := tx.GetCount("SELECT * FROM user"); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(count, 4) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_GetStruct(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + type User struct { + Id int + Passport string + Password string + NickName string + CreateTime gtime.Time + } + user := new(User) + if err := tx.GetStruct(user, "SELECT * FROM user WHERE id=?", 1); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(user.NickName, "T11") + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } +} + +func TestTX_Delete(t *testing.T) { + tx, err := db.Begin() + if err != nil { + gtest.Fatal(err) + } + if _, err := tx.Delete("user", nil); err != nil { + gtest.Fatal(err) + } + if err := tx.Commit(); err != nil { + gtest.Fatal(err) + } + if n, err := db.Table("user").Count(); err != nil { + gtest.Fatal(err) + } else { + gtest.Assert(n, 0) + } +} + + diff --git a/g/frame/gins/gins.go b/g/frame/gins/gins.go index 0c8ba867c..66598e940 100644 --- a/g/frame/gins/gins.go +++ b/g/frame/gins/gins.go @@ -120,7 +120,7 @@ func Config(file...string) *gcfg.Config { } // 数据库操作对象,使用了连接池 -func Database(name...string) *gdb.Db { +func Database(name...string) gdb.DB { config := Config() group := gdb.DEFAULT_GROUP_NAME if len(name) > 0 { @@ -195,7 +195,7 @@ func Database(name...string) *gdb.Db { return nil }) if db != nil { - return db.(*gdb.Db) + return db.(gdb.DB) } return nil } diff --git a/g/g_object.go b/g/g_object.go index ff298fdb4..012cf1a16 100644 --- a/g/g_object.go +++ b/g/g_object.go @@ -44,12 +44,12 @@ func Config(file...string) *gcfg.Config { } // 数据库操作对象,使用了连接池 -func Database(name...string) *gdb.Db { +func Database(name...string) gdb.DB { return gins.Database(name...) } // (别名)Database -func DB(name...string) *gdb.Db { +func DB(name...string) gdb.DB { return gins.Database(name...) } diff --git a/g/os/glog/glog_logger.go b/g/os/glog/glog_logger.go index 92959f2b1..4b022cf48 100644 --- a/g/os/glog/glog_logger.go +++ b/g/os/glog/glog_logger.go @@ -31,7 +31,7 @@ type Logger struct { file *gtype.String // 日志文件名称格式 level *gtype.Int // 日志输出等级 btSkip *gtype.Int // 错误产生时的backtrace回调信息skip条数 - btEnabled *gtype.Bool // 是否当打印错误时同时开启backtrace打印 + btStatus *gtype.Int // 是否当打印错误时同时开启backtrace打印(默认-1,表示默认打印逻辑 - 错误才打印) printHeader *gtype.Bool // 是否不打印前缀信息(时间,级别等) alsoStdPrint *gtype.Bool // 控制台打印开关,当输出到文件/自定义输出时也同时打印到终端 } @@ -65,7 +65,7 @@ func New() *Logger { file : gtype.NewString(gDEFAULT_FILE_FORMAT), level : gtype.NewInt(defaultLevel.Val()), btSkip : gtype.NewInt(), - btEnabled : gtype.NewBool(true), + btStatus : gtype.NewInt(-1), printHeader : gtype.NewBool(true), alsoStdPrint : gtype.NewBool(true), } @@ -80,7 +80,7 @@ func (l *Logger) Clone() *Logger { file : l.file.Clone(), level : l.level.Clone(), btSkip : l.btSkip.Clone(), - btEnabled : l.btEnabled.Clone(), + btStatus : l.btStatus.Clone(), printHeader : l.printHeader.Clone(), alsoStdPrint : l.alsoStdPrint.Clone(), } @@ -106,7 +106,12 @@ func (l *Logger) SetDebug(debug bool) { } func (l *Logger) SetBacktrace(enabled bool) { - l.btEnabled.Set(enabled) + if enabled { + l.btStatus.Set(1) + } else { + l.btStatus.Set(0) + } + } // 设置BacktraceSkip @@ -214,27 +219,37 @@ func (l *Logger) doStdLockPrint(std io.Writer, s string) { // 核心打印数据方法(标准输出) func (l *Logger) stdPrint(s string) { + if l.btStatus.Val() == 1 { + s = l.appendBacktrace(s) + } l.print(os.Stdout, s) } // 核心打印数据方法(标准错误) func (l *Logger) errPrint(s string) { // 记录调用回溯信息 - if l.btEnabled.Val() { - tracestr := l.GetBacktrace() - if tracestr != "" { - backtrace := "Backtrace:" + ln + tracestr - if s[len(s) - 1] == byte('\n') { - s = s + backtrace + ln - } else { - s = s + ln + backtrace + ln - } - } + status := l.btStatus.Val() + if status == -1 || status == 1 { + s = l.appendBacktrace(s) } // 防止串日志情况,这里不使用stderr,而是使用stdout l.print(os.Stdout, s) } +// 输出内容中添加回溯信息 +func (l *Logger) appendBacktrace(s string) string { + trace := l.GetBacktrace() + if trace != "" { + backtrace := "Backtrace:" + ln + trace + if s[len(s) - 1] == byte('\n') { + s = s + backtrace + ln + } else { + s = s + ln + backtrace + ln + } + } + return s +} + // 直接打印回溯信息,参数skip表示调用端往上多少级开始回溯 func (l *Logger) PrintBacktrace(skip...int) { l.Println(l.GetBacktrace(skip...)) diff --git a/g/util/gconv/gconv_struct.go b/g/util/gconv/gconv_struct.go index 1261065f1..d0cd15dbf 100644 --- a/g/util/gconv/gconv_struct.go +++ b/g/util/gconv/gconv_struct.go @@ -157,12 +157,10 @@ func bindVarToStruct(elem reflect.Value, name string, value interface{}) (err er structFieldValue := elem.FieldByName(name) // 键名与对象属性匹配检测,map中如果有struct不存在的属性,那么不做处理,直接return if !structFieldValue.IsValid() { - //return errors.New(fmt.Sprintf(`invalid struct attribute of name "%s"`, name)) return nil } // CanSet的属性必须为公开属性(首字母大写) if !structFieldValue.CanSet() { - //return errors.New(fmt.Sprintf(`struct attribute of name "%s" cannot be set`, name)) return nil } // 必须将value转换为struct属性的数据类型,这里必须用到gconv包 @@ -181,12 +179,10 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e structFieldValue := elem.FieldByIndex([]int{index}) // 键名与对象属性匹配检测 if !structFieldValue.IsValid() { - //return errors.New(fmt.Sprintf("invalid struct attribute at index %d", index)) return nil } // CanSet的属性必须为公开属性(首字母大写) if !structFieldValue.CanSet() { - //return errors.New(fmt.Sprintf("struct attribute cannot be set at index %d", index)) return nil } // 必须将value转换为struct属性的数据类型,这里必须用到gconv包 diff --git a/g/util/gtest/gtest.go b/g/util/gtest/gtest.go new file mode 100644 index 000000000..2b054afa1 --- /dev/null +++ b/g/util/gtest/gtest.go @@ -0,0 +1,30 @@ +// Copyright 2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://gitee.com/johng/gf. + +// Package gtest provides useful test utils. +// 测试模块. +package gtest + +import ( + "fmt" + "gitee.com/johng/gf/g/os/glog" + "gitee.com/johng/gf/g/util/gconv" + "os" +) + +// 断言判断 +func Assert(value, expect interface{}) { + if gconv.String(value) != gconv.String(expect) { + glog.Backtrace(true, 1).Printfln(`[ASSERT] VALUE: %v, EXPECT: %v`, value, expect) + os.Exit(1) + } +} + +// 提示错误并退出 +func Fatal(message...interface{}) { + glog.Backtrace(true, 1).Println(`[FATAL] `, fmt.Sprint(message...)) + os.Exit(1) +} \ No newline at end of file diff --git a/geg/database/orm/mysql/gdb.go b/geg/database/orm/mysql/gdb.go index d951d1821..8c276459d 100644 --- a/geg/database/orm/mysql/gdb.go +++ b/geg/database/orm/mysql/gdb.go @@ -9,7 +9,7 @@ import ( // 本文件用于gf框架的mysql数据库操作示例,不作为单元测试使用 -var db *gdb.Db +var db gdb.DB // 初始化配置及创建数据库 func init () { @@ -17,7 +17,7 @@ func init () { Host : "127.0.0.1", Port : "3306", User : "root", - Pass : "8692651", + Pass : "12345678", Name : "test", Type : "mysql", Role : "master", diff --git a/geg/database/orm/mysql/gdb_pool.go b/geg/database/orm/mysql/gdb_pool.go index 8f4c7dabf..ea99415d1 100644 --- a/geg/database/orm/mysql/gdb_pool.go +++ b/geg/database/orm/mysql/gdb_pool.go @@ -1,28 +1,16 @@ package main import ( - "gitee.com/johng/gf/g/database/gdb" + "gitee.com/johng/gf/g" "time" ) func main() { - gdb.AddDefaultConfigNode(gdb.ConfigNode { - Host : "127.0.0.1", - Port : "3306", - User : "root", - Pass : "12345678", - Name : "test", - Type : "mysql", - Role : "master", - Charset : "utf8", - MaxIdleConnCount : 10, - MaxOpenConnCount : 10, - MaxConnLifetime : 10, - }) - db, err := gdb.New() - if err != nil { - panic(err) - } + db := g.DB() + db.SetMaxIdleConns(10) + db.SetMaxOpenConns(10) + db.SetConnMaxLifetime(10) + // 开启调试模式,以便于记录所有执行的SQL db.SetDebug(true) diff --git a/geg/database/orm/mysql/gdb_value.go b/geg/database/orm/mysql/gdb_value.go index 7851f6dd6..82a48b0c0 100644 --- a/geg/database/orm/mysql/gdb_value.go +++ b/geg/database/orm/mysql/gdb_value.go @@ -2,24 +2,11 @@ package main import ( "fmt" - "gitee.com/johng/gf/g/database/gdb" + "gitee.com/johng/gf/g" ) func main() { - gdb.AddDefaultConfigNode(gdb.ConfigNode { - Host : "192.168.1.11", - Port : "3306", - User : "root", - Pass : "8692651", - Name : "test", - Type : "mysql", - Role : "master", - Charset : "utf8", - }) - db, err := gdb.New() - if err != nil { - panic(err) - } + db := g.DB() // 开启调试模式,以便于记录所有执行的SQL db.SetDebug(true) diff --git a/geg/other/test/test_test.go b/geg/other/test/test_test.go deleted file mode 100644 index 30da58b5a..000000000 --- a/geg/other/test/test_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package test - -import ( - "strings" - "testing" -) - -var s = "/name/john///." -var c = "./" - -func t1(s string) string { - if len(s) == 0 { - return s - } - for _, cut := range c { - for s[len(s) - 1] == uint8(cut) { - s = s[:len(s) - 1] - if len(s) == 0 { - return s - } - } - } - return s -} - -func t2(s string) string { - return strings.TrimRight(s, c) -} - -func Benchmark_t1(b *testing.B) { - for i := 0; i < b.N; i++ { - t1(s) - } -} - -func Benchmark_t2(b *testing.B) { - for i := 0; i < b.N; i++ { - t2(s) - } -} - diff --git a/geg/other/test2.go b/geg/other/test2.go index a1aad8c3c..f6a556927 100644 --- a/geg/other/test2.go +++ b/geg/other/test2.go @@ -1,25 +1,15 @@ package main -import "fmt" +import ( + "fmt" + "gitee.com/johng/gf/g/util/gregex" +) -type User struct { - Uid int -} -func New() *User { - return &User{ - 100, - } -} - -func (user *User) Clear() { - user = New() -} func main() { - user := New() - user.Uid = 10000 - fmt.Println(user) - user.Clear() - fmt.Println(user) + query := "select * from user" + q, err := gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query) + fmt.Println(err) + fmt.Println(q) } \ No newline at end of file diff --git a/third/github.com/go-sql-driver/mysql/.travis.yml b/third/github.com/go-sql-driver/mysql/.travis.yml index 47dd289a0..75505f144 100644 --- a/third/github.com/go-sql-driver/mysql/.travis.yml +++ b/third/github.com/go-sql-driver/mysql/.travis.yml @@ -4,6 +4,7 @@ go: - 1.8.x - 1.9.x - 1.10.x + - 1.11.x - master before_install: diff --git a/third/github.com/go-sql-driver/mysql/AUTHORS b/third/github.com/go-sql-driver/mysql/AUTHORS index fbe4ec442..5ce4f7eca 100644 --- a/third/github.com/go-sql-driver/mysql/AUTHORS +++ b/third/github.com/go-sql-driver/mysql/AUTHORS @@ -35,6 +35,7 @@ Hanno Braun Henri Yandell Hirotaka Yamamoto ICHINOSE Shogo +Ilia Cimpoes INADA Naoki Jacek Szwec James Harr @@ -72,7 +73,9 @@ Shuode Li Soroush Pour Stan Putrya Stanley Gunawan +Steven Hartland Thomas Wodarek +Tom Jenkinson Xiangyu Hu Xiaobing Jiang Xiuming Chen @@ -88,3 +91,4 @@ Keybase Inc. Percona LLC Pivotal Inc. Stripe Inc. +Multiplay Ltd. diff --git a/third/github.com/go-sql-driver/mysql/README.md b/third/github.com/go-sql-driver/mysql/README.md index bda11031e..6e4816e63 100644 --- a/third/github.com/go-sql-driver/mysql/README.md +++ b/third/github.com/go-sql-driver/mysql/README.md @@ -58,7 +58,7 @@ _Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`: ```go import "database/sql" -import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" +import _ "github.com/go-sql-driver/mysql" db, err := sql.Open("mysql", "user:password@/dbname") ``` @@ -328,11 +328,11 @@ Timeout for establishing connections, aka dial timeout. The value must be a deci ``` Type: bool / string -Valid Values: true, false, skip-verify, +Valid Values: true, false, skip-verify, preferred, Default: false ``` -`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). +`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use `preferred` to use TLS only when advertised by the server, this is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). ##### `writeTimeout` @@ -431,7 +431,7 @@ See [context support in the database/sql package](https://golang.org/doc/go1.8#d ### `LOAD DATA LOCAL INFILE` support For this feature you need direct access to the package. Therefore you must change the import path (no `_`): ```go -import "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" +import "github.com/go-sql-driver/mysql" ``` Files must be whitelisted by registering them with `mysql.RegisterLocalFile(filepath)` (recommended) or the Whitelist check must be deactivated by using the DSN parameter `allowAllFiles=true` ([*Might be insecure!*](http://dev.mysql.com/doc/refman/5.7/en/load-data-local.html)). diff --git a/third/github.com/go-sql-driver/mysql/auth.go b/third/github.com/go-sql-driver/mysql/auth.go index 2f61ecd4f..fec7040d4 100644 --- a/third/github.com/go-sql-driver/mysql/auth.go +++ b/third/github.com/go-sql-driver/mysql/auth.go @@ -234,64 +234,64 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro if err != nil { return err } - return mc.writeAuthSwitchPacket(enc, false) + return mc.writeAuthSwitchPacket(enc) } -func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { +func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { switch plugin { case "caching_sha2_password": authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) - return authResp, false, nil + return authResp, nil case "mysql_old_password": if !mc.cfg.AllowOldPasswords { - return nil, false, ErrOldPassword + return nil, ErrOldPassword } // Note: there are edge cases where this should work but doesn't; // this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 - authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) - return authResp, true, nil + authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) + return authResp, nil case "mysql_clear_password": if !mc.cfg.AllowCleartextPasswords { - return nil, false, ErrCleartextPassword + return nil, ErrCleartextPassword } // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - return []byte(mc.cfg.Passwd), true, nil + return append([]byte(mc.cfg.Passwd), 0), nil case "mysql_native_password": if !mc.cfg.AllowNativePasswords { - return nil, false, ErrNativePassword + return nil, ErrNativePassword } // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html // Native password authentication only need and will need 20-byte challenge. authResp := scramblePassword(authData[:20], mc.cfg.Passwd) - return authResp, false, nil + return authResp, nil case "sha256_password": if len(mc.cfg.Passwd) == 0 { - return nil, true, nil + return []byte{0}, nil } if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - return []byte(mc.cfg.Passwd), true, nil + return append([]byte(mc.cfg.Passwd), 0), nil } pubKey := mc.cfg.pubKey if pubKey == nil { // request public key from server - return []byte{1}, false, nil + return []byte{1}, nil } // encrypted password enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) - return enc, false, err + return enc, err default: errLog.Print("unknown auth plugin:", plugin) - return nil, false, ErrUnknownPlugin + return nil, ErrUnknownPlugin } } @@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { plugin = newPlugin - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { return err } - if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { + if err = mc.writeAuthSwitchPacket(authResp); err != nil { return err } @@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { case cachingSha2PasswordPerformFullAuthentication: if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) + err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) if err != nil { return err } @@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { pubKey := mc.cfg.pubKey if pubKey == nil { // request public key from server - data := mc.buf.takeSmallBuffer(4 + 1) + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + return err + } data[4] = cachingSha2PasswordRequestPublicKey mc.writePacket(data) // parse public key - data, err := mc.readPacket() - if err != nil { + if data, err = mc.readPacket(); err != nil { return err } diff --git a/third/github.com/go-sql-driver/mysql/auth_test.go b/third/github.com/go-sql-driver/mysql/auth_test.go index bd0e2189c..1920ef39f 100644 --- a/third/github.com/go-sql-driver/mysql/auth_test.go +++ b/third/github.com/go-sql-driver/mysql/auth_test.go @@ -85,11 +85,11 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -130,11 +130,11 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -172,11 +172,11 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -228,11 +228,11 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -280,11 +280,11 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { plugin := "caching_sha2_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -336,7 +336,7 @@ func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - _, _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin) if err != ErrCleartextPassword { t.Errorf("expected ErrCleartextPassword, got %v", err) } @@ -353,11 +353,11 @@ func TestAuthFastCleartextPassword(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -367,8 +367,8 @@ func TestAuthFastCleartextPassword(t *testing.T) { authRespEnd := authRespStart + 1 + len(authResp) writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - expectedAuthResp := []byte{115, 101, 99, 114, 101, 116} - if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } conn.written = nil @@ -396,11 +396,11 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { plugin := "mysql_clear_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -410,9 +410,9 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) { authRespEnd := authRespStart + 1 + len(authResp) writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - if writtenAuthRespLen != 0 { - t.Fatalf("unexpected written auth response (%d bytes): %v", - writtenAuthRespLen, writtenAuthResp) + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } conn.written = nil @@ -439,7 +439,7 @@ func TestAuthFastNativePasswordNotAllowed(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - _, _, err := mc.auth(authData, plugin) + _, err := mc.auth(authData, plugin) if err != ErrNativePassword { t.Errorf("expected ErrNativePassword, got %v", err) } @@ -455,11 +455,11 @@ func TestAuthFastNativePassword(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -498,11 +498,11 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) { plugin := "mysql_native_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -540,11 +540,11 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -554,7 +554,8 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) { authRespEnd := authRespStart + 1 + len(authResp) writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] - if writtenAuthRespLen != 0 { + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } conn.written = nil @@ -587,11 +588,11 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -636,11 +637,11 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { plugin := "sha256_password" // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } @@ -669,7 +670,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { plugin := "sha256_password" // send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { t.Fatal(err) } @@ -677,18 +678,18 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) { // unset TLS config to prevent the actual establishment of a TLS wrapper mc.cfg.tls = nil - err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin) + err = mc.writeHandshakeResponsePacket(authResp, plugin) if err != nil { t.Fatal(err) } // check written auth response authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 - authRespEnd := authRespStart + 1 + len(authResp) + 1 + authRespEnd := authRespStart + 1 + len(authResp) writtenAuthRespLen := conn.written[authRespStart] writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} - if writtenAuthRespLen != 6 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) } conn.written = nil @@ -1064,6 +1065,22 @@ func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { } } +// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request. +func TestOldAuthSwitchNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + conn.maxReads = 1 + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrOldPassword { + t.Errorf("expected ErrOldPassword, got %v", err) + } +} + func TestAuthSwitchOldPassword(t *testing.T) { conn, mc := newRWMockConn(2) mc.cfg.AllowOldPasswords = true @@ -1092,6 +1109,32 @@ func TestAuthSwitchOldPassword(t *testing.T) { } } +// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request. +func TestOldAuthSwitch(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "secret" + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} func TestAuthSwitchOldPasswordEmpty(t *testing.T) { conn, mc := newRWMockConn(2) mc.cfg.AllowOldPasswords = true @@ -1120,6 +1163,33 @@ func TestAuthSwitchOldPasswordEmpty(t *testing.T) { } } +// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request. +func TestOldAuthSwitchPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "" + + // OldAuthSwitch request. + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { conn, mc := newRWMockConn(2) mc.cfg.Passwd = "" diff --git a/third/github.com/go-sql-driver/mysql/buffer.go b/third/github.com/go-sql-driver/mysql/buffer.go index eb4748bf4..19486bd6f 100644 --- a/third/github.com/go-sql-driver/mysql/buffer.go +++ b/third/github.com/go-sql-driver/mysql/buffer.go @@ -22,17 +22,17 @@ const defaultBufSize = 4096 // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. type buffer struct { - buf []byte + buf []byte // buf is a byte buffer who's length and capacity are equal. nc net.Conn idx int length int timeout time.Duration } +// newBuffer allocates and returns a new buffer. func newBuffer(nc net.Conn) buffer { - var b [defaultBufSize]byte return buffer{ - buf: b[:], + buf: make([]byte, defaultBufSize), nc: nc, } } @@ -105,43 +105,56 @@ func (b *buffer) readNext(need int) ([]byte, error) { return b.buf[offset:b.idx], nil } -// returns a buffer with the requested size. +// takeBuffer returns a buffer with the requested size. // If possible, a slice from the existing buffer is returned. // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. -func (b *buffer) takeBuffer(length int) []byte { +func (b *buffer) takeBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } // test (cheap) general case first - if length <= defaultBufSize || length <= cap(b.buf) { - return b.buf[:length] + if length <= cap(b.buf) { + return b.buf[:length], nil } if length < maxPacketSize { b.buf = make([]byte, length) - return b.buf + return b.buf, nil } - return make([]byte, length) + + // buffer is larger than we want to store. + return make([]byte, length), nil } -// shortcut which can be used if the requested buffer is guaranteed to be -// smaller than defaultBufSize +// takeSmallBuffer is shortcut which can be used if length is +// known to be smaller than defaultBufSize. // Only one buffer (total) can be used at a time. -func (b *buffer) takeSmallBuffer(length int) []byte { +func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } - return b.buf[:length] + return b.buf[:length], nil } // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. +// cap and len of the returned buffer will be equal. // Only one buffer (total) can be used at a time. -func (b *buffer) takeCompleteBuffer() []byte { +func (b *buffer) takeCompleteBuffer() ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } - return b.buf + return b.buf, nil +} + +// store stores buf, an updated buffer, if its suitable to do so. +func (b *buffer) store(buf []byte) error { + if b.length > 0 { + return ErrBusyBuffer + } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) { + b.buf = buf[:cap(buf)] + } + return nil } diff --git a/third/github.com/go-sql-driver/mysql/connection.go b/third/github.com/go-sql-driver/mysql/connection.go index 911be2060..fc4ec7597 100644 --- a/third/github.com/go-sql-driver/mysql/connection.go +++ b/third/github.com/go-sql-driver/mysql/connection.go @@ -19,16 +19,6 @@ import ( "time" ) -// a copy of context.Context for Go 1.7 and earlier -type mysqlContext interface { - Done() <-chan struct{} - Err() error - - // defined in context.Context, but not used in this driver: - // Deadline() (deadline time.Time, ok bool) - // Value(key interface{}) interface{} -} - type mysqlConn struct { buf buffer netConn net.Conn @@ -45,7 +35,7 @@ type mysqlConn struct { // for context support (Go 1.8+) watching bool - watcher chan<- mysqlContext + watcher chan<- context.Context closech chan struct{} finished chan<- struct{} canceled atomicError // set non-nil if conn is canceled @@ -192,10 +182,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - buf := mc.buf.takeCompleteBuffer() - if buf == nil { + buf, err := mc.buf.takeCompleteBuffer() + if err != nil { // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return "", ErrInvalidConn } buf = buf[:0] @@ -475,7 +465,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) { defer mc.finish() if err = mc.writeCommandPacket(comPing); err != nil { - return + return mc.markBadConn(err) } return mc.readResultOK() @@ -595,33 +585,32 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error { mc.cleanup() return nil } + // When ctx is already cancelled, don't watch it. + if err := ctx.Err(); err != nil { + return err + } + // When ctx is not cancellable, don't watch it. if ctx.Done() == nil { return nil } - - mc.watching = true - select { - default: - case <-ctx.Done(): - return ctx.Err() - } + // When watcher is not alive, can't watch it. if mc.watcher == nil { return nil } + mc.watching = true mc.watcher <- ctx - return nil } func (mc *mysqlConn) startWatcher() { - watcher := make(chan mysqlContext, 1) + watcher := make(chan context.Context, 1) mc.watcher = watcher finished := make(chan struct{}) mc.finished = finished go func() { for { - var ctx mysqlContext + var ctx context.Context select { case ctx = <-watcher: case <-mc.closech: diff --git a/third/github.com/go-sql-driver/mysql/connection_test.go b/third/github.com/go-sql-driver/mysql/connection_test.go index dec376117..2a1c8e888 100644 --- a/third/github.com/go-sql-driver/mysql/connection_test.go +++ b/third/github.com/go-sql-driver/mysql/connection_test.go @@ -9,7 +9,10 @@ package mysql import ( + "context" "database/sql/driver" + "errors" + "net" "testing" ) @@ -79,3 +82,76 @@ func TestCheckNamedValue(t *testing.T) { t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value) } } + +// TestCleanCancel tests passed context is cancelled at start. +// No packet should be sent. Connection should keep current status. +func TestCleanCancel(t *testing.T) { + mc := &mysqlConn{ + closech: make(chan struct{}), + } + mc.startWatcher() + defer mc.cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + for i := 0; i < 3; i++ { // Repeat same behavior + err := mc.Ping(ctx) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %#v", err) + } + + if mc.closed.IsSet() { + t.Error("expected mc is not closed, closed actually") + } + + if mc.watching { + t.Error("expected watching is false, but true") + } + } +} + +func TestPingMarkBadConnection(t *testing.T) { + nc := badConnection{err: errors.New("boom")} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + } + + err := ms.Ping(context.Background()) + + if err != driver.ErrBadConn { + t.Errorf("expected driver.ErrBadConn, got %#v", err) + } +} + +func TestPingErrInvalidConn(t *testing.T) { + nc := badConnection{err: errors.New("failed to write"), n: 10} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + } + + err := ms.Ping(context.Background()) + + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %#v", err) + } +} + +type badConnection struct { + n int + err error + net.Conn +} + +func (bc badConnection) Write(b []byte) (n int, err error) { + return bc.n, bc.err +} + +func (bc badConnection) Close() error { + return nil +} diff --git a/third/github.com/go-sql-driver/mysql/driver.go b/third/github.com/go-sql-driver/mysql/driver.go index 53782ed24..9f4967087 100644 --- a/third/github.com/go-sql-driver/mysql/driver.go +++ b/third/github.com/go-sql-driver/mysql/driver.go @@ -9,7 +9,7 @@ // The driver should be used via the database/sql package: // // import "database/sql" -// import _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" +// import _ "github.com/go-sql-driver/mysql" // // db, err := sql.Open("mysql", "user:password@/dbname") // @@ -50,7 +50,7 @@ func RegisterDial(net string, dial DialFunc) { // Open new Connection. // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how -// the DSN string is formated +// the DSN string is formatted func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { var err error @@ -77,6 +77,10 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) } if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + errLog.Print("net.Error from Dial()': ", nerr.Error()) + return nil, driver.ErrBadConn + } return nil, err } @@ -110,18 +114,18 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { // try the default auth plugin, if using the requested plugin failed errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) plugin = defaultAuthPlugin - authResp, addNUL, err = mc.auth(authData, plugin) + authResp, err = mc.auth(authData, plugin) if err != nil { mc.cleanup() return nil, err } } - if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { mc.cleanup() return nil, err } diff --git a/third/github.com/go-sql-driver/mysql/driver_test.go b/third/github.com/go-sql-driver/mysql/driver_test.go index f2bf344e5..46d1f7ff4 100644 --- a/third/github.com/go-sql-driver/mysql/driver_test.go +++ b/third/github.com/go-sql-driver/mysql/driver_test.go @@ -85,6 +85,23 @@ type DBTest struct { db *sql.DB } +type netErrorMock struct { + temporary bool + timeout bool +} + +func (e netErrorMock) Temporary() bool { + return e.temporary +} + +func (e netErrorMock) Timeout() bool { + return e.timeout +} + +func (e netErrorMock) Error() string { + return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout) +} + func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { if !available { t.Skipf("MySQL server not running on %s", netAddr) @@ -1287,7 +1304,7 @@ func TestFoundRows(t *testing.T) { } func TestTLS(t *testing.T) { - tlsTest := func(dbt *DBTest) { + tlsTestReq := func(dbt *DBTest) { if err := dbt.db.Ping(); err != nil { if err == ErrNoTLS { dbt.Skip("server does not support TLS") @@ -1304,19 +1321,27 @@ func TestTLS(t *testing.T) { dbt.Fatal(err.Error()) } - if value == nil { - dbt.Fatal("no Cipher") + if (*value == nil) || (len(*value) == 0) { + dbt.Fatalf("no Cipher") + } else { + dbt.Logf("Cipher: %s", *value) } } } + tlsTestOpt := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } - runTests(t, dsn+"&tls=skip-verify", tlsTest) + runTests(t, dsn+"&tls=preferred", tlsTestOpt) + runTests(t, dsn+"&tls=skip-verify", tlsTestReq) // Verify that registering / using a custom cfg works RegisterTLSConfig("custom-skip-verify", &tls.Config{ InsecureSkipVerify: true, }) - runTests(t, dsn+"&tls=custom-skip-verify", tlsTest) + runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq) } func TestReuseClosedConnection(t *testing.T) { @@ -1801,6 +1826,38 @@ func TestConcurrent(t *testing.T) { }) } +func testDialError(t *testing.T, dialErr error, expectErr error) { + RegisterDial("mydial", func(addr string) (net.Conn, error) { + return nil, dialErr + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + _, err = db.Exec("DO 1") + if err != expectErr { + t.Fatalf("was expecting %s. Got: %s", dialErr, err) + } +} + +func TestDialUnknownError(t *testing.T) { + testErr := fmt.Errorf("test") + testDialError(t, testErr, testErr) +} + +func TestDialNonRetryableNetErr(t *testing.T) { + testErr := netErrorMock{} + testDialError(t, testErr, testErr) +} + +func TestDialTemporaryNetErr(t *testing.T) { + testErr := netErrorMock{temporary: true} + testDialError(t, testErr, driver.ErrBadConn) +} + // Tests custom dial functions func TestCustomDial(t *testing.T) { if !available { diff --git a/third/github.com/go-sql-driver/mysql/dsn.go b/third/github.com/go-sql-driver/mysql/dsn.go index be014babe..b9134722e 100644 --- a/third/github.com/go-sql-driver/mysql/dsn.go +++ b/third/github.com/go-sql-driver/mysql/dsn.go @@ -560,7 +560,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { } else { cfg.TLSConfig = "false" } - } else if vl := strings.ToLower(value); vl == "skip-verify" { + } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" { cfg.TLSConfig = vl cfg.tls = &tls.Config{InsecureSkipVerify: true} } else { diff --git a/third/github.com/go-sql-driver/mysql/packets.go b/third/github.com/go-sql-driver/mysql/packets.go index 170aaa02b..5e0853767 100644 --- a/third/github.com/go-sql-driver/mysql/packets.go +++ b/third/github.com/go-sql-driver/mysql/packets.go @@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { mc.sequence++ // packets with length 0 terminate a previous packet which is a - // multiple of (2^24)−1 bytes long + // multiple of (2^24)-1 bytes long if pktLen == 0 { // there was no previous packet if prevData == nil { @@ -194,7 +194,11 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return nil, "", ErrNoTLS + if mc.cfg.TLSConfig == "preferred" { + mc.cfg.tls = nil + } else { + return nil, "", ErrNoTLS + } } pos += 2 @@ -243,7 +247,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -269,7 +273,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, // encode length of the auth plugin data var authRespLEIBuf [9]byte - authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) + authRespLen := len(authResp) + authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) if len(authRespLEI) > 1 { // if the length can not be written in 1 byte, it must be written as a // length encoded integer @@ -277,9 +282,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, } pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 - if addNUL { - pktLen++ - } // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -288,10 +290,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, } // Calculate packet length and get buffer with that size - data := mc.buf.takeSmallBuffer(pktLen + 4) - if data == nil { + data, err := mc.buf.takeSmallBuffer(pktLen + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -350,10 +352,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, // Auth Data [length encoded integer] pos += copy(data[pos:], authRespLEI) pos += copy(data[pos:], authResp) - if addNUL { - data[pos] = 0x00 - pos++ - } // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { @@ -364,30 +362,24 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, pos += copy(data[pos:], plugin) data[pos] = 0x00 + pos++ // Send Auth packet - return mc.writePacket(data) + return mc.writePacket(data[:pos]) } // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) - if addNUL { - pktLen++ - } - data := mc.buf.takeSmallBuffer(pktLen) - if data == nil { + data, err := mc.buf.takeSmallBuffer(pktLen) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } // Add the auth data [EOF] copy(data[4:], authData) - if addNUL { - data[pktLen-1] = 0x00 - } - return mc.writePacket(data) } @@ -399,10 +391,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1) - if data == nil { + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -418,10 +410,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.sequence = 0 pktLen := 1 + len(arg) - data := mc.buf.takeBuffer(pktLen + 4) - if data == nil { + data, err := mc.buf.takeBuffer(pktLen + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -439,10 +431,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1 + 4) - if data == nil { + data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -479,7 +471,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { return data[1:], "", err case iEOF: - if len(data) < 1 { + if len(data) == 1 { // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest return nil, "mysql_old_password", nil } @@ -895,7 +887,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc - // Determine threshould dynamically to avoid packet size shortage. + // Determine threshold dynamically to avoid packet size shortage. longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) if longDataSize < 64 { longDataSize = 64 @@ -905,15 +897,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { mc.sequence = 0 var data []byte + var err error if len(args) == 0 { - data = mc.buf.takeBuffer(minPktLen) + data, err = mc.buf.takeBuffer(minPktLen) } else { - data = mc.buf.takeCompleteBuffer() + data, err = mc.buf.takeCompleteBuffer() + // In this case the len(data) == cap(data) which is used to optimise the flow below. } - if data == nil { + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -939,7 +933,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { pos := minPktLen var nullMask []byte - if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) { // buffer has to be extended but we don't know by how much so // we depend on append after all data with known sizes fit. // We stop at that because we deal with a lot of columns here @@ -948,10 +942,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { copy(tmp[:pos], data[:pos]) data = tmp nullMask = data[pos : pos+maskLen] + // No need to clean nullMask as make ensures that. pos += maskLen } else { nullMask = data[pos : pos+maskLen] - for i := 0; i < maskLen; i++ { + for i := range nullMask { nullMask[i] = 0 } pos += maskLen @@ -1088,7 +1083,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - mc.buf.buf = data + if err = mc.buf.store(data); err != nil { + errLog.Print(err) + return errBadConnNoWrite + } } pos += len(paramValues)