improve gdb.Update/Delete feature to support orderby/limit features

This commit is contained in:
John
2019-07-02 23:14:40 +08:00
parent e63e989d41
commit c948f0c287
31 changed files with 136 additions and 72 deletions

View File

@ -14,13 +14,14 @@ import (
"database/sql"
"errors"
"fmt"
"time"
"github.com/gogf/gf/g/container/gmap"
"github.com/gogf/gf/g/container/gring"
"github.com/gogf/gf/g/container/gtype"
"github.com/gogf/gf/g/container/gvar"
"github.com/gogf/gf/g/os/gcache"
"github.com/gogf/gf/g/util/grand"
"time"
)
// 数据库操作接口

View File

@ -11,13 +11,14 @@ import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"github.com/gogf/gf/g/container/gvar"
"github.com/gogf/gf/g/os/gcache"
"github.com/gogf/gf/g/os/gtime"
"github.com/gogf/gf/g/text/gregex"
"github.com/gogf/gf/g/util/gconv"
"reflect"
"strings"
)
const (
@ -498,7 +499,10 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
// CURD操作:数据更新统一采用sql预处理。
// data参数支持string/map/struct/*struct类型。
func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
newWhere, newArgs := formatCondition(condition, args)
newWhere, newArgs := formatWhere(condition, args)
if newWhere != "" {
newWhere = " WHERE " + newWhere
}
return bs.db.doUpdate(nil, table, data, newWhere, newArgs...)
}
@ -537,15 +541,15 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
return nil, err
}
}
if len(condition) == 0 {
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s", table, updates), args...)
}
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, condition), args...)
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition), args...)
}
// CURD操作:删除数据
func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
newWhere, newArgs := formatCondition(condition, args)
newWhere, newArgs := formatWhere(condition, args)
if newWhere != "" {
newWhere = " WHERE " + newWhere
}
return bs.db.doDelete(nil, table, newWhere, newArgs...)
}
@ -556,10 +560,7 @@ func (bs *dbBase) doDelete(link dbLink, table string, condition string, args ...
return nil, err
}
}
if len(condition) == 0 {
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s", table), args...)
}
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, condition), args...)
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
}
// 获得缓存对象
@ -570,12 +571,15 @@ func (bs *dbBase) getCache() *gcache.Cache {
// 将数据查询的列表数据*sql.Rows转换为Result类型
func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
// 列信息列表, 名称与类型
types := make([]string, 0)
columns := make([]string, 0)
columnTypes, _ := rows.ColumnTypes()
for _, t := range columnTypes {
types = append(types, t.DatabaseTypeName())
columns = append(columns, t.Name())
columnTypes, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
types := make([]string, len(columnTypes))
columns := make([]string, len(columnTypes))
for k, v := range columnTypes {
types[k] = v.DatabaseTypeName()
columns[k] = v.Name()
}
// 返回结构组装
values := make([]sql.RawBytes, len(columns))
@ -589,14 +593,15 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
return records, err
}
row := make(Record)
// 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址
for i, col := range values {
if col == nil {
// 注意col字段是一个[]byte类型(slice类型本身是一个引用类型)
// 多个记录循环时该变量指向的是同一个内存地址
for i, column := range values {
if column == nil {
row[columns[i]] = gvar.New(nil, true)
} else {
// 由于 sql.RawBytes 是slice类型, 这里必须使用值复制
v := make([]byte, len(col))
copy(v, col)
v := make([]byte, len(column))
copy(v, column)
row[columns[i]] = gvar.New(bs.db.convertValue(v, types[i]), true)
}
}

View File

@ -10,14 +10,15 @@ import (
"bytes"
"errors"
"fmt"
"reflect"
"strings"
"time"
"github.com/gogf/gf/g/os/glog"
"github.com/gogf/gf/g/os/gtime"
"github.com/gogf/gf/g/text/gregex"
"github.com/gogf/gf/g/text/gstr"
"github.com/gogf/gf/g/util/gconv"
"reflect"
"strings"
"time"
)
// Type assert api for String().
@ -25,8 +26,8 @@ type apiString interface {
String() string
}
// 格式化SQL查询条件
func formatCondition(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) {
// 格式化Where查询条件
func formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) {
// 条件字符串处理
buffer := bytes.NewBuffer(nil)
// 使用反射进行类型判断

View File

@ -12,8 +12,9 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/gogf/gf/g/util/gconv"
"reflect"
"github.com/gogf/gf/g/util/gconv"
)
// 数据库链式操作模型对象
@ -29,13 +30,13 @@ type Model struct {
orderBy string // 排序语句
start int // 分页开始
limit int // 分页条数
data interface{} // 操作记录(支持Map/List/string类型)
data interface{} // 操作数据(注意仅支持Map/List/string类型)
batch int // 批量操作条数
filter bool // 是否按照表字段过滤data参数
cacheEnabled bool // 当前SQL操作是否开启查询缓存功能
cacheTime int // 查询缓存时间
cacheName string // 查询缓存名称
safe bool // 当前模型是否运行安全模式(修改当前模型否则每一次链式操作都是返回新的模型对象)
safe bool // 当前模型是否安全模式(默认非安全表示链式操作直接修改当前模型属性;否则每一次链式操作都是返回新的模型对象)
}
// 链式操作,数据表字段,可支持多个表,以半角逗号连接
@ -45,6 +46,7 @@ func (bs *dbBase) Table(tables string) *Model {
tablesInit: tables,
tables: tables,
fields: "*",
start: -1,
safe: false,
}
}
@ -149,7 +151,7 @@ func (md *Model) Where(where interface{}, args ...interface{}) *Model {
if model.where != "" {
return md.And(where, args...)
}
newWhere, newArgs := formatCondition(where, args)
newWhere, newArgs := formatWhere(where, args)
model.where = newWhere
model.whereArgs = newArgs
return model
@ -158,7 +160,7 @@ func (md *Model) Where(where interface{}, args ...interface{}) *Model {
// 链式操作添加AND条件到Where中
func (md *Model) And(where interface{}, args ...interface{}) *Model {
model := md.getModel()
newWhere, newArgs := formatCondition(where, args)
newWhere, newArgs := formatWhere(where, args)
if len(model.where) > 0 && model.where[0] == '(' {
model.where = fmt.Sprintf(`%s AND (%s)`, model.where, newWhere)
} else {
@ -171,7 +173,7 @@ func (md *Model) And(where interface{}, args ...interface{}) *Model {
// 链式操作添加OR条件到Where中
func (md *Model) Or(where interface{}, args ...interface{}) *Model {
model := md.getModel()
newWhere, newArgs := formatCondition(where, args)
newWhere, newArgs := formatWhere(where, args)
if len(model.where) > 0 && model.where[0] == '(' {
model.where = fmt.Sprintf(`%s OR (%s)`, model.where, newWhere)
} else {
@ -195,11 +197,20 @@ func (md *Model) OrderBy(orderBy string) *Model {
return model
}
// 链式操作limit
func (md *Model) Limit(start int, limit int) *Model {
// 链式操作limit
//
// 如果给定一个参数那么生成的SQL为LIMIT limit[0]
//
// 如果给定两个参数那么生成的SQL为LIMIT limit[0], limit[1]
func (md *Model) Limit(limit ...int) *Model {
model := md.getModel()
model.start = start
model.limit = limit
switch len(limit) {
case 1:
model.limit = limit[0]
case 2:
model.start = limit[0]
model.limit = limit[1]
}
return model
}
@ -425,9 +436,9 @@ func (md *Model) Update() (result sql.Result, err error) {
}
}
if md.tx == nil {
return md.db.doUpdate(nil, md.tables, md.data, md.where, md.whereArgs...)
return md.db.doUpdate(nil, md.tables, md.data, md.getConditionSql(), md.whereArgs...)
} else {
return md.tx.doUpdate(md.tables, md.data, md.where, md.whereArgs...)
return md.tx.doUpdate(md.tables, md.data, md.getConditionSql(), md.whereArgs...)
}
}
@ -439,9 +450,9 @@ func (md *Model) Delete() (result sql.Result, err error) {
}
}()
if md.tx == nil {
return md.db.doDelete(nil, md.tables, md.where, md.whereArgs...)
return md.db.doDelete(nil, md.tables, md.getConditionSql(), md.whereArgs...)
} else {
return md.tx.doDelete(md.tables, md.where, md.whereArgs...)
return md.tx.doDelete(md.tables, md.getConditionSql(), md.whereArgs...)
}
}
@ -452,7 +463,7 @@ func (md *Model) Select() (Result, error) {
// 链式操作,查询所有记录
func (md *Model) All() (Result, error) {
return md.getAll(md.getFormattedSql(), md.whereArgs...)
return md.getAll(fmt.Sprintf("SELECT %s FROM %s %s", md.fields, md.tables, md.getConditionSql()), md.whereArgs...)
}
// 链式操作,查询单条记录
@ -530,7 +541,7 @@ func (md *Model) Count() (int, error) {
} else {
md.fields = fmt.Sprintf(`COUNT(%s)`, md.fields)
}
s := md.getFormattedSql()
s := fmt.Sprintf("SELECT %s FROM %s %s", md.fields, md.tables, md.getConditionSql())
if len(md.groupBy) > 0 {
s = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", s)
}
@ -583,12 +594,9 @@ func (md *Model) checkAndRemoveCache() {
}
}
// 格式化当前输入参数,返回可执行的SQL语句不带参数
func (md *Model) getFormattedSql() string {
if md.fields == "" {
md.fields = "*"
}
s := fmt.Sprintf("SELECT %s FROM %s", md.fields, md.tables)
// 格式化当前输入参数返回SQL条件语句(不带参数)
func (md *Model) getConditionSql() string {
s := ""
if md.where != "" {
s += " WHERE " + md.where
}
@ -599,7 +607,12 @@ func (md *Model) getFormattedSql() string {
s += " ORDER BY " + md.orderBy
}
if md.limit != 0 {
s += fmt.Sprintf(" LIMIT %d, %d", md.start, md.limit)
if md.start >= 0 {
s += fmt.Sprintf(" LIMIT %d, %d", md.start, md.limit)
} else {
s += fmt.Sprintf(" LIMIT %d", md.limit)
}
}
return s
}

View File

@ -9,8 +9,9 @@ package gdb
import (
"database/sql"
"fmt"
"github.com/gogf/gf/g/text/gregex"
"reflect"
"github.com/gogf/gf/g/text/gregex"
)
// 数据库事务对象
@ -164,7 +165,10 @@ func (tx *TX) BatchSave(table string, list interface{}, batch ...int) (sql.Resul
// CURD操作:数据更新统一采用sql预处理,
// data参数支持字符串或者关联数组类型内部会自行做判断处理.
func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
newWhere, newArgs := formatCondition(condition, args)
newWhere, newArgs := formatWhere(condition, args)
if newWhere != "" {
newWhere = " WHERE " + newWhere
}
return tx.doUpdate(table, data, newWhere, newArgs...)
}
@ -175,7 +179,10 @@ func (tx *TX) doUpdate(table string, data interface{}, condition string, args ..
// CURD操作:删除数据
func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
newWhere, newArgs := formatCondition(condition, args)
newWhere, newArgs := formatWhere(condition, args)
if newWhere != "" {
newWhere = " WHERE " + newWhere
}
return tx.doDelete(table, newWhere, newArgs...)
}

View File

@ -8,12 +8,14 @@ package gdb_test
import (
"fmt"
"os"
"github.com/gogf/gf/g"
"github.com/gogf/gf/g/container/garray"
"github.com/gogf/gf/g/database/gdb"
"github.com/gogf/gf/g/os/gtime"
"github.com/gogf/gf/g/test/gtest"
"os"
)
const (
@ -88,14 +90,6 @@ func createTable(table ...string) (name string) {
return
}
// 删除指定表.
func dropTable(table string) {
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
gtest.Fatal(err)
}
}
// See createTable.
// 创建测试表,并初始化默认数据。
func createInitTable(table ...string) (name string) {
name = createTable(table...)
@ -117,3 +111,10 @@ func createInitTable(table ...string) (name string) {
gtest.Assert(n, INIT_DATA_SIZE)
return
}
// 删除指定表.
func dropTable(table string) {
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
gtest.Fatal(err)
}
}

View File

@ -7,10 +7,11 @@
package gdb_test
import (
"testing"
"github.com/gogf/gf/g"
"github.com/gogf/gf/g/os/gtime"
"github.com/gogf/gf/g/test/gtest"
"testing"
)
// 基本测试
@ -187,6 +188,25 @@ func TestModel_Save(t *testing.T) {
}
func TestModel_Update(t *testing.T) {
table := createInitTable()
// UPDATE...LIMIT
gtest.Case(t, func() {
result, err := db.Table(table).Data("nickname", "T100").OrderBy("id desc").Limit(2).Update()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 2)
v1, err := db.Table(table).Fields("nickname").Where("id", 10).Value()
gtest.Assert(err, nil)
gtest.Assert(v1.String(), "T100")
v2, err := db.Table(table).Fields("nickname").Where("id", 8).Value()
gtest.Assert(err, nil)
gtest.Assert(v2.String(), "T8")
})
gtest.Case(t, func() {
result, err := db.Table("user").Data("passport", "t22").Where("passport=?", "t2").Update()
if err != nil {
@ -644,10 +664,22 @@ func TestModel_Where(t *testing.T) {
}
func TestModel_Delete(t *testing.T) {
result, err := db.Table("user").Delete()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 3)
// DELETE...LIMIT
gtest.Case(t, func() {
result, err := db.Table("user").Limit(2).Delete()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 2)
})
gtest.Case(t, func() {
result, err := db.Table("user").Delete()
if err != nil {
gtest.Fatal(err)
}
n, _ := result.RowsAffected()
gtest.Assert(n, 1)
})
}

View File

@ -2,6 +2,7 @@ package main
import (
"fmt"
"github.com/gogf/gf/g"
"github.com/gogf/gf/g/database/gdb"
"github.com/gogf/gf/g/os/glog"
@ -22,8 +23,11 @@ func main() {
if err != nil {
panic(err)
}
glog.SetPath("/tmp")
db.SetDebug(true)
db.Table("user").Limit(2).Delete()
return
glog.SetPath("/tmp")
// 执行3条SQL查询
for i := 1; i <= 3; i++ {
db.Table("user").Where("uid=?", i).One()