From facb9d93c00f0b7bfe63c8fd6f1d3cc1c8b34d23 Mon Sep 17 00:00:00 2001 From: John Date: Wed, 3 Jun 2020 21:36:16 +0800 Subject: [PATCH] fix issue of multiple slice arguments handling in function where --- database/gdb/gdb_func.go | 22 +++++++++--- database/gdb/gdb_unit_z_mysql_model_test.go | 37 +++++++++++++++------ 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 7b8b52b09..559baa9d3 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -411,11 +411,15 @@ func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key // underlying driver. func handleArguments(sql string, args []interface{}) (newSql string, newArgs []interface{}) { newSql = sql + // insertHolderCount is used to calculate the inserting position for the '?' holder. + insertHolderCount := 0 // Handles the slice arguments. if len(args) > 0 { for index, arg := range args { - rv := reflect.ValueOf(arg) - kind := rv.Kind() + var ( + rv = reflect.ValueOf(arg) + kind = rv.Kind() + ) if kind == reflect.Ptr { rv = rv.Elem() kind = rv.Kind() @@ -431,17 +435,25 @@ func handleArguments(sql string, args []interface{}) (newSql string, newArgs []i for i := 0; i < rv.Len(); i++ { newArgs = append(newArgs, rv.Index(i).Interface()) } - // It the '?' holder count equals the length of the slice, + // If the '?' holder count equals the length of the slice, // it does not implement the arguments splitting logic. // Eg: db.Query("SELECT ?+?", g.Slice{1, 2}) if len(args) == 1 && gstr.Count(newSql, "?") == rv.Len() { break } // counter is used to finding the inserting position for the '?' holder. - counter := 0 + var ( + counter = 0 + replaced = false + ) newSql, _ = gregex.ReplaceStringFunc(`\?`, newSql, func(s string) string { + if replaced { + return s + } counter++ - if counter == index+1 { + if counter == index+insertHolderCount+1 { + replaced = true + insertHolderCount += rv.Len() - 1 return "?" + strings.Repeat(",?", rv.Len()-1) } return s diff --git a/database/gdb/gdb_unit_z_mysql_model_test.go b/database/gdb/gdb_unit_z_mysql_model_test.go index ddc277027..e1d0df458 100644 --- a/database/gdb/gdb_unit_z_mysql_model_test.go +++ b/database/gdb/gdb_unit_z_mysql_model_test.go @@ -1023,15 +1023,6 @@ func Test_Model_Where(t *testing.T) { t.AssertGT(len(result), 0) t.Assert(result["id"].Int(), 3) }) - gtest.C(t, func(t *gtest.T) { - result, err := db.Table(table).Where(g.Map{ - "id": g.Slice{1, 2, 3}, - "passport": g.Slice{"user_2", "user_3"}, - }).Or("nickname=?", g.Slice{"name_4"}).And("id", 3).One() - t.Assert(err, nil) - t.AssertGT(len(result), 0) - t.Assert(result["id"].Int(), 3) - }) gtest.C(t, func(t *gtest.T) { result, err := db.Table(table).Where("id=3", g.Slice{}).One() t.Assert(err, nil) @@ -1343,7 +1334,7 @@ func Test_Model_WherePri(t *testing.T) { }).Or("nickname=?", g.Slice{"name_4"}).And("id", 3).One() t.Assert(err, nil) t.AssertGT(len(result), 0) - t.Assert(result["id"].Int(), 3) + t.Assert(result["id"].Int(), 2) }) gtest.C(t, func(t *gtest.T) { result, err := db.Table(table).WherePri("id=3", g.Slice{}).One() @@ -1833,6 +1824,32 @@ func Test_Model_Option_Where(t *testing.T) { }) } +func Test_Model_Where_MultiSliceArguments(t *testing.T) { + table := createInitTable() + defer dropTable(table) + gtest.C(t, func(t *gtest.T) { + r, err := db.Table(table).Where(g.Map{ + "id": g.Slice{1, 2, 3, 4}, + "passport": g.Slice{"user_2", "user_3", "user_4"}, + "nickname": g.Slice{"name_2", "name_4"}, + "id >= 4": nil, + }).All() + t.Assert(err, nil) + t.Assert(len(r), 1) + t.Assert(r[0]["id"], 4) + }) + + gtest.C(t, func(t *gtest.T) { + result, err := db.Table(table).Where(g.Map{ + "id": g.Slice{1, 2, 3}, + "passport": g.Slice{"user_2", "user_3"}, + }).Or("nickname=?", g.Slice{"name_4"}).And("id", 3).One() + t.Assert(err, nil) + t.AssertGT(len(result), 0) + t.Assert(result["id"].Int(), 2) + }) +} + func Test_Model_FieldsEx(t *testing.T) { table := createInitTable() defer dropTable(table)