diff --git a/g/database/gdb/gdb.go b/g/database/gdb/gdb.go index e66f23931..5dd863618 100644 --- a/g/database/gdb/gdb.go +++ b/g/database/gdb/gdb.go @@ -12,7 +12,6 @@ import ( "database/sql" "errors" "fmt" - "gitee.com/johng/gf/g/container/gmap" "gitee.com/johng/gf/g/container/gring" "gitee.com/johng/gf/g/container/gtype" "gitee.com/johng/gf/g/container/gvar" @@ -30,19 +29,16 @@ const ( ) // 数据库操作接口 -type Link interface { - // 打开数据库连接,建立数据库操作对象 - Open(c *ConfigNode) (*sql.DB, error) - +type DB interface { // SQL操作方法 - Query(q string, args ...interface{}) (*sql.Rows, error) - Exec(q string, args ...interface{}) (sql.Result, error) - Prepare(q string) (*sql.Stmt, error) + Query(query string, args ...interface{}) (*sql.Rows, error) + Exec(sql string, args ...interface{}) (sql.Result, error) + Prepare(sql string, execOnMaster...bool) (*sql.Stmt, error) // 数据库查询 - GetAll(q string, args ...interface{}) (Result, error) - GetOne(q string, args ...interface{}) (Record, error) - GetValue(q string, args ...interface{}) (Value, error) + GetAll(query string, args ...interface{}) (Result, error) + GetOne(query string, args ...interface{}) (Record, error) + GetValue(query string, args ...interface{}) (Value, error) // Ping PingMaster() error @@ -74,21 +70,23 @@ type Link interface { Table(tables string) *Model From(tables string) *Model - // 内部方法 - insert(table string, data Map, option uint8) (sql.Result, error) - batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) + // 设置管理 + SetDebug(debug bool) - getQuoteCharLeft() string - getQuoteCharRight() string - handleSqlBeforeExec(q *string) *string + // 内部方法接口 + open(c *ConfigNode) (*sql.DB, error) + getCache() (*gcache.Cache) + getChars() (charLeft string, charRight string) + getDebug() bool + putSql(s *Sql) + formatCondition(condition interface{}) (where string) + handleSqlBeforeExec(sql string) string } // 数据库链接对象 -type Db struct { - link Link // 底层数据库类型管理对象 +type dbBase struct { + db DB // 数据库对象 group string // 配置分组名称 - charl string // SQL安全符号(左) - charr string // SQL安全符号(右) debug *gtype.Bool // (默认关闭)是否开启调试模式,当开启时会启用一些调试特性 sqls *gring.Ring // (debug=true时有效)已执行的SQL列表 cache *gcache.Cache // 数据库缓存,包括底层连接池对象缓存及查询缓存;需要注意的是,事务查询不支持查询缓存 @@ -104,7 +102,7 @@ type Sql struct { Error error // 执行结果(nil为成功) Start int64 // 执行开始时间(毫秒) End int64 // 执行结束时间(毫秒) - Func string // 执行方法名称 + Func string // 执行方法 } // 返回数据表记录值 @@ -117,27 +115,13 @@ type Record map[string]Value type Result []Record // 关联数组,绑定一条数据表记录(使用别名) -type Map = map[string]interface{} +type Map = map[string]interface{} // 关联数组列表(索引从0开始的数组),绑定多条记录(使用别名) type List = []Map -var ( - // 支持的数据库类型map - driverMap = make(map[string]interface{}) - // 数据库查询缓存对象map,使用数据库连接名称作为键名,键值为查询缓存对象 - dbCaches = gmap.NewStringInterfaceMap() -) -func init() { - driverMap["mysql"] = linkMysql - driverMap["oracle"] = linkOracle - driverMap["sqlite"] = linkSqlite - driverMap["pgsql"] = linkPgsql - driverMap["mssql"] = linkMssql -} - // 使用默认/指定分组配置进行连接,数据库集群配置项:default -func New(groupName ...string) (*Db, error) { +func New(groupName ...string) (db DB, err error) { group := config.d if len(groupName) > 0 { group = groupName[0] @@ -150,24 +134,29 @@ func New(groupName ...string) (*Db, error) { } if _, ok := config.c[group]; ok { if node, err := getConfigNodeByGroup(group, true); err == nil { - link, err := getLinkByType(node.Type) - if err != nil { - return nil, err - } - db := &Db { - link : link, + base := &dbBase { group : group, - charl : link.getQuoteCharLeft(), - charr : link.getQuoteCharRight(), debug : gtype.NewBool(), + cache : gcache.New(), maxIdleConnCount : gtype.NewInt(), maxOpenConnCount : gtype.NewInt(), maxConnLifetime : gtype.NewInt(), } - db.cache = dbCaches.GetOrSetFuncLock(group, func() interface{} { - return gcache.New() - }).(*gcache.Cache) - return db, nil + switch node.Type { + case "mysql": + base.db = &dbMysql{dbBase : base} + case "pgsql": + base.db = &dbPgsql{dbBase : base} + case "mssql": + base.db = &dbMssql{dbBase : base} + case "sqlite": + base.db = &dbSqlite{dbBase : base} + case "oracle": + base.db = &dbOracle{dbBase : base} + default: + return nil, errors.New(fmt.Sprintf(`unsupported database type "%s"`, node.Type)) + } + return base.db, nil } else { return nil, err } @@ -238,34 +227,20 @@ func getConfigNodeByPriority(cg ConfigGroup) *ConfigNode { return nil } -// 根据配置的数据库;类型获得Link接口对象 -func getLinkByType(dbType string) (Link, error) { - if dblink, ok := driverMap[dbType]; ok == false { - return nil, errors.New(fmt.Sprintf("unsupported db type '%s'", dbType)) - } else { - return dblink.(Link), nil - } -} - // 获得底层数据库链接对象 -func (db *Db) getSqlDb(master bool) (sqlDb *sql.DB, err error) { +func (db *dbBase) getSqlDb(master bool) (sqlDb *sql.DB, err error) { // 负载均衡 node, err := getConfigNodeByGroup(db.group, master) if err != nil { return nil, err } - // 类型对象 - link, err := getLinkByType(node.Type) - if err != nil { - return nil, err - } // 检查缓存连接池对象 cacheKey := node.String() if v := db.cache.Get(cacheKey); v != nil { return v.(*sql.DB), nil } v := db.cache.GetOrSetFuncLock(node.String(), func() interface{} { - sqlDb, err = link.Open(node) + sqlDb, err = db.db.open(node) if err != nil { return nil } @@ -296,11 +271,11 @@ func (db *Db) getSqlDb(master bool) (sqlDb *sql.DB, err error) { } // 创建底层数据库master链接对象 -func (db *Db) Master() (*sql.DB, error) { +func (db *dbBase) Master() (*sql.DB, error) { return db.getSqlDb(true) } // 创建底层数据库slave链接对象 -func (db *Db) Slave() (*sql.DB, error) { +func (db *dbBase) Slave() (*sql.DB, error) { return db.getSqlDb(false) } diff --git a/g/database/gdb/gdb_base.go b/g/database/gdb/gdb_base.go index ade922df1..e11618f33 100644 --- a/g/database/gdb/gdb_base.go +++ b/g/database/gdb/gdb_base.go @@ -8,33 +8,23 @@ package gdb import ( - "fmt" - "errors" - "strings" - "reflect" "database/sql" - "gitee.com/johng/gf/g/util/gstr" - "gitee.com/johng/gf/g/util/gconv" - "gitee.com/johng/gf/g/container/gring" + "errors" + "fmt" + "gitee.com/johng/gf/g/os/gcache" "gitee.com/johng/gf/g/os/gtime" - "gitee.com/johng/gf/g/os/glog" - "gitee.com/johng/gf/g/container/gvar" + "gitee.com/johng/gf/g/util/gconv" + "gitee.com/johng/gf/g/util/gstr" + "reflect" + "strings" ) const ( gDEFAULT_DEBUG_SQL_LENGTH = 1000 // 默认调试模式下记录的SQL条数 ) -// 是否开启调试服务 -func (db *Db) SetDebug(debug bool) { - db.debug.Set(debug) - if debug && db.sqls == nil { - db.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH) - } -} - // 获取已经执行的SQL列表(仅在debug=true时有效) -func (db *Db) GetQueriedSqls() []*Sql { +func (db *dbBase) GetQueriedSqls() []*Sql { if db.sqls == nil { return nil } @@ -51,7 +41,7 @@ func (db *Db) GetQueriedSqls() []*Sql { } // 打印已经执行的SQL列表(仅在debug=true时有效) -func (db *Db) PrintQueriedSqls() { +func (db *dbBase) PrintQueriedSqls() { sqls := db.GetQueriedSqls() for k, v := range sqls { fmt.Println(len(sqls) - k, ":") @@ -61,142 +51,54 @@ func (db *Db) PrintQueriedSqls() { fmt.Println(" Start:", gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u")) fmt.Println(" End :", gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u")) fmt.Println(" Cost :", v.End - v.Start, "ms") - fmt.Println(" Func :", v.Func) - } -} - -// 打印SQL对象(仅在debug=true时有效) -func (db *Db) printSql(v *Sql) { - s := fmt.Sprintf("%s, %v, %s, %s, %d ms, %s", v.Sql, v.Args, - gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u"), - gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u"), - v.End - v.Start, v.Func, - ) - if v.Error != nil { - s += "\nError: " + v.Error.Error() - glog.Backtrace(true, 2).Error(s) - } else { - glog.Debug(s) } } // 数据库sql查询操作,主要执行查询 -func (db *Db) Query(query string, args ...interface{}) (*sql.Rows, error) { - var err error - var rows *sql.Rows - var slave *sql.DB - slave, err = db.Slave(); +func (db *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { + link, err := db.Slave() if err != nil { return nil,err } - p := db.link.handleSqlBeforeExec(&query) - if db.debug.Val() { - militime1 := gtime.Millisecond() - rows, err = slave.Query(*p, args ...) - militime2 := gtime.Millisecond() - s := &Sql{ - Sql : *p, - Args : args, - Error : err, - Start : militime1, - End : militime2, - Func : "DB:Query", - } - db.sqls.Put(s) - db.printSql(s) - } else { - rows, err = slave.Query(*p, args ...) - } - if err == nil { - return rows, nil - } else { - err = db.formatError(err, p, args...) - } - return nil, err + return doQuery(db.db, link, query, args...) } // 执行一条sql,并返回执行情况,主要用于非查询操作 -func (db *Db) Exec(query string, args ...interface{}) (sql.Result, error) { - var err error - var result sql.Result - var master *sql.DB - master, err = db.Master(); +func (db *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) { + link, err := db.Master() if err != nil { return nil,err } - p := db.link.handleSqlBeforeExec(&query) - if db.debug.Val() { - militime1 := gtime.Millisecond() - result, err = master.Exec(*p, args ...) - militime2 := gtime.Millisecond() - s := &Sql{ - Sql : *p, - Args : args, - Error : err, - Start : militime1, - End : militime2, - Func : "DB:Exec", + return doExec(db.db, link, query, args...) +} + +// SQL预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作; 默认执行在Slave上, 通过第二个参数指定执行在Master上 +func (db *dbBase) Prepare(query string, execOnMaster...bool) (*sql.Stmt, error) { + err := (error)(nil) + sqldb := (*sql.DB)(nil) + if len(execOnMaster) > 0 && execOnMaster[0] { + if sqldb, err = db.Master(); err != nil { + return nil, err } - db.sqls.Put(s) - db.printSql(s) } else { - result, err = master.Exec(*p, args ...) - } - return result, db.formatError(err, p, args...) -} - -// 格式化错误信息 -func (db *Db) formatError(err error, query *string, args ...interface{}) error { - if err != nil { - errstr := fmt.Sprintf("DB ERROR: %s\n", err.Error()) - errstr += fmt.Sprintf("DB QUERY: %s\n", *query) - if len(args) > 0 { - errstr += fmt.Sprintf("DB PARAM: %v\n", args) + if sqldb, err = db.Slave(); err != nil { + return nil, err } - err = errors.New(errstr) } - return err + return sqldb.Prepare(query) } - // 数据库查询,获取查询结果集,以列表结构返回 -func (db *Db) GetAll(query string, args ...interface{}) (Result, error) { - // 执行sql +func (db *dbBase) GetAll(query string, args ...interface{}) (Result, error) { rows, err := db.Query(query, args ...) if err != nil || rows == nil { return nil, err } - // 列名称列表 - columns, err := rows.Columns() - if err != nil { - return nil, err - } - // 返回结构组装 - values := make([]sql.RawBytes, len(columns)) - scanArgs := make([]interface{}, len(values)) - records := make(Result, 0) - for i := range values { - scanArgs[i] = &values[i] - } - for rows.Next() { - err = rows.Scan(scanArgs...) - if err != nil { - return records, err - } - row := make(Record) - // 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址 - for i, col := range values { - v := make([]byte, len(col)) - copy(v, col) - row[columns[i]] = gvar.New(v, false) - } - records = append(records, row) - } - return records, nil + return rowsToResult(rows) } // 数据库查询,获取查询结果记录,以关联数组结构返回 -func (db *Db) GetOne(query string, args ...interface{}) (Record, error) { +func (db *dbBase) GetOne(query string, args ...interface{}) (Record, error) { list, err := db.GetAll(query, args ...) if err != nil { return nil, err @@ -208,7 +110,7 @@ func (db *Db) GetOne(query string, args ...interface{}) (Record, error) { } // 数据库查询,获取查询结果记录,自动映射数据到给定的struct对象中 -func (db *Db) GetStruct(obj interface{}, query string, args ...interface{}) error { +func (db *dbBase) GetStruct(obj interface{}, query string, args ...interface{}) error { one, err := db.GetOne(query, args...) if err != nil { return err @@ -218,7 +120,7 @@ func (db *Db) GetStruct(obj interface{}, query string, args ...interface{}) erro // 数据库查询,获取查询字段值 -func (db *Db) GetValue(query string, args ...interface{}) (Value, error) { +func (db *dbBase) GetValue(query string, args ...interface{}) (Value, error) { one, err := db.GetOne(query, args ...) if err != nil { return nil, err @@ -230,7 +132,7 @@ func (db *Db) GetValue(query string, args ...interface{}) (Value, error) { } // 数据库查询,获取查询数量 -func (db *Db) GetCount(query string, args ...interface{}) (int, error) { +func (db *dbBase) GetCount(query string, args ...interface{}) (int, error) { val, err := db.GetValue(query, args ...) if err != nil { return 0, err @@ -239,7 +141,7 @@ func (db *Db) GetCount(query string, args ...interface{}) (int, error) { } // 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作 -func (db *Db) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) { +func (db *dbBase) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) { s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables) if condition != nil { s += fmt.Sprintf("WHERE %s ", db.formatCondition(condition)) @@ -256,17 +158,8 @@ func (db *Db) Select(tables, fields string, condition interface{}, groupBy, orde return db.GetAll(s, args ... ) } -// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 -func (db *Db) Prepare(query string) (*sql.Stmt, error) { - if master, err := db.Master(); err != nil { - return nil, err - } else { - return master.Prepare(query) - } -} - // ping一下,判断或保持数据库链接(master) -func (db *Db) PingMaster() error { +func (db *dbBase) PingMaster() error { if master, err := db.Master(); err != nil { return err } else { @@ -275,7 +168,7 @@ func (db *Db) PingMaster() error { } // ping一下,判断或保持数据库链接(slave) -func (db *Db) PingSlave() error { +func (db *dbBase) PingSlave() error { if slave, err := db.Slave(); err != nil { return err } else { @@ -285,13 +178,13 @@ func (db *Db) PingSlave() error { // 事务操作,开启,会返回一个底层的事务操作对象链接如需要嵌套事务,那么可以使用该对象,否则请忽略 // 只有在tx.Commit/tx.Rollback时,链接会自动Close -func (db *Db) Begin() (*Tx, error) { +func (db *dbBase) Begin() (*TX, error) { if master, err := db.Master(); err != nil { return nil, err } else { if tx, err := master.Begin(); err == nil { - return &Tx { - db : db, + return &TX { + db : db.db, tx : tx, master : master, }, nil @@ -301,42 +194,30 @@ func (db *Db) Begin() (*Tx, error) { } } -// 根据insert选项获得操作名称 -func (db *Db) getInsertOperationByOption(option uint8) string { - oper := "INSERT" - switch option { - case OPTION_REPLACE: - oper = "REPLACE" - case OPTION_SAVE: - case OPTION_IGNORE: - oper = "INSERT IGNORE" - } - return oper -} - // insert、replace, save, ignore操作 // 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 // 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 // 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 // 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做 -func (db *Db) insert(table string, data Map, option uint8) (sql.Result, error) { +func (db *dbBase) insert(table string, data Map, option uint8) (sql.Result, error) { var fields []string var values []string var params []interface{} + charl, charr := db.db.getChars() for k, v := range data { - fields = append(fields, db.charl + k + db.charr) + fields = append(fields, charl + k + charr) values = append(values, "?") params = append(params, v) } - operation := db.getInsertOperationByOption(option) + operation := getInsertOperationByOption(option) updatestr := "" if option == OPTION_SAVE { var updates []string for k, _ := range data { updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s%s%s)", - db.charl, k, db.charr, - db.charl, k, db.charr, + charl, k, charr, + charl, k, charr, ), ) } @@ -352,22 +233,22 @@ func (db *Db) insert(table string, data Map, option uint8) (sql.Result, error) { } // CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 -func (db *Db) Insert(table string, data Map) (sql.Result, error) { +func (db *dbBase) Insert(table string, data Map) (sql.Result, error) { return db.insert(table, data, OPTION_INSERT) } // CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -func (db *Db) Replace(table string, data Map) (sql.Result, error) { +func (db *dbBase) Replace(table string, data Map) (sql.Result, error) { return db.insert(table, data, OPTION_REPLACE) } // CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -func (db *Db) Save(table string, data Map) (sql.Result, error) { +func (db *dbBase) Save(table string, data Map) (sql.Result, error) { return db.insert(table, data, OPTION_SAVE) } // 批量写入数据 -func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) { +func (db *dbBase) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) { var keys []string var values []string var bvalues []string @@ -383,18 +264,19 @@ func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql keys = append(keys, k) values = append(values, "?") } - keyStr := db.charl + strings.Join(keys, db.charl + "," + db.charr) + db.charr + charl, charr := db.db.getChars() + keyStr := charl + strings.Join(keys, charl + "," + charr) + charr valueHolderStr := "(" + strings.Join(values, ",") + ")" // 操作判断 - operation := db.getInsertOperationByOption(option) + operation := getInsertOperationByOption(option) updatestr := "" if option == OPTION_SAVE { var updates []string for _, k := range keys { updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s%s%s)", - db.charl, k, db.charr, - db.charl, k, db.charr, + charl, k, charr, + charl, k, charr, ), ) } @@ -434,31 +316,32 @@ func (db *Db) batchInsert(table string, list List, batch int, option uint8) (sql } // CURD操作:批量数据指定批次量写入 -func (db *Db) BatchInsert(table string, list List, batch int) (sql.Result, error) { +func (db *dbBase) BatchInsert(table string, list List, batch int) (sql.Result, error) { return db.batchInsert(table, list, batch, OPTION_INSERT) } // CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -func (db *Db) BatchReplace(table string, list List, batch int) (sql.Result, error) { +func (db *dbBase) BatchReplace(table string, list List, batch int) (sql.Result, error) { return db.batchInsert(table, list, batch, OPTION_REPLACE) } // CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -func (db *Db) BatchSave(table string, list List, batch int) (sql.Result, error) { +func (db *dbBase) BatchSave(table string, list List, batch int) (sql.Result, error) { return db.batchInsert(table, list, batch, OPTION_SAVE) } // CURD操作:数据更新,统一采用sql预处理 // data参数支持字符串或者关联数组类型,内部会自行做判断处理 -func (db *Db) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { +func (db *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { var params []interface{} var updates string - refValue := reflect.ValueOf(data) + charl, charr := db.db.getChars() + refValue := reflect.ValueOf(data) if refValue.Kind() == reflect.Map { var fields []string keys := refValue.MapKeys() for _, k := range keys { - fields = append(fields, fmt.Sprintf("%s%s%s=?", db.charl, k, db.charr)) + fields = append(fields, fmt.Sprintf("%s%s%s=?", charl, k, charr)) params = append(params, gconv.String(refValue.MapIndex(k).Interface())) } updates = strings.Join(fields, ",") @@ -472,12 +355,12 @@ func (db *Db) Update(table string, data interface{}, condition interface{}, args } // CURD操作:删除数据 -func (db *Db) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { +func (db *dbBase) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { return db.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, db.formatCondition(condition)), args...) } // 格式化SQL查询条件 -func (db *Db) formatCondition(condition interface{}) (where string) { +func (db *dbBase) formatCondition(condition interface{}) (where string) { if reflect.ValueOf(condition).Kind() == reflect.Map { ks := reflect.ValueOf(condition).MapKeys() vs := reflect.ValueOf(condition) @@ -498,4 +381,14 @@ func (db *Db) formatCondition(condition interface{}) (where string) { where += gconv.String(condition) } return +} + +// 获得缓存对象 +func (db *dbBase) getCache() *gcache.Cache { + return db.cache +} + +// 记录执行的SQL +func (db *dbBase) putSql(s *Sql) { + db.sqls.Put(s) } \ No newline at end of file diff --git a/g/database/gdb/gdb_config.go b/g/database/gdb/gdb_config.go index f3a6f490c..a4c9fb312 100644 --- a/g/database/gdb/gdb_config.go +++ b/g/database/gdb/gdb_config.go @@ -9,6 +9,7 @@ package gdb import ( "fmt" + "gitee.com/johng/gf/g/container/gring" "sync" ) @@ -122,18 +123,18 @@ func SetDefaultGroup (groupName string) { } // 设置数据库连接池中空闲链接的大小 -func (db *Db) SetMaxIdleConns(n int) { +func (db *dbBase) SetMaxIdleConns(n int) { db.maxIdleConnCount.Set(n) } // 设置数据库连接池最大打开的链接数量 -func (db *Db) SetMaxOpenConns(n int) { +func (db *dbBase) SetMaxOpenConns(n int) { db.maxOpenConnCount.Set(n) } // 设置数据库连接可重复利用的时间,超过该时间则被关闭废弃 // 如果 d <= 0 表示该链接会一直重复利用 -func (db *Db) SetConnMaxLifetime(n int) { +func (db *dbBase) SetConnMaxLifetime(n int) { db.maxConnLifetime.Set(n) } @@ -146,4 +147,17 @@ func (node *ConfigNode) String() string { node.Name, node.Type, node.Role, node.Charset, node.MaxIdleConnCount, node.MaxOpenConnCount, node.MaxConnLifetime, ) +} + +// 是否开启调试服务 +func (db *dbBase) SetDebug(debug bool) { + db.debug.Set(debug) + if debug && db.sqls == nil { + db.sqls = gring.New(gDEFAULT_DEBUG_SQL_LENGTH) + } +} + +// 获取是否开启调试服务 +func (db *dbBase) getDebug() bool { + return db.debug.Val() } \ No newline at end of file diff --git a/g/database/gdb/gdb_func.go b/g/database/gdb/gdb_func.go new file mode 100644 index 000000000..e6e8f103c --- /dev/null +++ b/g/database/gdb/gdb_func.go @@ -0,0 +1,176 @@ +// Copyright 2017-2018 gf Author(https://gitee.com/johng/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://gitee.com/johng/gf. + +package gdb + +import ( + "database/sql" + "errors" + "fmt" + "gitee.com/johng/gf/g/container/gvar" + "gitee.com/johng/gf/g/os/glog" + "gitee.com/johng/gf/g/os/gtime" + _ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql" + "strings" +) + +type dbLink interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + Exec(sql string, args ...interface{}) (sql.Result, error) + Prepare(sql string) (*sql.Stmt, error) +} + +// 数据库sql查询操作,主要执行查询 +func doQuery(db DB, link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) { + query = db.handleSqlBeforeExec(query) + if db.getDebug() { + mTime1 := gtime.Millisecond() + rows, err = link.Query(query, args...) + mTime2 := gtime.Millisecond() + s := &Sql{ + Sql : query, + Args : args, + Error : err, + Start : mTime1, + End : mTime2, + } + db.putSql(s) + printSql(s) + } else { + rows, err = link.Query(query, args ...) + } + if err == nil { + return rows, nil + } else { + err = formatError(err, query, args...) + } + return nil, err +} + +// 执行一条sql,并返回执行情况,主要用于非查询操作 +func doExec(db DB, link dbLink, query string, args ...interface{}) (result sql.Result, err error) { + query = db.handleSqlBeforeExec(query) + if db.getDebug() { + mTime1 := gtime.Millisecond() + result, err = link.Exec(query, args ...) + mTime2 := gtime.Millisecond() + s := &Sql{ + Sql : query, + Args : args, + Error : err, + Start : mTime1, + End : mTime2, + } + db.putSql(s) + printSql(s) + } else { + result, err = link.Exec(query, args ...) + } + return result, formatError(err, query, args...) +} + +// 将数据查询的列表数据*sql.Rows转换为Result类型 +func rowsToResult(rows *sql.Rows) (Result, error) { + // 列名称列表 + columns, err := rows.Columns() + if err != nil { + return nil, err + } + // 返回结构组装 + values := make([]sql.RawBytes, len(columns)) + scanArgs := make([]interface{}, len(values)) + records := make(Result, 0) + for i := range values { + scanArgs[i] = &values[i] + } + for rows.Next() { + err = rows.Scan(scanArgs...) + if err != nil { + return records, err + } + row := make(Record) + // 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址 + for i, col := range values { + v := make([]byte, len(col)) + copy(v, col) + row[columns[i]] = gvar.New(v, false) + } + records = append(records, row) + } + return records, nil +} + +func formatInsertQuery(db DB, table string, data Map, option uint8) (string, []interface{}) { + var fields []string + var values []string + var params []interface{} + charl, charr := db.getChars() + for k, v := range data { + fields = append(fields, charl + k + charr) + values = append(values, "?") + params = append(params, v) + } + operation := getInsertOperationByOption(option) + updatestr := "" + if option == OPTION_SAVE { + var updates []string + for k, _ := range data { + updates = append(updates, + fmt.Sprintf("%s%s%s=VALUES(%s%s%s)", + charl, k, charr, + charl, k, charr, + ), + ) + } + updatestr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ",")) + } + return fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s", + operation, table, strings.Join(fields, ","), + strings.Join(values, ","), updatestr), + params +} + +// 打印SQL对象(仅在debug=true时有效) +func printSql(v *Sql) { + s := fmt.Sprintf("%s, %v, %s, %s, %d ms, %s", v.Sql, v.Args, + gtime.NewFromTimeStamp(v.Start).Format("Y-m-d H:i:s.u"), + gtime.NewFromTimeStamp(v.End).Format("Y-m-d H:i:s.u"), + v.End - v.Start, + v.Func, + ) + if v.Error != nil { + s += "\nError: " + v.Error.Error() + glog.Backtrace(true, 2).Error(s) + } else { + glog.Debug(s) + } +} + +// 格式化错误信息 +func formatError(err error, query string, args ...interface{}) error { + if err != nil { + errstr := fmt.Sprintf("DB ERROR: %s\n", err.Error()) + errstr += fmt.Sprintf("DB QUERY: %s\n", query) + if len(args) > 0 { + errstr += fmt.Sprintf("DB PARAM: %v\n", args) + } + err = errors.New(errstr) + } + return err +} + +// 根据insert选项获得操作名称 +func getInsertOperationByOption(option uint8) string { + oper := "INSERT" + switch option { + case OPTION_REPLACE: + oper = "REPLACE" + case OPTION_SAVE: + case OPTION_IGNORE: + oper = "INSERT IGNORE" + } + return oper +} diff --git a/g/database/gdb/gdb_model.go b/g/database/gdb/gdb_model.go index a86f4dbe9..f15ea7562 100644 --- a/g/database/gdb/gdb_model.go +++ b/g/database/gdb/gdb_model.go @@ -18,7 +18,7 @@ import ( // 数据库链式操作模型对象 type Model struct { tx *Tx // 数据库事务对象 - db *Db // 数据库操作对象 + db DB // 数据库操作对象 tablesInit string // 初始化Model时的表名称(可以是多个) tables string // 数据库操作表 fields string // 操作字段 @@ -36,9 +36,9 @@ type Model struct { } // 链式操作,数据表字段,可支持多个表,以半角逗号连接 -func (db *Db) Table(tables string) (*Model) { +func (db *dbBase) Table(tables string) (*Model) { return &Model{ - db : db, + db : db.db, tablesInit : tables, tables : tables, fields : "*", @@ -46,7 +46,7 @@ func (db *Db) Table(tables string) (*Model) { } // 链式操作,数据表字段,可支持多个表,以半角逗号连接 -func (db *Db) From(tables string) (*Model) { +func (db *dbBase) From(tables string) (*Model) { return db.Table(tables) } @@ -396,7 +396,7 @@ func (md *Model) getAll(sql string, args ...interface{}) (result Result, err err if len(cacheKey) == 0 { cacheKey = sql + "/" + gconv.String(args) } - if v := md.db.cache.Get(cacheKey); v != nil { + if v := md.db.getCache().Get(cacheKey); v != nil { return v.(Result), nil } } @@ -408,9 +408,9 @@ func (md *Model) getAll(sql string, args ...interface{}) (result Result, err err // 查询缓存保存处理 if len(cacheKey) > 0 && err == nil { if md.cacheTime < 0 { - md.db.cache.Remove(cacheKey) + md.db.getCache().Remove(cacheKey) } else { - md.db.cache.Set(cacheKey, result, md.cacheTime*1000) + md.db.getCache().Set(cacheKey, result, md.cacheTime*1000) } } return result, err @@ -419,7 +419,7 @@ func (md *Model) getAll(sql string, args ...interface{}) (result Result, err err // 检查是否需要查询查询缓存 func (md *Model) checkAndRemoveCache() { if md.cacheEnabled && md.cacheTime < 0 && len(md.cacheName) > 0 { - md.db.cache.Remove(md.cacheName) + md.db.getCache().Remove(md.cacheName) } } diff --git a/g/database/gdb/gdb_mssql.go b/g/database/gdb/gdb_mssql.go index 4ffbecc9a..5532b1b40 100644 --- a/g/database/gdb/gdb_mssql.go +++ b/g/database/gdb/gdb_mssql.go @@ -22,15 +22,13 @@ import ( ) -var linkMssql = &dbmssql{} - // 数据库链接对象 -type dbmssql struct { - Db +type dbMssql struct { + *dbBase } // 创建SQL操作对象 -func (db *dbmssql) Open(c *ConfigNode) (*sql.DB, error) { +func (db *dbMssql) open(c *ConfigNode) (*sql.DB, error) { var source string if c.Linkinfo != "" { source = c.Linkinfo @@ -44,43 +42,38 @@ func (db *dbmssql) Open(c *ConfigNode) (*sql.DB, error) { } } -// 获得关键字操作符 - 左 -func (db *dbmssql) getQuoteCharLeft() string { - return "\"" -} - -// 获得关键字操作符 - 右 -func (db *dbmssql) getQuoteCharRight() string { - return "\"" +// 获得关键字操作符 +func (db *dbMssql) getChars () (charLeft string, charRight string) { + return "\"", "\"" } // 在执行sql之前对sql进行进一步处理 -func (db *dbmssql) handleSqlBeforeExec(q *string) *string { +func (db *dbMssql) handleSqlBeforeExec(query string) string { index := 0 - str, _ := gregex.ReplaceStringFunc("\\?", *q, func(s string) string { + str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string { index++ return fmt.Sprintf("@p%d", index) }) str, _ = gregex.ReplaceString("\"", "", str) - return db.parseSql(&str) + return db.parseSql(str) } //将MYSQL的SQL语法转换为MSSQL的语法 //1.由于mssql不支持limit写法所以需要对mysql中的limit用法做转换 -func (db *dbmssql) parseSql(sql *string) *string { +func (db *dbMssql) parseSql(sql string) string { //下面的正则表达式匹配出SELECT和INSERT的关键字后分别做不同的处理,如有LIMIT则将LIMIT的关键字也匹配出 patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))` - if gregex.IsMatchString(patten, *sql) == false { + if gregex.IsMatchString(patten, sql) == false { fmt.Println("not matched..") return sql } - res, err := gregex.MatchAllString(patten, *sql) + res, err := gregex.MatchAllString(patten, sql) if err != nil { fmt.Println("MatchString error.", err) - return nil + return "" } index := 0 @@ -96,17 +89,17 @@ func (db *dbmssql) parseSql(sql *string) *string { } //不含LIMIT则不处理 - if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) == false { + if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) == false { break } //判断SQL中是否含有order by selectStr := "" orderbyStr := "" - haveOrderby := gregex.IsMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", *sql) + haveOrderby := gregex.IsMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql) if haveOrderby { //取order by 前面的字符串 - queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)ORDER BY)", *sql) + queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql) if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "ORDER BY") == false{ break @@ -114,13 +107,13 @@ func (db *dbmssql) parseSql(sql *string) *string { selectStr = queryExpr[2] //取order by表达式的值 - orderbyExpr, _ := gregex.MatchString("((?i)ORDER BY)(.+)((?i)LIMIT)", *sql) + orderbyExpr, _ := gregex.MatchString("((?i)ORDER BY)(.+)((?i)LIMIT)", sql) if len(orderbyExpr) != 4 || strings.EqualFold(orderbyExpr[1], "ORDER BY") == false || strings.EqualFold(orderbyExpr[3], "LIMIT") == false{ break } orderbyStr = orderbyExpr[2] } else { - queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) + queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "LIMIT") == false{ break } @@ -144,14 +137,14 @@ func (db *dbmssql) parseSql(sql *string) *string { } if haveOrderby { - *sql = fmt.Sprintf("SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d", orderbyStr, selectStr, first, limit) + sql = fmt.Sprintf("SELECT * FROM (SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d", orderbyStr, selectStr, first, limit) } else { if first == 0 { first = limit } else { first = limit - first } - *sql = fmt.Sprintf("SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ", first, limit, selectStr) + sql = fmt.Sprintf("SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ", first, limit, selectStr) } default: } diff --git a/g/database/gdb/gdb_mysql.go b/g/database/gdb/gdb_mysql.go index 76775312b..0c6283648 100644 --- a/g/database/gdb/gdb_mysql.go +++ b/g/database/gdb/gdb_mysql.go @@ -12,17 +12,13 @@ import ( "database/sql" ) -// MySQL接口对象 -var linkMysql = &dbmysql{} - - // 数据库链接对象 -type dbmysql struct { - Db +type dbMysql struct { + *dbBase } // 创建SQL操作对象,内部采用了lazy link处理 -func (db *dbmysql) Open (c *ConfigNode) (*sql.DB, error) { +func (db *dbMysql) open (c *ConfigNode) (*sql.DB, error) { var source string if c.Linkinfo != "" { source = c.Linkinfo @@ -36,17 +32,12 @@ func (db *dbmysql) Open (c *ConfigNode) (*sql.DB, error) { } } -// 获得关键字操作符 - 左 -func (db *dbmysql) getQuoteCharLeft () string { - return "`" -} - -// 获得关键字操作符 - 右 -func (db *dbmysql) getQuoteCharRight () string { - return "`" +// 获得关键字操作符 +func (db *dbMysql) getChars () (charLeft string, charRight string) { + return "`", "`" } // 在执行sql之前对sql进行进一步处理 -func (db *dbmysql) handleSqlBeforeExec(q *string) *string { - return q +func (db *dbMysql) handleSqlBeforeExec(query string) string { + return query } \ No newline at end of file diff --git a/g/database/gdb/gdb_oracle.go b/g/database/gdb/gdb_oracle.go index 93ad1ac56..962be5959 100644 --- a/g/database/gdb/gdb_oracle.go +++ b/g/database/gdb/gdb_oracle.go @@ -21,15 +21,13 @@ import ( "strings" ) -var linkOracle = &dboracle{} - // 数据库链接对象 -type dboracle struct { - Db +type dbOracle struct { + *dbBase } // 创建SQL操作对象 -func (db *dboracle) Open(c *ConfigNode) (*sql.DB, error) { +func (db *dbOracle) open(c *ConfigNode) (*sql.DB, error) { var source string if c.Linkinfo != "" { source = c.Linkinfo @@ -43,42 +41,37 @@ func (db *dboracle) Open(c *ConfigNode) (*sql.DB, error) { } } -// 获得关键字操作符 - 左 -func (db *dboracle) getQuoteCharLeft() string { - return "\"" -} - -// 获得关键字操作符 - 右 -func (db *dboracle) getQuoteCharRight() string { - return "\"" +// 获得关键字操作符 +func (db *dbOracle) getChars () (charLeft string, charRight string) { + return "\"", "\"" } // 在执行sql之前对sql进行进一步处理 -func (db *dboracle) handleSqlBeforeExec(q *string) *string { +func (db *dbOracle) handleSqlBeforeExec(query string) string { index := 0 - str, _ := gregex.ReplaceStringFunc("\\?", *q, func(s string) string { + str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string { index++ return fmt.Sprintf(":%d", index) }) str, _ = gregex.ReplaceString("\"", "", str) - return db.parseSql(&str) + return db.parseSql(str) } //由于ORACLE中对LIMIT和批量插入的语法与MYSQL不一致,所以这里需要对LIMIT和批量插入做语法上的转换 -func (db *dboracle) parseSql(sql *string) *string { +func (db *dbOracle) parseSql(sql string) string { //下面的正则表达式匹配出SELECT和INSERT的关键字后分别做不同的处理,如有LIMIT则将LIMIT的关键字也匹配出 patten := `^\s*(?i)(SELECT)|(INSERT)|(LIMIT\s*(\d+)\s*,\s*(\d+))` - if gregex.IsMatchString(patten, *sql) == false { + if gregex.IsMatchString(patten, sql) == false { fmt.Println("not matched..") return sql } - res, err := gregex.MatchAllString(patten, *sql) + res, err := gregex.MatchAllString(patten, sql) if err != nil { fmt.Println("MatchString error.", err) - return nil + return "" } index := 0 @@ -94,11 +87,11 @@ func (db *dboracle) parseSql(sql *string) *string { } //取limit前面的字符串 - if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) == false { + if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) == false { break } - queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", *sql) + queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) if len(queryExpr) != 4 || strings.EqualFold(queryExpr[1], "SELECT") == false || strings.EqualFold(queryExpr[3], "LIMIT") == false{ break } @@ -118,10 +111,10 @@ func (db *dboracle) parseSql(sql *string) *string { } //也可以使用between,据说这种写法的性能会比between好点,里层SQL中的ROWNUM_ >= limit可以缩小查询后的数据集规模 - *sql = fmt.Sprintf("SELECT * FROM (SELECT GFORM.*, ROWNUM ROWNUM_ FROM (%s %s) GFORM WHERE ROWNUM <= %d) WHERE ROWNUM_ >= %d", queryExpr[1], queryExpr[2], limit, first) + sql = fmt.Sprintf("SELECT * FROM (SELECT GFORM.*, ROWNUM ROWNUM_ FROM (%s %s) GFORM WHERE ROWNUM <= %d) WHERE ROWNUM_ >= %d", queryExpr[1], queryExpr[2], limit, first) case "INSERT": //获取VALUE的值,匹配所有带括号的值,会将INSERT INTO后的值匹配到,所以下面的判断语句会判断数组长度是否小于3 - valueExpr, err := gregex.MatchAllString(`(\s*\(([^\(\)]*)\))`, *sql) + valueExpr, err := gregex.MatchAllString(`(\s*\(([^\(\)]*)\))`, sql) if err != nil { return sql } @@ -132,17 +125,17 @@ func (db *dboracle) parseSql(sql *string) *string { } //获取INTO后面的值 - tableExpr, err := gregex.MatchString(`(?i)\s*(INTO\s+\w+\(([^\(\)]*)\))`, *sql) + tableExpr, err := gregex.MatchString(`(?i)\s*(INTO\s+\w+\(([^\(\)]*)\))`, sql) if err != nil { return sql } tableExpr[0] = strings.TrimSpace(tableExpr[0]) - *sql = "INSERT ALL" + sql = "INSERT ALL" for i := 1; i < len(valueExpr); i++ { - *sql += fmt.Sprintf(" %s VALUES%s", tableExpr[0], strings.TrimSpace(valueExpr[i][0])) + sql += fmt.Sprintf(" %s VALUES%s", tableExpr[0], strings.TrimSpace(valueExpr[i][0])) } - *sql += " SELECT 1 FROM DUAL" + sql += " SELECT 1 FROM DUAL" default: } diff --git a/g/database/gdb/gdb_pgsql.go b/g/database/gdb/gdb_pgsql.go index a16a1efe4..874263465 100644 --- a/g/database/gdb/gdb_pgsql.go +++ b/g/database/gdb/gdb_pgsql.go @@ -18,17 +18,13 @@ import ( // _ "gitee.com/johng/gf/third/github.com/lib/pq" // @todo 需要完善replace和save的操作覆盖 -// PostgreSQL接口对象 -var linkPgsql = &dbpgsql{} - - // 数据库链接对象 -type dbpgsql struct { - Db +type dbPgsql struct { + *dbBase } // 创建SQL操作对象,内部采用了lazy link处理 -func (db *dbpgsql) Open (c *ConfigNode) (*sql.DB, error) { +func (db *dbPgsql) open (c *ConfigNode) (*sql.DB, error) { var source string if c.Linkinfo != "" { source = c.Linkinfo @@ -42,23 +38,18 @@ func (db *dbpgsql) Open (c *ConfigNode) (*sql.DB, error) { } } -// 获得关键字操作符 - 左 -func (db *dbpgsql) getQuoteCharLeft () string { - return "\"" -} - -// 获得关键字操作符 - 右 -func (db *dbpgsql) getQuoteCharRight () string { - return "\"" +// 获得关键字操作符 +func (db *dbPgsql) getChars () (charLeft string, charRight string) { + return "\"", "\"" } // 在执行sql之前对sql进行进一步处理 -func (db *dbpgsql) handleSqlBeforeExec(q *string) *string { +func (db *dbPgsql) handleSqlBeforeExec(query string) string { reg := regexp.MustCompile("\\?") index := 0 - str := reg.ReplaceAllStringFunc(*q, func (s string) string { + str := reg.ReplaceAllStringFunc(query, func (s string) string { index ++ return fmt.Sprintf("$%d", index) }) - return &str + return str } \ No newline at end of file diff --git a/g/database/gdb/gdb_sqlite.go b/g/database/gdb/gdb_sqlite.go index 0a92972bb..68810c517 100644 --- a/g/database/gdb/gdb_sqlite.go +++ b/g/database/gdb/gdb_sqlite.go @@ -16,15 +16,13 @@ import ( // Sqlite接口对象 // @author wxkj -var linkSqlite = &dbsqlite{} - // 数据库链接对象 -type dbsqlite struct { - Db +type dbSqlite struct { + *dbBase } -func (db *dbsqlite) Open(c *ConfigNode) (*sql.DB, error) { +func (db *dbSqlite) open(c *ConfigNode) (*sql.DB, error) { var source string if c.Linkinfo != "" { source = c.Linkinfo @@ -38,20 +36,14 @@ func (db *dbsqlite) Open(c *ConfigNode) (*sql.DB, error) { } } -// 获得关键字操作符 - 左 -func (db *dbsqlite) getQuoteCharLeft() string { - return "`" -} - -// 获得关键字操作符 - 右 -func (db *dbsqlite) getQuoteCharRight() string { - return "`" +// 获得关键字操作符 +func (db *dbSqlite) getChars () (charLeft string, charRight string) { + return "`", "`" } // 在执行sql之前对sql进行进一步处理 // @todo 需要增加对Save方法的支持,可使用正则来实现替换, // @todo 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE) -func (db *dbsqlite) handleSqlBeforeExec(q *string) *string { - - return q +func (db *dbSqlite) handleSqlBeforeExec(query string) string { + return query } \ No newline at end of file diff --git a/g/database/gdb/gdb_transaction.go b/g/database/gdb/gdb_transaction.go index 692b3bb5b..cd8e48976 100644 --- a/g/database/gdb/gdb_transaction.go +++ b/g/database/gdb/gdb_transaction.go @@ -19,79 +19,34 @@ import ( ) // 数据库事务对象 -type Tx struct { - db *Db +type TX struct { + db DB tx *sql.Tx master *sql.DB } // 事务操作,提交 -func (tx *Tx) Commit() error { +func (tx *TX) Commit() error { return tx.tx.Commit() } // 事务操作,回滚 -func (tx *Tx) Rollback() error { +func (tx *TX) Rollback() error { return tx.tx.Rollback() } // (事务)数据库sql查询操作,主要执行查询 -func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { - var err error - var rows *sql.Rows - p := tx.db.link.handleSqlBeforeExec(&query) - if tx.db.debug.Val() { - militime1 := gtime.Millisecond() - rows, err = tx.tx.Query(*p, args ...) - militime2 := gtime.Millisecond() - s := &Sql{ - Sql : *p, - Args : args, - Error : err, - Start : militime1, - End : militime2, - Func : "TX:Query", - } - tx.db.sqls.Put(s) - tx.db.printSql(s) - } else { - rows, err = tx.tx.Query(*p, args ...) - } - if err == nil { - return rows, nil - } else { - err = tx.db.formatError(err, p, args...) - } - return nil, err +func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { + return doQuery(tx.db, tx.tx, query, args...) } // (事务)执行一条sql,并返回执行情况,主要用于非查询操作 -func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { - var err error - var result sql.Result - p := tx.db.link.handleSqlBeforeExec(&query) - if tx.db.debug.Val() { - militime1 := gtime.Millisecond() - result, err = tx.tx.Exec(*p, args ...) - militime2 := gtime.Millisecond() - s := &Sql{ - Sql : *p, - Args : args, - Error : err, - Start : militime1, - End : militime2, - Func : "TX:Exec", - } - tx.db.sqls.Put(s) - tx.db.printSql(s) - } else { - result, err = tx.tx.Exec(*p, args ...) - } - return result, tx.db.formatError(err, p, args...) +func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) { + return doExec(tx.db, tx.tx, query, args...) } // 数据库查询,获取查询结果集,以列表结构返回 -func (tx *Tx) GetAll(query string, args ...interface{}) (Result, error) { +func (tx *TX) GetAll(query string, args ...interface{}) (Result, error) { // 执行sql rows, err := tx.Query(query, args ...) if err != nil || rows == nil { @@ -128,7 +83,7 @@ func (tx *Tx) GetAll(query string, args ...interface{}) (Result, error) { } // 数据库查询,获取查询结果记录,以关联数组结构返回 -func (tx *Tx) GetOne(query string, args ...interface{}) (Record, error) { +func (tx *TX) GetOne(query string, args ...interface{}) (Record, error) { list, err := tx.GetAll(query, args ...) if err != nil { return nil, err @@ -140,7 +95,7 @@ func (tx *Tx) GetOne(query string, args ...interface{}) (Record, error) { } // 数据库查询,获取查询结果记录,自动映射数据到给定的struct对象中 -func (tx *Tx) GetStruct(obj interface{}, query string, args ...interface{}) error { +func (tx *TX) GetStruct(obj interface{}, query string, args ...interface{}) error { one, err := tx.GetOne(query, args...) if err != nil { return err @@ -150,7 +105,7 @@ func (tx *Tx) GetStruct(obj interface{}, query string, args ...interface{}) erro // 数据库查询,获取查询字段值 -func (tx *Tx) GetValue(query string, args ...interface{}) (Value, error) { +func (tx *TX) GetValue(query string, args ...interface{}) (Value, error) { one, err := tx.GetOne(query, args ...) if err != nil { return nil, err @@ -162,7 +117,7 @@ func (tx *Tx) GetValue(query string, args ...interface{}) (Value, error) { } // 数据库查询,获取查询数量 -func (tx *Tx) GetCount(query string, args ...interface{}) (int, error) { +func (tx *TX) GetCount(query string, args ...interface{}) (int, error) { val, err := tx.GetValue(query, args ...) if err != nil { return 0, err @@ -171,7 +126,7 @@ func (tx *Tx) GetCount(query string, args ...interface{}) (int, error) { } // 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作 -func (tx *Tx) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) { +func (tx *TX) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) { s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables) if condition != nil { s += fmt.Sprintf("WHERE %s ", tx.db.formatCondition(condition)) @@ -189,7 +144,7 @@ func (tx *Tx) Select(tables, fields string, condition interface{}, groupBy, orde } // sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 -func (tx *Tx) Prepare(query string) (*sql.Stmt, error) { +func (tx *TX) Prepare(query string) (*sql.Stmt, error) { return tx.tx.Prepare(query) } @@ -198,7 +153,7 @@ func (tx *Tx) Prepare(query string) (*sql.Stmt, error) { // 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 // 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 // 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做 -func (tx *Tx) insert(table string, data Map, option uint8) (sql.Result, error) { +func (tx *TX) insert(table string, data Map, option uint8) (sql.Result, error) { var keys []string var values []string var params []interface{} @@ -226,22 +181,22 @@ func (tx *Tx) insert(table string, data Map, option uint8) (sql.Result, error) { } // CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 -func (tx *Tx) Insert(table string, data Map) (sql.Result, error) { +func (tx *TX) Insert(table string, data Map) (sql.Result, error) { return tx.insert(table, data, OPTION_INSERT) } // CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -func (tx *Tx) Replace(table string, data Map) (sql.Result, error) { +func (tx *TX) Replace(table string, data Map) (sql.Result, error) { return tx.insert(table, data, OPTION_REPLACE) } // CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -func (tx *Tx) Save(table string, data Map) (sql.Result, error) { +func (tx *TX) Save(table string, data Map) (sql.Result, error) { return tx.insert(table, data, OPTION_SAVE) } // 批量写入数据 -func (tx *Tx) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) { +func (tx *TX) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) { var keys []string var values []string var bvalues []string @@ -303,23 +258,23 @@ func (tx *Tx) batchInsert(table string, list List, batch int, option uint8) (sql } // CURD操作:批量数据指定批次量写入 -func (tx *Tx) BatchInsert(table string, list List, batch int) (sql.Result, error) { +func (tx *TX) BatchInsert(table string, list List, batch int) (sql.Result, error) { return tx.batchInsert(table, list, batch, OPTION_INSERT) } // CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 -func (tx *Tx) BatchReplace(table string, list List, batch int) (sql.Result, error) { +func (tx *TX) BatchReplace(table string, list List, batch int) (sql.Result, error) { return tx.batchInsert(table, list, batch, OPTION_REPLACE) } // CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 -func (tx *Tx) BatchSave(table string, list List, batch int) (sql.Result, error) { +func (tx *TX) BatchSave(table string, list List, batch int) (sql.Result, error) { return tx.batchInsert(table, list, batch, OPTION_SAVE) } // CURD操作:数据更新,统一采用sql预处理 // data参数支持字符串或者关联数组类型,内部会自行做判断处理 -func (tx *Tx) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { +func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { var params []interface{} var updates string refValue := reflect.ValueOf(data) @@ -341,7 +296,7 @@ func (tx *Tx) Update(table string, data interface{}, condition interface{}, args } // CURD操作:删除数据 -func (tx *Tx) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { +func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { return tx.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", table, tx.db.formatCondition(condition)), args...) }