mirror of
https://gitee.com/johng/gf
synced 2026-06-07 02:12:11 +08:00
improve gdb.Update/Delete feature to support orderby/limit features
This commit is contained in:
@ -14,13 +14,14 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/g/container/gmap"
|
||||
"github.com/gogf/gf/g/container/gring"
|
||||
"github.com/gogf/gf/g/container/gtype"
|
||||
"github.com/gogf/gf/g/container/gvar"
|
||||
"github.com/gogf/gf/g/os/gcache"
|
||||
"github.com/gogf/gf/g/util/grand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 数据库操作接口
|
||||
|
||||
@ -11,13 +11,14 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/g/container/gvar"
|
||||
"github.com/gogf/gf/g/os/gcache"
|
||||
"github.com/gogf/gf/g/os/gtime"
|
||||
"github.com/gogf/gf/g/text/gregex"
|
||||
"github.com/gogf/gf/g/util/gconv"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -498,7 +499,10 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt
|
||||
// CURD操作:数据更新,统一采用sql预处理。
|
||||
// data参数支持string/map/struct/*struct类型。
|
||||
func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
newWhere, newArgs := formatCondition(condition, args)
|
||||
newWhere, newArgs := formatWhere(condition, args)
|
||||
if newWhere != "" {
|
||||
newWhere = " WHERE " + newWhere
|
||||
}
|
||||
return bs.db.doUpdate(nil, table, data, newWhere, newArgs...)
|
||||
}
|
||||
|
||||
@ -537,15 +541,15 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if len(condition) == 0 {
|
||||
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s", table, updates), args...)
|
||||
}
|
||||
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s WHERE %s", table, updates, condition), args...)
|
||||
return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition), args...)
|
||||
}
|
||||
|
||||
// CURD操作:删除数据
|
||||
func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) {
|
||||
newWhere, newArgs := formatCondition(condition, args)
|
||||
newWhere, newArgs := formatWhere(condition, args)
|
||||
if newWhere != "" {
|
||||
newWhere = " WHERE " + newWhere
|
||||
}
|
||||
return bs.db.doDelete(nil, table, newWhere, newArgs...)
|
||||
}
|
||||
|
||||
@ -556,10 +560,7 @@ func (bs *dbBase) doDelete(link dbLink, table string, condition string, args ...
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if len(condition) == 0 {
|
||||
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s", table), args...)
|
||||
}
|
||||
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s WHERE %s", table, condition), args...)
|
||||
return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...)
|
||||
}
|
||||
|
||||
// 获得缓存对象
|
||||
@ -570,12 +571,15 @@ func (bs *dbBase) getCache() *gcache.Cache {
|
||||
// 将数据查询的列表数据*sql.Rows转换为Result类型
|
||||
func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
|
||||
// 列信息列表, 名称与类型
|
||||
types := make([]string, 0)
|
||||
columns := make([]string, 0)
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
for _, t := range columnTypes {
|
||||
types = append(types, t.DatabaseTypeName())
|
||||
columns = append(columns, t.Name())
|
||||
columnTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
types := make([]string, len(columnTypes))
|
||||
columns := make([]string, len(columnTypes))
|
||||
for k, v := range columnTypes {
|
||||
types[k] = v.DatabaseTypeName()
|
||||
columns[k] = v.Name()
|
||||
}
|
||||
// 返回结构组装
|
||||
values := make([]sql.RawBytes, len(columns))
|
||||
@ -589,14 +593,15 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) {
|
||||
return records, err
|
||||
}
|
||||
row := make(Record)
|
||||
// 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址
|
||||
for i, col := range values {
|
||||
if col == nil {
|
||||
// 注意col字段是一个[]byte类型(slice类型本身是一个引用类型),
|
||||
// 多个记录循环时该变量指向的是同一个内存地址
|
||||
for i, column := range values {
|
||||
if column == nil {
|
||||
row[columns[i]] = gvar.New(nil, true)
|
||||
} else {
|
||||
// 由于 sql.RawBytes 是slice类型, 这里必须使用值复制
|
||||
v := make([]byte, len(col))
|
||||
copy(v, col)
|
||||
v := make([]byte, len(column))
|
||||
copy(v, column)
|
||||
row[columns[i]] = gvar.New(bs.db.convertValue(v, types[i]), true)
|
||||
}
|
||||
}
|
||||
|
||||
@ -10,14 +10,15 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/g/os/glog"
|
||||
"github.com/gogf/gf/g/os/gtime"
|
||||
"github.com/gogf/gf/g/text/gregex"
|
||||
"github.com/gogf/gf/g/text/gstr"
|
||||
"github.com/gogf/gf/g/util/gconv"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Type assert api for String().
|
||||
@ -25,8 +26,8 @@ type apiString interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
// 格式化SQL查询条件
|
||||
func formatCondition(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) {
|
||||
// 格式化Where查询条件
|
||||
func formatWhere(where interface{}, args []interface{}) (newWhere string, newArgs []interface{}) {
|
||||
// 条件字符串处理
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
// 使用反射进行类型判断
|
||||
|
||||
@ -12,8 +12,9 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gogf/gf/g/util/gconv"
|
||||
"reflect"
|
||||
|
||||
"github.com/gogf/gf/g/util/gconv"
|
||||
)
|
||||
|
||||
// 数据库链式操作模型对象
|
||||
@ -29,13 +30,13 @@ type Model struct {
|
||||
orderBy string // 排序语句
|
||||
start int // 分页开始
|
||||
limit int // 分页条数
|
||||
data interface{} // 操作记录(支持Map/List/string类型)
|
||||
data interface{} // 操作数据(注意仅支持Map/List/string类型)
|
||||
batch int // 批量操作条数
|
||||
filter bool // 是否按照表字段过滤data参数
|
||||
cacheEnabled bool // 当前SQL操作是否开启查询缓存功能
|
||||
cacheTime int // 查询缓存时间
|
||||
cacheName string // 查询缓存名称
|
||||
safe bool // 当前模型是否运行安全模式(可修改当前模型,否则每一次链式操作都是返回新的模型对象)
|
||||
safe bool // 当前模型是否安全模式(默认非安全表示链式操作直接修改当前模型属性;否则每一次链式操作都是返回新的模型对象)
|
||||
}
|
||||
|
||||
// 链式操作,数据表字段,可支持多个表,以半角逗号连接
|
||||
@ -45,6 +46,7 @@ func (bs *dbBase) Table(tables string) *Model {
|
||||
tablesInit: tables,
|
||||
tables: tables,
|
||||
fields: "*",
|
||||
start: -1,
|
||||
safe: false,
|
||||
}
|
||||
}
|
||||
@ -149,7 +151,7 @@ func (md *Model) Where(where interface{}, args ...interface{}) *Model {
|
||||
if model.where != "" {
|
||||
return md.And(where, args...)
|
||||
}
|
||||
newWhere, newArgs := formatCondition(where, args)
|
||||
newWhere, newArgs := formatWhere(where, args)
|
||||
model.where = newWhere
|
||||
model.whereArgs = newArgs
|
||||
return model
|
||||
@ -158,7 +160,7 @@ func (md *Model) Where(where interface{}, args ...interface{}) *Model {
|
||||
// 链式操作,添加AND条件到Where中
|
||||
func (md *Model) And(where interface{}, args ...interface{}) *Model {
|
||||
model := md.getModel()
|
||||
newWhere, newArgs := formatCondition(where, args)
|
||||
newWhere, newArgs := formatWhere(where, args)
|
||||
if len(model.where) > 0 && model.where[0] == '(' {
|
||||
model.where = fmt.Sprintf(`%s AND (%s)`, model.where, newWhere)
|
||||
} else {
|
||||
@ -171,7 +173,7 @@ func (md *Model) And(where interface{}, args ...interface{}) *Model {
|
||||
// 链式操作,添加OR条件到Where中
|
||||
func (md *Model) Or(where interface{}, args ...interface{}) *Model {
|
||||
model := md.getModel()
|
||||
newWhere, newArgs := formatCondition(where, args)
|
||||
newWhere, newArgs := formatWhere(where, args)
|
||||
if len(model.where) > 0 && model.where[0] == '(' {
|
||||
model.where = fmt.Sprintf(`%s OR (%s)`, model.where, newWhere)
|
||||
} else {
|
||||
@ -195,11 +197,20 @@ func (md *Model) OrderBy(orderBy string) *Model {
|
||||
return model
|
||||
}
|
||||
|
||||
// 链式操作,limit
|
||||
func (md *Model) Limit(start int, limit int) *Model {
|
||||
// 链式操作,limit。
|
||||
//
|
||||
// 如果给定一个参数,那么生成的SQL为:LIMIT limit[0]
|
||||
//
|
||||
// 如果给定两个参数,那么生成的SQL为:LIMIT limit[0], limit[1]
|
||||
func (md *Model) Limit(limit ...int) *Model {
|
||||
model := md.getModel()
|
||||
model.start = start
|
||||
model.limit = limit
|
||||
switch len(limit) {
|
||||
case 1:
|
||||
model.limit = limit[0]
|
||||
case 2:
|
||||
model.start = limit[0]
|
||||
model.limit = limit[1]
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
@ -425,9 +436,9 @@ func (md *Model) Update() (result sql.Result, err error) {
|
||||
}
|
||||
}
|
||||
if md.tx == nil {
|
||||
return md.db.doUpdate(nil, md.tables, md.data, md.where, md.whereArgs...)
|
||||
return md.db.doUpdate(nil, md.tables, md.data, md.getConditionSql(), md.whereArgs...)
|
||||
} else {
|
||||
return md.tx.doUpdate(md.tables, md.data, md.where, md.whereArgs...)
|
||||
return md.tx.doUpdate(md.tables, md.data, md.getConditionSql(), md.whereArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
@ -439,9 +450,9 @@ func (md *Model) Delete() (result sql.Result, err error) {
|
||||
}
|
||||
}()
|
||||
if md.tx == nil {
|
||||
return md.db.doDelete(nil, md.tables, md.where, md.whereArgs...)
|
||||
return md.db.doDelete(nil, md.tables, md.getConditionSql(), md.whereArgs...)
|
||||
} else {
|
||||
return md.tx.doDelete(md.tables, md.where, md.whereArgs...)
|
||||
return md.tx.doDelete(md.tables, md.getConditionSql(), md.whereArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
@ -452,7 +463,7 @@ func (md *Model) Select() (Result, error) {
|
||||
|
||||
// 链式操作,查询所有记录
|
||||
func (md *Model) All() (Result, error) {
|
||||
return md.getAll(md.getFormattedSql(), md.whereArgs...)
|
||||
return md.getAll(fmt.Sprintf("SELECT %s FROM %s %s", md.fields, md.tables, md.getConditionSql()), md.whereArgs...)
|
||||
}
|
||||
|
||||
// 链式操作,查询单条记录
|
||||
@ -530,7 +541,7 @@ func (md *Model) Count() (int, error) {
|
||||
} else {
|
||||
md.fields = fmt.Sprintf(`COUNT(%s)`, md.fields)
|
||||
}
|
||||
s := md.getFormattedSql()
|
||||
s := fmt.Sprintf("SELECT %s FROM %s %s", md.fields, md.tables, md.getConditionSql())
|
||||
if len(md.groupBy) > 0 {
|
||||
s = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", s)
|
||||
}
|
||||
@ -583,12 +594,9 @@ func (md *Model) checkAndRemoveCache() {
|
||||
}
|
||||
}
|
||||
|
||||
// 格式化当前输入参数,返回可执行的SQL语句(不带参数)
|
||||
func (md *Model) getFormattedSql() string {
|
||||
if md.fields == "" {
|
||||
md.fields = "*"
|
||||
}
|
||||
s := fmt.Sprintf("SELECT %s FROM %s", md.fields, md.tables)
|
||||
// 格式化当前输入参数,返回SQL条件语句(不带参数)
|
||||
func (md *Model) getConditionSql() string {
|
||||
s := ""
|
||||
if md.where != "" {
|
||||
s += " WHERE " + md.where
|
||||
}
|
||||
@ -599,7 +607,12 @@ func (md *Model) getFormattedSql() string {
|
||||
s += " ORDER BY " + md.orderBy
|
||||
}
|
||||
if md.limit != 0 {
|
||||
s += fmt.Sprintf(" LIMIT %d, %d", md.start, md.limit)
|
||||
if md.start >= 0 {
|
||||
s += fmt.Sprintf(" LIMIT %d, %d", md.start, md.limit)
|
||||
} else {
|
||||
s += fmt.Sprintf(" LIMIT %d", md.limit)
|
||||
}
|
||||
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
@ -9,8 +9,9 @@ package gdb
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/gogf/gf/g/text/gregex"
|
||||
"reflect"
|
||||
|
||||
"github.com/gogf/gf/g/text/gregex"
|
||||
)
|
||||
|
||||
// 数据库事务对象
|
||||
@ -164,7 +165,10 @@ func (tx *TX) BatchSave(table string, list interface{}, batch ...int) (sql.Resul
|
||||
// CURD操作:数据更新,统一采用sql预处理,
|
||||
// data参数支持字符串或者关联数组类型,内部会自行做判断处理.
|
||||
func (tx *TX) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
newWhere, newArgs := formatCondition(condition, args)
|
||||
newWhere, newArgs := formatWhere(condition, args)
|
||||
if newWhere != "" {
|
||||
newWhere = " WHERE " + newWhere
|
||||
}
|
||||
return tx.doUpdate(table, data, newWhere, newArgs...)
|
||||
}
|
||||
|
||||
@ -175,7 +179,10 @@ func (tx *TX) doUpdate(table string, data interface{}, condition string, args ..
|
||||
|
||||
// CURD操作:删除数据
|
||||
func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) {
|
||||
newWhere, newArgs := formatCondition(condition, args)
|
||||
newWhere, newArgs := formatWhere(condition, args)
|
||||
if newWhere != "" {
|
||||
newWhere = " WHERE " + newWhere
|
||||
}
|
||||
return tx.doDelete(table, newWhere, newArgs...)
|
||||
}
|
||||
|
||||
|
||||
@ -8,12 +8,14 @@ package gdb_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/gogf/gf/g"
|
||||
"github.com/gogf/gf/g/container/garray"
|
||||
|
||||
"github.com/gogf/gf/g/database/gdb"
|
||||
"github.com/gogf/gf/g/os/gtime"
|
||||
"github.com/gogf/gf/g/test/gtest"
|
||||
"os"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -88,14 +90,6 @@ func createTable(table ...string) (name string) {
|
||||
return
|
||||
}
|
||||
|
||||
// 删除指定表.
|
||||
func dropTable(table string) {
|
||||
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// See createTable.
|
||||
// 创建测试表,并初始化默认数据。
|
||||
func createInitTable(table ...string) (name string) {
|
||||
name = createTable(table...)
|
||||
@ -117,3 +111,10 @@ func createInitTable(table ...string) (name string) {
|
||||
gtest.Assert(n, INIT_DATA_SIZE)
|
||||
return
|
||||
}
|
||||
|
||||
// 删除指定表.
|
||||
func dropTable(table string) {
|
||||
if _, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS `%s`", table)); err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,10 +7,11 @@
|
||||
package gdb_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/gogf/gf/g"
|
||||
"github.com/gogf/gf/g/os/gtime"
|
||||
"github.com/gogf/gf/g/test/gtest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// 基本测试
|
||||
@ -187,6 +188,25 @@ func TestModel_Save(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModel_Update(t *testing.T) {
|
||||
table := createInitTable()
|
||||
// UPDATE...LIMIT
|
||||
gtest.Case(t, func() {
|
||||
result, err := db.Table(table).Data("nickname", "T100").OrderBy("id desc").Limit(2).Update()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 2)
|
||||
|
||||
v1, err := db.Table(table).Fields("nickname").Where("id", 10).Value()
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(v1.String(), "T100")
|
||||
|
||||
v2, err := db.Table(table).Fields("nickname").Where("id", 8).Value()
|
||||
gtest.Assert(err, nil)
|
||||
gtest.Assert(v2.String(), "T8")
|
||||
})
|
||||
|
||||
gtest.Case(t, func() {
|
||||
result, err := db.Table("user").Data("passport", "t22").Where("passport=?", "t2").Update()
|
||||
if err != nil {
|
||||
@ -644,10 +664,22 @@ func TestModel_Where(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModel_Delete(t *testing.T) {
|
||||
result, err := db.Table("user").Delete()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 3)
|
||||
// DELETE...LIMIT
|
||||
gtest.Case(t, func() {
|
||||
result, err := db.Table("user").Limit(2).Delete()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 2)
|
||||
})
|
||||
|
||||
gtest.Case(t, func() {
|
||||
result, err := db.Table("user").Delete()
|
||||
if err != nil {
|
||||
gtest.Fatal(err)
|
||||
}
|
||||
n, _ := result.RowsAffected()
|
||||
gtest.Assert(n, 1)
|
||||
})
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/g"
|
||||
"github.com/gogf/gf/g/database/gdb"
|
||||
"github.com/gogf/gf/g/os/glog"
|
||||
@ -22,8 +23,11 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
glog.SetPath("/tmp")
|
||||
db.SetDebug(true)
|
||||
db.Table("user").Limit(2).Delete()
|
||||
return
|
||||
glog.SetPath("/tmp")
|
||||
|
||||
// 执行3条SQL查询
|
||||
for i := 1; i <= 3; i++ {
|
||||
db.Table("user").Where("uid=?", i).One()
|
||||
Reference in New Issue
Block a user