diff --git a/g/database/gdb/gdb_base.go b/g/database/gdb/gdb_base.go index a74e20627..e9c50cf33 100644 --- a/g/database/gdb/gdb_base.go +++ b/g/database/gdb/gdb_base.go @@ -128,6 +128,15 @@ func (db *Db) GetValue(query string, args ...interface{}) (interface{}, error) { return nil, nil } +// 数据库查询,获取查询数量 +func (db *Db) GetCount(query string, args ...interface{}) (int, error) { + val, err := db.GetValue(query, args ...) + if err != nil { + return 0, err + } + return gconv.Int(val), nil +} + // 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作 func (db *Db) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (List, error) { s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables) diff --git a/g/database/gdb/gdb_model.go b/g/database/gdb/gdb_model.go index 3cd911602..9d98e0b18 100644 --- a/g/database/gdb/gdb_model.go +++ b/g/database/gdb/gdb_model.go @@ -251,26 +251,10 @@ func (md *Model) Batch(batch int) *Model { // 链式操作,select func (md *Model) Select() (List, error) { - if md.fields == "" { - md.fields = "*" - } - s := fmt.Sprintf("SELECT %s FROM %s", md.fields, md.tables) - if md.where != "" { - s += " WHERE " + md.where - } - if md.groupBy != "" { - s += " GROUP BY " + md.groupBy - } - if md.orderBy != "" { - s += " ORDER BY " + md.orderBy - } - if md.limit != 0 { - s += fmt.Sprintf(" LIMIT %d, %d", md.start, md.limit) - } if md.tx == nil { - return md.db.GetAll(s, md.whereArgs...) + return md.db.GetAll(md.getFormattedSql(), md.whereArgs...) } else { - return md.tx.GetAll(s, md.whereArgs...) + return md.tx.GetAll(md.getFormattedSql(), md.whereArgs...) } } @@ -303,3 +287,40 @@ func (md *Model) Value() (interface{}, error) { return "", nil } +// 链式操作,查询数量,fields可以为空,也可以自定义查询字段, +// 当给定自定义查询字段时,该字段必须为数量结果,否则会引起歧义,如:Fields("COUNT(id)") +func (md *Model) Count() (int, error) { + if md.fields == "" || md.fields == "*" { + md.fields = "COUNT(1)" + } + s := md.getFormattedSql() + if len(md.groupBy) > 0 { + s = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", s) + } + if md.tx == nil { + return md.db.GetCount(s, md.whereArgs...) + } else { + return md.tx.GetCount(s, md.whereArgs...) + } +} + +// 格式化当前输入参数,返回可执行的SQL语句(不带参数) +func (md *Model) getFormattedSql() string { + if md.fields == "" { + md.fields = "*" + } + s := fmt.Sprintf("SELECT %s FROM %s", md.fields, md.tables) + if md.where != "" { + s += " WHERE " + md.where + } + if md.groupBy != "" { + s += " GROUP BY " + md.groupBy + } + if md.orderBy != "" { + s += " ORDER BY " + md.orderBy + } + if md.limit != 0 { + s += fmt.Sprintf(" LIMIT %d, %d", md.start, md.limit) + } + return s +} \ No newline at end of file diff --git a/g/database/gdb/gdb_transaction.go b/g/database/gdb/gdb_transaction.go index 315242ca7..06448fd1d 100644 --- a/g/database/gdb/gdb_transaction.go +++ b/g/database/gdb/gdb_transaction.go @@ -13,6 +13,7 @@ import ( "database/sql" _ "github.com/lib/pq" _ "github.com/go-sql-driver/mysql" + "gitee.com/johng/gf/g/util/gconv" ) // 数据库事务对象 @@ -122,6 +123,15 @@ func (tx *Tx) GetValue(query string, args ...interface{}) (interface{}, error) { return nil, nil } +// 数据库查询,获取查询数量 +func (tx *Tx) GetCount(query string, args ...interface{}) (int, error) { + val, err := tx.GetValue(query, args ...) + if err != nil { + return 0, err + } + return gconv.Int(val), nil +} + // (事务)sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 // 记得调用sql.Stmt.Close关闭操作对象 func (tx *Tx) Prepare(query string) (*sql.Stmt, error) { diff --git a/geg/database/mysql/mysql.go b/geg/database/mysql/mysql.go index 02fdfd0b1..3b49c1775 100644 --- a/geg/database/mysql/mysql.go +++ b/geg/database/mysql/mysql.go @@ -17,7 +17,7 @@ func init () { Host : "127.0.0.1", Port : "3306", User : "root", - Pass : "123456", + Pass : "8692651", Name : "test", Type : "mysql", Role : "master", @@ -285,6 +285,19 @@ func linkopSelect3() { fmt.Println() } +// 链式查询数量1 +func linkopCount1() { + fmt.Println("linkopCount1:") + r, err := db.Table("user u").LeftJoin("user_detail ud", "u.uid=ud.uid").Where("u.uid=?", 1).Count() + if err == nil { + fmt.Println(r) + } else { + fmt.Println(err) + } + fmt.Println() +} + + // 错误操作 func linkopUpdate1() { fmt.Println("linkopUpdate1:") @@ -462,10 +475,11 @@ func main() { //linkopSelect1() //linkopSelect2() //linkopSelect3() + linkopCount1() //linkopUpdate1() //linkopUpdate2() //linkopUpdate3() - linkopUpdate4() + //linkopUpdate4() //keepPing() //transaction1() //transaction2()