From 2198f0cefe994dab3dae302bcb2ab3d38b5296e9 Mon Sep 17 00:00:00 2001 From: John Guo Date: Thu, 27 Apr 2023 11:35:46 +0800 Subject: [PATCH] fix issue #2561 #2431 (#2598) --- contrib/drivers/mssql/mssql.go | 2 +- .../mysql/mysql_feature_model_do_test.go | 2 +- contrib/drivers/mysql/mysql_issue_test.go | 62 +++++++++++++ database/gdb/gdb_core.go | 87 ++++++++++++++++++- database/gdb/gdb_core_underlying.go | 3 +- 5 files changed, 149 insertions(+), 7 deletions(-) diff --git a/contrib/drivers/mssql/mssql.go b/contrib/drivers/mssql/mssql.go index 04952f4d2..fb9987e2e 100644 --- a/contrib/drivers/mssql/mssql.go +++ b/contrib/drivers/mssql/mssql.go @@ -20,6 +20,7 @@ import ( "strings" _ "github.com/denisenkom/go-mssqldb" + "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" @@ -285,7 +286,6 @@ ORDER BY a.id,a.colorder`, } fields = make(map[string]*gdb.TableField) for i, m := range result { - fields[m["Field"].String()] = &gdb.TableField{ Index: i, Name: m["Field"].String(), diff --git a/contrib/drivers/mysql/mysql_feature_model_do_test.go b/contrib/drivers/mysql/mysql_feature_model_do_test.go index d96af2d5c..2b367e5cf 100644 --- a/contrib/drivers/mysql/mysql_feature_model_do_test.go +++ b/contrib/drivers/mysql/mysql_feature_model_do_test.go @@ -128,7 +128,7 @@ func Test_Model_Update_Data_DO(t *testing.T) { func Test_Model_Update_Pointer_Data_DO(t *testing.T) { table := createInitTable() defer dropTable(table) - db.SetDebug(true) + gtest.C(t, func(t *gtest.T) { type NN string type Req struct { diff --git a/contrib/drivers/mysql/mysql_issue_test.go b/contrib/drivers/mysql/mysql_issue_test.go index 380ab909f..a75f9fb6b 100644 --- a/contrib/drivers/mysql/mysql_issue_test.go +++ b/contrib/drivers/mysql/mysql_issue_test.go @@ -664,3 +664,65 @@ CREATE TABLE %s ( t.AssertNil(err3) }) } + +// https://github.com/gogf/gf/issues/2561 +func Test_Issue2561(t *testing.T) { + table := createTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + type User struct { + g.Meta `orm:"do:true"` + Id interface{} + Passport interface{} + Password interface{} + Nickname interface{} + CreateTime interface{} + } + data := g.Slice{ + User{ + Id: 1, + Passport: "user_1", + }, + User{ + Id: 2, + Password: "pass_2", + }, + User{ + Id: 3, + Password: "pass_3", + }, + } + result, err := db.Model(table).Data(data).Insert() + t.AssertNil(err) + m, _ := result.LastInsertId() + t.Assert(m, 3) + + n, _ := result.RowsAffected() + t.Assert(n, 3) + + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one[`id`], `1`) + t.Assert(one[`passport`], `user_1`) + t.Assert(one[`password`], ``) + t.Assert(one[`nickname`], ``) + t.Assert(one[`create_time`], ``) + + one, err = db.Model(table).WherePri(2).One() + t.AssertNil(err) + t.Assert(one[`id`], `2`) + t.Assert(one[`passport`], ``) + t.Assert(one[`password`], `pass_2`) + t.Assert(one[`nickname`], ``) + t.Assert(one[`create_time`], ``) + + one, err = db.Model(table).WherePri(3).One() + t.AssertNil(err) + t.Assert(one[`id`], `3`) + t.Assert(one[`passport`], ``) + t.Assert(one[`password`], `pass_3`) + t.Assert(one[`nickname`], ``) + t.Assert(one[`create_time`], ``) + }) +} diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index b86dd9f61..7e2173678 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -14,6 +14,8 @@ import ( "reflect" "strings" + "github.com/gogf/gf/v2/container/gmap" + "github.com/gogf/gf/v2/container/gset" "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" @@ -389,6 +391,29 @@ func (c *Core) Save(ctx context.Context, table string, data interface{}, batch . return c.Model(table).Ctx(ctx).Data(data).Save() } +func (c *Core) fieldsToSequence(ctx context.Context, table string, fields []string) ([]string, error) { + var ( + fieldSet = gset.NewStrSetFrom(fields) + fieldsResultInSequence = make([]string, 0) + tableFields, err = c.db.TableFields(ctx, table) + ) + if err != nil { + return nil, err + } + // Sort the fields in order. + var fieldsOfTableInSequence = make([]string, len(tableFields)) + for _, field := range tableFields { + fieldsOfTableInSequence[field.Index] = field.Name + } + // Sort the input fields. + for _, fieldName := range fieldsOfTableInSequence { + if fieldSet.Contains(fieldName) { + fieldsResultInSequence = append(fieldsResultInSequence, fieldName) + } + } + return fieldsResultInSequence, nil +} + // DoInsert inserts or updates data forF given table. // This function is usually used for custom interface definition, you do not need call it manually. // The parameter `data` can be type of map/gmap/struct/*struct/[]map/[]struct, etc. @@ -408,9 +433,50 @@ func (c *Core) DoInsert(ctx context.Context, link Link, table string, list List, params []interface{} // Values that will be committed to underlying database driver. onDuplicateStr string // onDuplicateStr is used in "ON DUPLICATE KEY UPDATE" statement. ) - // Handle the field names and placeholders. - for k := range list[0] { - keys = append(keys, k) + // Group the list by fields. Different fields to different list. + // It here uses ListMap to keep sequence for data inserting. + var keyListMap = gmap.NewListMap() + for _, item := range list { + var ( + tmpKeys = make([]string, 0) + tmpKeysInSequenceStr string + ) + for k := range item { + tmpKeys = append(tmpKeys, k) + } + keys, err = c.fieldsToSequence(ctx, table, tmpKeys) + if err != nil { + return nil, err + } + tmpKeysInSequenceStr = gstr.Join(keys, ",") + + if !keyListMap.Contains(tmpKeysInSequenceStr) { + keyListMap.Set(tmpKeysInSequenceStr, make(List, 0)) + } + tmpKeysInSequenceList := keyListMap.Get(tmpKeysInSequenceStr).(List) + tmpKeysInSequenceList = append(tmpKeysInSequenceList, item) + keyListMap.Set(tmpKeysInSequenceStr, tmpKeysInSequenceList) + } + if keyListMap.Size() > 1 { + var ( + tmpResult sql.Result + sqlResult SqlResult + rowsAffected int64 + ) + keyListMap.Iterator(func(key, value interface{}) bool { + tmpResult, err = c.DoInsert(ctx, link, table, value.(List), option) + if err != nil { + return false + } + rowsAffected, err = tmpResult.RowsAffected() + if err != nil { + return false + } + sqlResult.Result = tmpResult + sqlResult.Affected += rowsAffected + return true + }) + return &sqlResult, nil } // Prepare the batch result pointer. var ( @@ -571,7 +637,20 @@ func (c *Core) DoUpdate(ctx context.Context, link Link, table string, data inter if err != nil { return nil, err } - for k, v := range dataMap { + // Sort the data keys in sequence of table fields. + var ( + dataKeys = make([]string, 0) + keysInSequence = make([]string, 0) + ) + for k := range dataMap { + dataKeys = append(dataKeys, k) + } + keysInSequence, err = c.fieldsToSequence(ctx, table, dataKeys) + if err != nil { + return nil, err + } + for _, k := range keysInSequence { + v := dataMap[k] switch value := v.(type) { case *Counter: counterHandler(k, *value) diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index a008ce9dc..1930270d9 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -10,11 +10,12 @@ package gdb import ( "context" "database/sql" - "github.com/gogf/gf/v2/util/gconv" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" "reflect" + "github.com/gogf/gf/v2/util/gconv" + "github.com/gogf/gf/v2" "github.com/gogf/gf/v2/container/gvar" "github.com/gogf/gf/v2/errors/gcode"