Files
gf/g/database/gdb/gdb_transaction.go

353 lines
12 KiB
Go
Raw Normal View History

2018-03-09 17:55:42 +08:00
// Copyright 2017 gf Author(https://gitee.com/johng/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://gitee.com/johng/gf.
package gdb
import (
"fmt"
"errors"
2018-03-12 11:46:12 +08:00
"strings"
"reflect"
2018-08-08 20:09:52 +08:00
"database/sql"
"gitee.com/johng/gf/g/os/gtime"
2018-08-08 20:09:52 +08:00
"gitee.com/johng/gf/g/util/gconv"
_ "gitee.com/johng/gf/third/github.com/go-sql-driver/mysql"
"gitee.com/johng/gf/g/container/gvar"
2018-03-09 17:55:42 +08:00
)
// 数据库事务对象
type Tx struct {
db *Db
tx *sql.Tx
master *sql.DB
2018-03-12 11:46:12 +08:00
}
// 事务操作,提交
func (tx *Tx) Commit() error {
err := tx.tx.Commit()
tx.master.Close()
return err
2018-03-12 11:46:12 +08:00
}
2018-03-09 17:55:42 +08:00
2018-03-12 11:46:12 +08:00
// 事务操作,回滚
func (tx *Tx) Rollback() error {
err := tx.tx.Rollback()
tx.master.Close()
return err
2018-03-09 17:55:42 +08:00
}
2018-03-12 15:38:27 +08:00
// (事务)数据库sql查询操作主要执行查询
2018-03-12 15:12:38 +08:00
func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
var err error
var rows *sql.Rows
p := tx.db.link.handleSqlBeforeExec(&query)
if tx.db.debug.Val() {
militime1 := gtime.Millisecond()
rows, err = tx.tx.Query(*p, args ...)
militime2 := gtime.Millisecond()
s := &Sql{
Sql : *p,
Args : args,
Error : err,
Start : militime1,
End : militime2,
Func : "TX:Query",
}
tx.db.sqls.Put(s)
tx.db.printSql(s)
} else {
rows, err = tx.tx.Query(*p, args ...)
}
2018-03-09 17:55:42 +08:00
if err == nil {
return rows, nil
} else {
err = tx.db.formatError(err, p, args...)
2018-03-09 17:55:42 +08:00
}
return nil, err
}
2018-03-12 15:38:27 +08:00
// (事务)执行一条sql并返回执行情况主要用于非查询操作
2018-03-12 15:12:38 +08:00
func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
var err error
var result sql.Result
p := tx.db.link.handleSqlBeforeExec(&query)
if tx.db.debug.Val() {
militime1 := gtime.Millisecond()
result, err = tx.tx.Exec(*p, args ...)
militime2 := gtime.Millisecond()
s := &Sql{
Sql : *p,
Args : args,
Error : err,
Start : militime1,
End : militime2,
Func : "TX:Exec",
}
tx.db.sqls.Put(s)
tx.db.printSql(s)
} else {
result, err = tx.tx.Exec(*p, args ...)
}
return result, tx.db.formatError(err, p, args...)
2018-03-09 17:55:42 +08:00
}
// 数据库查询,获取查询结果集,以列表结构返回
func (tx *Tx) GetAll(query string, args ...interface{}) (Result, error) {
2018-03-09 17:55:42 +08:00
// 执行sql
2018-03-12 15:12:38 +08:00
rows, err := tx.Query(query, args ...)
2018-03-09 17:55:42 +08:00
if err != nil || rows == nil {
return nil, err
}
// 列名称列表
columns, err := rows.Columns()
if err != nil {
return nil, err
}
// 返回结构组装
values := make([]sql.RawBytes, len(columns))
scanArgs := make([]interface{}, len(values))
records := make(Result, 0)
2018-03-09 17:55:42 +08:00
for i := range values {
scanArgs[i] = &values[i]
}
for rows.Next() {
err = rows.Scan(scanArgs...)
if err != nil {
return records, err
2018-03-09 17:55:42 +08:00
}
row := make(Record)
// 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址
2018-03-09 17:55:42 +08:00
for i, col := range values {
v := make([]byte, len(col))
copy(v, col)
row[columns[i]] = gvar.New(v)
2018-03-09 17:55:42 +08:00
}
//fmt.Printf("%p\n", row["typeid"])
records = append(records, row)
2018-03-09 17:55:42 +08:00
}
return records, nil
2018-03-09 17:55:42 +08:00
}
// 数据库查询,获取查询结果记录,以关联数组结构返回
func (tx *Tx) GetOne(query string, args ...interface{}) (Record, error) {
2018-03-12 15:12:38 +08:00
list, err := tx.GetAll(query, args ...)
2018-03-09 17:55:42 +08:00
if err != nil {
return nil, err
}
if len(list) > 0 {
return list[0], nil
}
return nil, nil
}
// 数据库查询获取查询结果记录自动映射数据到给定的struct对象中
func (tx *Tx) GetStruct(obj interface{}, query string, args ...interface{}) error {
one, err := tx.GetOne(query, args...)
if err != nil {
return err
}
return one.ToStruct(obj)
2018-03-09 17:55:42 +08:00
}
// 数据库查询,获取查询字段值
2018-05-02 18:55:23 +08:00
func (tx *Tx) GetValue(query string, args ...interface{}) (Value, error) {
2018-03-12 15:12:38 +08:00
one, err := tx.GetOne(query, args ...)
2018-03-09 17:55:42 +08:00
if err != nil {
2018-03-12 15:12:38 +08:00
return nil, err
2018-03-09 17:55:42 +08:00
}
for _, v := range one {
return v, nil
}
2018-03-12 15:12:38 +08:00
return nil, nil
2018-03-09 17:55:42 +08:00
}
// 数据库查询,获取查询数量
func (tx *Tx) GetCount(query string, args ...interface{}) (int, error) {
val, err := tx.GetValue(query, args ...)
if err != nil {
return 0, err
}
return gconv.Int(val), nil
}
// 数据表查询其中tables可以是多个联表查询语句这种查询方式较复杂建议使用链式操作
func (tx *Tx) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) {
s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables)
if condition != nil {
s += fmt.Sprintf("WHERE %s ", tx.db.formatCondition(condition))
}
if len(groupBy) > 0 {
s += fmt.Sprintf("GROUP BY %s ", groupBy)
}
if len(orderBy) > 0 {
s += fmt.Sprintf("ORDER BY %s ", orderBy)
}
if limit > 0 {
s += fmt.Sprintf("LIMIT %d,%d ", first, limit)
}
return tx.GetAll(s, args ... )
}
// sql预处理执行完成后调用返回值sql.Stmt.Exec完成sql操作
2018-03-09 17:55:42 +08:00
// 记得调用sql.Stmt.Close关闭操作对象
2018-03-12 15:12:38 +08:00
func (tx *Tx) Prepare(query string) (*sql.Stmt, error) {
return tx.tx.Prepare(query)
2018-03-09 17:55:42 +08:00
}
// insert、replace, save ignore操作
2018-03-09 17:55:42 +08:00
// 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
// 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
// 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
// 3: ignore: 如果数据存在(主键或者唯一索引),那么什么也不做
func (tx *Tx) insert(table string, data Map, option uint8) (sql.Result, error) {
var keys []string
var values []string
var params []interface{}
for k, v := range data {
keys = append(keys, tx.db.charl + k + tx.db.charr)
values = append(values, "?")
params = append(params, v)
2018-03-09 17:55:42 +08:00
}
operation := tx.db.getInsertOperationByOption(option)
updatestr := ""
if option == OPTION_SAVE {
var updates []string
for k, _ := range data {
2018-03-12 11:46:12 +08:00
updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k))
2018-03-09 17:55:42 +08:00
}
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
}
return tx.Exec(
fmt.Sprintf("%s INTO %s%s%s(%s) VALUES(%s) %s",
operation, tx.db.charl, table, tx.db.charr, strings.Join(keys, ","),
strings.Join(values, ","),
updatestr),
params...
2018-03-09 17:55:42 +08:00
)
}
// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回
2018-03-09 17:55:42 +08:00
func (tx *Tx) Insert(table string, data Map) (sql.Result, error) {
2018-03-12 11:46:12 +08:00
return tx.insert(table, data, OPTION_INSERT)
2018-03-09 17:55:42 +08:00
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
2018-03-09 17:55:42 +08:00
func (tx *Tx) Replace(table string, data Map) (sql.Result, error) {
2018-03-12 11:46:12 +08:00
return tx.insert(table, data, OPTION_REPLACE)
2018-03-09 17:55:42 +08:00
}
// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
2018-03-09 17:55:42 +08:00
func (tx *Tx) Save(table string, data Map) (sql.Result, error) {
2018-03-12 11:46:12 +08:00
return tx.insert(table, data, OPTION_SAVE)
2018-03-09 17:55:42 +08:00
}
// 批量写入数据
2018-03-09 17:55:42 +08:00
func (tx *Tx) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) {
var keys []string
var values []string
var bvalues []string
var params []interface{}
var result sql.Result
var size = len(list)
// 判断长度
if size < 1 {
return result, errors.New("empty data list")
}
// 首先获取字段名称及记录长度
for k, _ := range list[0] {
keys = append(keys, k)
values = append(values, "?")
}
2018-06-30 23:12:37 +08:00
keyStr := tx.db.charl + strings.Join(keys, tx.db.charl + "," + tx.db.charr) + tx.db.charr
valueHolderStr := "(" + strings.Join(values, ",") + ")"
2018-03-09 17:55:42 +08:00
// 操作判断
2018-03-12 11:46:12 +08:00
operation := tx.db.getInsertOperationByOption(option)
2018-03-09 17:55:42 +08:00
updatestr := ""
if option == OPTION_SAVE {
var updates []string
for _, k := range keys {
2018-06-30 22:52:39 +08:00
updates = append(updates, fmt.Sprintf("%s%s%s=VALUES(%s)", tx.db.charl, k, tx.db.charr, k))
2018-03-09 17:55:42 +08:00
}
updatestr = fmt.Sprintf(" ON DUPLICATE KEY UPDATE %s", strings.Join(updates, ","))
}
// 构造批量写入数据格式(注意map的遍历是无序的)
for i := 0; i < size; i++ {
for _, k := range keys {
params = append(params, list[i][k])
2018-03-09 17:55:42 +08:00
}
2018-06-30 23:12:37 +08:00
bvalues = append(bvalues, valueHolderStr)
2018-03-09 17:55:42 +08:00
if len(bvalues) == batch {
2018-06-30 23:12:37 +08:00
r, err := tx.Exec(fmt.Sprintf("%s INTO %s%s%s(%s) VALUES%s %s",
operation, tx.db.charl, table, tx.db.charr, keyStr, strings.Join(bvalues, ","),
updatestr),
params...)
2018-03-09 17:55:42 +08:00
if err != nil {
return result, err
}
result = r
2018-06-30 23:12:37 +08:00
params = params[:0]
2018-03-09 17:55:42 +08:00
bvalues = bvalues[:0]
}
}
// 处理最后不构成指定批量的数据
if len(bvalues) > 0 {
2018-06-30 23:12:37 +08:00
r, err := tx.Exec(fmt.Sprintf("%s INTO %s%s%s(%s) VALUES%s %s",
2018-06-30 23:14:54 +08:00
operation, tx.db.charl, table, tx.db.charr, keyStr, strings.Join(bvalues, ","),
2018-06-30 23:12:37 +08:00
updatestr),
params...)
2018-03-09 17:55:42 +08:00
if err != nil {
return result, err
}
result = r
}
return result, nil
}
// CURD操作:批量数据指定批次量写入
2018-03-09 17:55:42 +08:00
func (tx *Tx) BatchInsert(table string, list List, batch int) (sql.Result, error) {
2018-03-12 11:46:12 +08:00
return tx.batchInsert(table, list, batch, OPTION_INSERT)
2018-03-09 17:55:42 +08:00
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条
2018-03-09 17:55:42 +08:00
func (tx *Tx) BatchReplace(table string, list List, batch int) (sql.Result, error) {
2018-03-12 11:46:12 +08:00
return tx.batchInsert(table, list, batch, OPTION_REPLACE)
2018-03-09 17:55:42 +08:00
}
// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据
2018-03-09 17:55:42 +08:00
func (tx *Tx) BatchSave(table string, list List, batch int) (sql.Result, error) {
2018-03-12 11:46:12 +08:00
return tx.batchInsert(table, list, batch, OPTION_SAVE)
2018-03-09 17:55:42 +08:00
}
// CURD操作:数据更新统一采用sql预处理
2018-03-09 17:55:42 +08:00
// data参数支持字符串或者关联数组类型内部会自行做判断处理
func (tx *Tx) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
var params []interface{}
var updates string
refValue := reflect.ValueOf(data)
if refValue.Kind() == reflect.Map {
var fields []string
keys := refValue.MapKeys()
for _, k := range keys {
fields = append(fields, fmt.Sprintf("%s%s%s=?", tx.db.charl, k, tx.db.charr))
params = append(params, gconv.String(refValue.MapIndex(k).Interface()))
updates = strings.Join(fields, ",")
}
} else {
updates = gconv.String(data)
2018-03-09 17:55:42 +08:00
}
for _, v := range args {
params = append(params, gconv.String(v))
2018-03-09 17:55:42 +08:00
}
return tx.Exec(fmt.Sprintf("UPDATE %s%s%s SET %s WHERE %s", tx.db.charl, table, tx.db.charr, updates, tx.db.formatCondition(condition)), params...)
2018-03-09 17:55:42 +08:00
}
// CURD操作:删除数据
2018-03-09 17:55:42 +08:00
func (tx *Tx) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
return tx.Exec(fmt.Sprintf("DELETE FROM %s%s%s WHERE %s", tx.db.charl, table, tx.db.charr, tx.db.formatCondition(condition)), args...)
2018-03-09 17:55:42 +08:00
}