diff --git a/g/database/gdb/gdb_base.go b/g/database/gdb/gdb_base.go index 28a852655..6f8b41232 100644 --- a/g/database/gdb/gdb_base.go +++ b/g/database/gdb/gdb_base.go @@ -16,6 +16,7 @@ import ( "github.com/gogf/gf/g/os/gtime" "github.com/gogf/gf/g/text/gregex" "github.com/gogf/gf/g/util/gconv" + "reflect" "strings" ) @@ -404,15 +405,24 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio params := ([]interface{})(nil) updates := "" charL, charR := bs.db.getChars() - if s, ok := data.(string); ok { - updates = s - } else { - var fields []string - for k, v := range gconv.Map(data) { - fields = append(fields, fmt.Sprintf("%s%s%s=?", charL, k, charR)) - params = append(params, gconv.String(v)) - } - updates = strings.Join(fields, ",") + // 使用反射进行类型判断 + rv := reflect.ValueOf(data) + kind := rv.Kind() + if kind == reflect.Ptr { + rv = rv.Elem() + kind = rv.Kind() + } + switch kind { + case reflect.Map: fallthrough + case reflect.Struct: + var fields []string + for k, v := range gconv.Map(data) { + fields = append(fields, fmt.Sprintf("%s%s%s=?", charL, k, charR)) + params = append(params, gconv.String(v)) + } + updates = strings.Join(fields, ",") + default: + updates = gconv.String(data) } for _, v := range args { params = append(params, gconv.String(v)) diff --git a/g/database/gdb/gdb_func.go b/g/database/gdb/gdb_func.go index fead48dba..7cd203bd5 100644 --- a/g/database/gdb/gdb_func.go +++ b/g/database/gdb/gdb_func.go @@ -24,21 +24,30 @@ import ( func formatCondition(where interface{}, args []interface{}) (string, []interface{}) { // 条件字符串处理 buffer := bytes.NewBuffer(nil) - if s, ok := where.(string); ok { - buffer.WriteString(s) - } else { - for k, v := range gconv.Map(where) { - key := gconv.String(k) - value := gconv.String(v) - if buffer.Len() > 0 { - buffer.WriteString(" AND ") + // 使用反射进行类型判断 + rv := reflect.ValueOf(where) + kind := rv.Kind() + if kind == reflect.Ptr { + rv = rv.Elem() + kind = rv.Kind() + } + switch kind { + case reflect.Map: fallthrough + case reflect.Struct: + for k, v := range gconv.Map(where) { + key := gconv.String(k) + value := gconv.String(v) + if buffer.Len() > 0 { + buffer.WriteString(" AND ") + } + if gstr.IsNumeric(value) || value == "?" { + buffer.WriteString(key + "=" + value) + } else { + buffer.WriteString(key + "='" + value + "'") + } } - if gstr.IsNumeric(value) || value == "?" { - buffer.WriteString(key + "=" + value) - } else { - buffer.WriteString(key + "='" + value + "'") - } - } + default: + buffer.WriteString(gconv.String(where)) } if buffer.Len() == 0 { buffer.WriteString("1=1") diff --git a/g/database/gdb/gdb_unit_0_test.go b/g/database/gdb/gdb_unit_0_test.go index 0300fa8f0..182382170 100644 --- a/g/database/gdb/gdb_unit_0_test.go +++ b/g/database/gdb/gdb_unit_0_test.go @@ -17,7 +17,7 @@ func init() { Host: "127.0.0.1", Port: "3306", User: "root", - Pass: "", + Pass: "12345678", Name: "", Type: "mysql", Role: "master", diff --git a/g/database/gdb/gdb_unit_1_test.go b/g/database/gdb/gdb_unit_1_test.go index e333d0dfc..064faa67b 100644 --- a/g/database/gdb/gdb_unit_1_test.go +++ b/g/database/gdb/gdb_unit_1_test.go @@ -56,6 +56,67 @@ func TestDbBase_Insert(t *testing.T) { }); err != nil { gtest.Fatal(err) } + + result, err := db.Insert("user", map[interface{}]interface{} { + "id" : "2", + "passport" : "t2", + "password" : "25d55ad283aa400af464c76d713c07ad", + "nickname" : "T2", + "create_time" : gtime.Now().String(), + }) + if err != nil { + gtest.Fatal(err) + } + n, _ := result.RowsAffected() + gtest.Assert(n, 1) + + type User struct { + Id int `gconv:"id"` + Passport string `json:"passport"` + Password string `gconv:"password"` + Nickname string `gconv:"nickname"` + CreateTime string `json:"create_time"` + } + result, err = db.Insert("user", User{ + Id : 3, + Uid : 3, + Passport : "t3", + Password : "25d55ad283aa400af464c76d713c07ad", + Nickname : "T3", + CreateTime : gtime.Now().String(), + }) + if err != nil { + gtest.Fatal(err) + } + n, _ = result.RowsAffected() + gtest.Assert(n, 1) + value, err := db.Table("user").Fields("passport").Where("id=3").Value() + gtest.Assert(err, nil) + gtest.Assert(value.String(), "t3") + + result, err = db.Insert("user", &User{ + Id : 4, + Uid : 4, + Passport : "t4", + Password : "25d55ad283aa400af464c76d713c07ad", + Nickname : "T4", + CreateTime : gtime.Now().String(), + }) + if err != nil { + gtest.Fatal(err) + } + n, _ = result.RowsAffected() + gtest.Assert(n, 1) + value, err = db.Table("user").Fields("passport").Where("id=4").Value() + gtest.Assert(err, nil) + gtest.Assert(value.String(), "t4") + + result, err = db.Table("user").Where("id>?", 1).Delete() + if err != nil { + gtest.Fatal(err) + } + n, _ = result.RowsAffected() + gtest.Assert(n, 3) } func TestDbBase_BatchInsert(t *testing.T) {