diff --git a/g/database/gdb/gdb_base.go b/g/database/gdb/gdb_base.go index 27a582b9e..f1675dafb 100644 --- a/g/database/gdb/gdb_base.go +++ b/g/database/gdb/gdb_base.go @@ -340,19 +340,17 @@ func (db *Db) BatchSave(table string, list List, batch int) (sql.Result, error) func (db *Db) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { var params []interface{} var updates string - switch value := data.(type) { - case string: - updates = value - case Map: - var keys []string - for k, v := range value { - keys = append(keys, fmt.Sprintf("%s%s%s=?", db.charl, k, db.charr)) - params = append(params, v) - } - updates = strings.Join(keys, ",") - - default: - return nil, errors.New("invalid data type for 'data' field, string or Map expected") + refValue := reflect.ValueOf(data) + if refValue.Kind() == reflect.Map { + var fields []string + keys := refValue.MapKeys() + for _, k := range keys { + fields = append(fields, fmt.Sprintf("%s%s%s=?", db.charl, k, db.charr)) + params = append(params, gconv.String(refValue.MapIndex(k).Interface())) + updates = strings.Join(fields, ",") + } + } else { + updates = gconv.String(data) } for _, v := range args { params = append(params, gconv.String(v)) diff --git a/g/database/gdb/gdb_model.go b/g/database/gdb/gdb_model.go index 8bfd9a480..3fa2a6401 100644 --- a/g/database/gdb/gdb_model.go +++ b/g/database/gdb/gdb_model.go @@ -84,7 +84,7 @@ func (md *Model) Fields(fields string) (*Model) { // 链式操作,condition,支持string & gdb.Map func (md *Model) Where(where interface{}, args...interface{}) (*Model) { md.where = md.db.formatCondition(where) - md.whereArgs = args + md.whereArgs = append(md.whereArgs, args...) return md } diff --git a/g/database/gdb/gdb_transaction.go b/g/database/gdb/gdb_transaction.go index f04236f01..4cba7163e 100644 --- a/g/database/gdb/gdb_transaction.go +++ b/g/database/gdb/gdb_transaction.go @@ -14,6 +14,7 @@ import ( _ "github.com/lib/pq" _ "github.com/go-sql-driver/mysql" "gitee.com/johng/gf/g/util/gconv" + "reflect" ) // 数据库事务对象 @@ -51,25 +52,7 @@ func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { return r, err } -// (事务)数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作 -func (tx *Tx) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) { - s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables) - if condition != nil { - s += fmt.Sprintf("WHERE %s ", tx.db.formatCondition(condition)) - } - if len(groupBy) > 0 { - s += fmt.Sprintf("GROUP BY %s ", groupBy) - } - if len(orderBy) > 0 { - s += fmt.Sprintf("ORDER BY %s ", orderBy) - } - if limit > 0 { - s += fmt.Sprintf("LIMIT %d,%d ", first, limit) - } - return tx.GetAll(s, args ... ) -} - -// (事务)数据库查询,获取查询结果集,以列表结构返回 +// 数据库查询,获取查询结果集,以列表结构返回 func (tx *Tx) GetAll(query string, args ...interface{}) (Result, error) { // 执行sql rows, err := tx.Query(query, args ...) @@ -94,24 +77,42 @@ func (tx *Tx) GetAll(query string, args ...interface{}) (Result, error) { return records, err } row := make(Record) + // 注意col字段是一个[]byte类型(slice类型本身是一个指针),多个记录循环时该变量指向的是同一个内存地址 for i, col := range values { - row[columns[i]] = Value(col) + k := columns[i] + v := make([]byte, len(col)) + copy(v, col) + row[k] = v } + //fmt.Printf("%p\n", row["typeid"]) records = append(records, row) } return records, nil } -// (事务)数据库查询,获取查询结果集,以关联数组结构返回 +// 数据库查询,获取查询结果记录,以关联数组结构返回 func (tx *Tx) GetOne(query string, args ...interface{}) (Record, error) { list, err := tx.GetAll(query, args ...) if err != nil { return nil, err } - return list[0], nil + if len(list) > 0 { + return list[0], nil + } + return nil, nil } -// (事务)数据库查询,获取查询字段值 +// 数据库查询,获取查询结果记录,自动映射数据到给定的struct对象中 +func (tx *Tx) GetStruct(obj interface{}, query string, args ...interface{}) error { + one, err := tx.GetOne(query, args...) + if err != nil { + return err + } + return one.ToStruct(obj) +} + + +// 数据库查询,获取查询字段值 func (tx *Tx) GetValue(query string, args ...interface{}) (Value, error) { one, err := tx.GetOne(query, args ...) if err != nil { @@ -132,13 +133,31 @@ func (tx *Tx) GetCount(query string, args ...interface{}) (int, error) { return gconv.Int(val), nil } -// (事务)sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 -// 记得调用sql.Stmt.Close关闭操作对象 -func (tx *Tx) Prepare(query string) (*sql.Stmt, error) { - return tx.Prepare(query) +// 数据表查询,其中tables可以是多个联表查询语句,这种查询方式较复杂,建议使用链式操作 +func (tx *Tx) Select(tables, fields string, condition interface{}, groupBy, orderBy string, first, limit int, args ... interface{}) (Result, error) { + s := fmt.Sprintf("SELECT %s FROM %s ", fields, tables) + if condition != nil { + s += fmt.Sprintf("WHERE %s ", tx.db.formatCondition(condition)) + } + if len(groupBy) > 0 { + s += fmt.Sprintf("GROUP BY %s ", groupBy) + } + if len(orderBy) > 0 { + s += fmt.Sprintf("ORDER BY %s ", orderBy) + } + if limit > 0 { + s += fmt.Sprintf("LIMIT %d,%d ", first, limit) + } + return tx.GetAll(s, args ... ) } -// (事务)insert、replace, save, ignore操作 +// sql预处理,执行完成后调用返回值sql.Stmt.Exec完成sql操作 +// 记得调用sql.Stmt.Close关闭操作对象 +func (tx *Tx) Prepare(query string) (*sql.Stmt, error) { + return tx.tx.Prepare(query) +} + +// insert、replace, save, ignore操作 // 0: insert: 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 // 1: replace: 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 // 2: save: 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 @@ -167,22 +186,22 @@ func (tx *Tx) insert(table string, data Map, option uint8) (sql.Result, error) { ) } -// (事务)CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 +// CURD操作:单条数据写入, 仅仅执行写入操作,如果存在冲突的主键或者唯一索引,那么报错返回 func (tx *Tx) Insert(table string, data Map) (sql.Result, error) { return tx.insert(table, data, OPTION_INSERT) } -// (事务)CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 +// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 func (tx *Tx) Replace(table string, data Map) (sql.Result, error) { return tx.insert(table, data, OPTION_REPLACE) } -// (事务)CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 +// CURD操作:单条数据写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 func (tx *Tx) Save(table string, data Map) (sql.Result, error) { return tx.insert(table, data, OPTION_SAVE) } -// (事务)批量写入数据 +// 批量写入数据 func (tx *Tx) batchInsert(table string, list List, batch int, option uint8) (sql.Result, error) { var keys []string var values []string @@ -217,9 +236,7 @@ func (tx *Tx) batchInsert(table string, list List, batch int, option uint8) (sql } bvalues = append(bvalues, "(" + strings.Join(values, ",") + ")") if len(bvalues) == batch { - r, err := tx.Exec(fmt.Sprintf("%s INTO %s%s%s(%s) VALUES%s %s", - operation, tx.db.charl, table, tx.db.charr, kstr, strings.Join(bvalues, ","), updatestr), - params...) + r, err := tx.Exec(fmt.Sprintf("%s INTO %s%s%s(%s) VALUES%s %s", operation, tx.db.charl, table, tx.db.charr, kstr, strings.Join(bvalues, ","), updatestr), params...) if err != nil { return result, err } @@ -229,9 +246,7 @@ func (tx *Tx) batchInsert(table string, list List, batch int, option uint8) (sql } // 处理最后不构成指定批量的数据 if len(bvalues) > 0 { - r, err := tx.Exec(fmt.Sprintf("%s INTO %s%s%s(%s) VALUES%s %s", - operation, tx.db.charl, table, tx.db.charr, kstr, strings.Join(bvalues, ","), updatestr), - params...) + r, err := tx.Exec(fmt.Sprintf("%s INTO %s%s%s(%s) VALUES%s %s", operation, tx.db.charl, table, tx.db.charr, kstr, strings.Join(bvalues, ","), updatestr), params...) if err != nil { return result, err } @@ -240,53 +255,46 @@ func (tx *Tx) batchInsert(table string, list List, batch int, option uint8) (sql return result, nil } -// (事务)CURD操作:批量数据指定批次量写入 +// CURD操作:批量数据指定批次量写入 func (tx *Tx) BatchInsert(table string, list List, batch int) (sql.Result, error) { return tx.batchInsert(table, list, batch, OPTION_INSERT) } -// (事务)CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 +// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么删除后重新写入一条 func (tx *Tx) BatchReplace(table string, list List, batch int) (sql.Result, error) { return tx.batchInsert(table, list, batch, OPTION_REPLACE) } -// (事务)CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 +// CURD操作:批量数据指定批次量写入, 如果数据存在(主键或者唯一索引),那么更新,否则写入一条新数据 func (tx *Tx) BatchSave(table string, list List, batch int) (sql.Result, error) { return tx.batchInsert(table, list, batch, OPTION_SAVE) } -// (事务)CURD操作:数据更新,统一采用sql预处理 +// CURD操作:数据更新,统一采用sql预处理 // data参数支持字符串或者关联数组类型,内部会自行做判断处理 func (tx *Tx) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { var params []interface{} var updates string - switch value := data.(type) { - case string: - updates = value - case Map: - var keys []string - for k, v := range value { - keys = append(keys, fmt.Sprintf("%s%s%s=?", tx.db.charl, k, tx.db.charr)) - params = append(params, v) - } - updates = strings.Join(keys, ",") - - default: - return nil, errors.New("invalid data type for 'data' field, string or Map expected") + refValue := reflect.ValueOf(data) + if refValue.Kind() == reflect.Map { + var fields []string + keys := refValue.MapKeys() + for _, k := range keys { + fields = append(fields, fmt.Sprintf("%s%s%s=?", tx.db.charl, k, tx.db.charr)) + params = append(params, gconv.String(refValue.MapIndex(k).Interface())) + updates = strings.Join(fields, ",") + } + } else { + updates = gconv.String(data) } for _, v := range args { - if r, ok := v.(string); ok { - params = append(params, r) - } else if r, ok := v.(int); ok { - params = append(params, string(r)) - } else { - - } + params = append(params, gconv.String(v)) } - return tx.Exec(fmt.Sprintf("UPDATE %s%s%s SET %s WHERE %s", tx.db.charl, table, tx.db.charr, updates, condition), params...) + return tx.Exec(fmt.Sprintf("UPDATE %s%s%s SET %s WHERE %s", tx.db.charl, table, tx.db.charr, updates, tx.db.formatCondition(condition)), params...) } -// (事务)CURD操作:删除数据 +// CURD操作:删除数据 func (tx *Tx) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) { - return tx.Exec(fmt.Sprintf("DELETE FROM %s WHERE %s", tx.db.charl, table, tx.db.charr, condition), args...) + return tx.Exec(fmt.Sprintf("DELETE FROM %s%s%s WHERE %s", tx.db.charl, table, tx.db.charr, tx.db.formatCondition(condition)), args...) } + diff --git a/geg/database/mysql/mysql.go b/geg/database/mysql/mysql.go index afff92934..a6ba6c9ef 100644 --- a/geg/database/mysql/mysql.go +++ b/geg/database/mysql/mysql.go @@ -6,7 +6,6 @@ import ( "gitee.com/johng/gf/g/database/gdb" "gitee.com/johng/gf/g" "gitee.com/johng/gf/g/frame/gins" - "gitee.com/johng/gf/g/encoding/gparser" ) // 本文件用于gf框架的mysql数据库操作示例,不作为单元测试使用 @@ -478,10 +477,9 @@ func mapToStruct() { } func main() { - r, _ := db.Table("user").Fields("*").Where("typeid = ?", 1).And("uid=?", 1).Limit(0, 10).Select() - j, _ := gparser.VarToJson(r.ToList()) - fmt.Println(string(j)) - fmt.Println(r) + r, err := db.Table("user").Data(g.Map{"name": "john14"}).Where("uid = ?", 1).Update() + fmt.Println(r.RowsAffected()) + fmt.Println(err) //create() //create() //insert() diff --git a/geg/other/test.go b/geg/other/test.go index 48ef3ddd3..b37cf228b 100644 --- a/geg/other/test.go +++ b/geg/other/test.go @@ -2,17 +2,21 @@ package main import ( "fmt" - "gitee.com/johng/gf/g/util/gvalid" + "reflect" + "gitee.com/johng/gf/g/database/gdb" ) func main() { - data := map[string]interface{} { - "id" : "1", + var value interface{} + value = gdb.Map{"a":1} + + refValue := reflect.ValueOf(value) + + if refValue.Kind() == reflect.Map { + keys := refValue.MapKeys() + for _, k := range keys { + fmt.Println(k, refValue.MapIndex(k).Interface()) + } } - rules := map[string]string { - "id" : "required", - "name" : "length:4,16", - } - m := gvalid.CheckMap(data, rules) - fmt.Println(m) + }