From 852c3dda623a4ef0047530271df4eed7c34bd679 Mon Sep 17 00:00:00 2001 From: John Guo Date: Tue, 9 Dec 2025 15:46:41 +0800 Subject: [PATCH] feat(contrib/drivers/dm&pgsql&mssql&oracle): add Replace/LastInsertId features support for dm/pgsql/mssql/oracle (#4547) This pull request introduces significant improvements to the handling of the `Replace` and `Save` operations for multiple database drivers, especially for MSSQL and PostgreSQL. The changes ensure that these operations now auto-detect primary keys when conflict columns are not explicitly provided, improving usability and aligning behavior across drivers. Additionally, the pull request updates related tests to reflect these enhancements and includes some minor documentation and code cleanup. **Key changes:** ### Enhanced Replace/Save Logic for Database Drivers * **MSSQL Driver:** - `Replace` and `Save` operations now auto-detect primary keys if `OnConflict` is not specified, using the `MERGE` statement for upsert functionality. If no primary key is found in the data, a detailed error is returned. [[1]](diffhunk://#diff-87815aa559a927e2de09bd05148f9841dfc06a1b5f3ecc5e3d5fcb80323a87f8L23-R61) [[2]](diffhunk://#diff-87815aa559a927e2de09bd05148f9841dfc06a1b5f3ecc5e3d5fcb80323a87f8L43-L59) - Updated tests to verify that `Replace` correctly updates or inserts records, and that missing conflict columns are properly handled. [[1]](diffhunk://#diff-bdbde9d7d6ee14c795343767b414740c4396f4dd3e97788b1f9d4e615405a42dL141-R151) [[2]](diffhunk://#diff-26338e93e473300b1313936eb0f6826546473793442f24715fa294b595f7a805L2661-R2707) * **PostgreSQL Driver:** - Similar to MSSQL, `Replace` and `Save` now auto-detect primary keys for conflict resolution if `OnConflict` is not set, and treat `Replace` as a `Save` operation. - Adjusted tests to ensure `Save` and `Replace` work as expected, including verifying data replacement and insertion. [[1]](diffhunk://#diff-c22703c37ebb6836c332f7cd2ada570577ba4564fe39886db02f7c2d0e7a2048L93-R93) [[2]](diffhunk://#diff-c22703c37ebb6836c332f7cd2ada570577ba4564fe39886db02f7c2d0e7a2048R102) [[3]](diffhunk://#diff-c22703c37ebb6836c332f7cd2ada570577ba4564fe39886db02f7c2d0e7a2048L110-R130) * **DM Driver:** - Improved conflict detection: now checks that at least one primary key exists in the provided data when `OnConflict` is not specified, and provides clearer error messages. - Refactored to use the core method for primary key detection and removed redundant code. ### Minor Improvements and Documentation * Added clarifying comments to `DoInsert` methods for ClickHouse, DM, MSSQL, Oracle, and PostgreSQL drivers, specifying that the input list must have at least one validated record. [[1]](diffhunk://#diff-f2e003895041ed3c52b91bb8c270696adc3528d77c39d2f7137af3396267444cR19) [[2]](diffhunk://#diff-f51b30e3f0b0f1284b905385a89992efd0de2fe9ff8c5a4062344dfab17d428eR23) [[3]](diffhunk://#diff-87815aa559a927e2de09bd05148f9841dfc06a1b5f3ecc5e3d5fcb80323a87f8L23-R61) [[4]](diffhunk://#diff-f61dac3fcfd5df4a3936cd8743499c8c0fc45f4f5d0f5398ed84a0cb1603202cR24) [[5]](diffhunk://#diff-c1dfed79aaa3a432057d2bd74d270e4b4094ebcf72984f1161d4972bea009410R16-R72) * Minor code and comment cleanups, including improved formatting and error handling. [[1]](diffhunk://#diff-f61dac3fcfd5df4a3936cd8743499c8c0fc45f4f5d0f5398ed84a0cb1603202cR37) [[2]](diffhunk://#diff-f61dac3fcfd5df4a3936cd8743499c8c0fc45f4f5d0f5398ed84a0cb1603202cL96-R98) [[3]](diffhunk://#diff-f61dac3fcfd5df4a3936cd8743499c8c0fc45f4f5d0f5398ed84a0cb1603202cL106-L116) [[4]](diffhunk://#diff-a17b44c76aaac53d1f164a2bb9440a5531659f4355e7ccfabdadff8dc8633c09L170-R171) [[5]](diffhunk://#diff-56189fa9ae1df51716b50d34d7fe56bfe67a330e8ac2c6b0de7b958db6817ed5R83-R98) ### Workflow and Documentation Updates * Updated example Docker commands in the CI workflow for consistency and clarity. [[1]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL57-R57) [[2]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL78-R78) [[3]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL92-R92) [[4]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL106-R106) [[5]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL153-R153) [[6]](diffhunk://#diff-a1a3cb9bdeb5541d148091d973cf266aa3b317e6415a86630e816cbe27cf8b9cL164-R164) * Removed outdated note about `Replace` support from the SQLite driver documentation. These changes improve the consistency, reliability, and developer experience when performing upsert operations across different database backends. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Lance Add <1196661499@qq.com> --- .github/workflows/ci-main.yml | 12 +- cmd/gf/internal/cmd/gendao/gendao.go | 4 + contrib/drivers/README.MD | 21 +- .../clickhouse/clickhouse_do_insert.go | 1 + contrib/drivers/dm/dm_do_insert.go | 72 +++--- contrib/drivers/dm/dm_z_unit_basic_test.go | 122 ----------- contrib/drivers/dm/dm_z_unit_init_test.go | 30 +++ contrib/drivers/dm/dm_z_unit_model_test.go | 185 ++++++++++++++++ contrib/drivers/mssql/mssql.go | 6 +- contrib/drivers/mssql/mssql_do_exec.go | 51 ++--- contrib/drivers/mssql/mssql_do_insert.go | 151 ++++++++----- contrib/drivers/mssql/mssql_result.go | 22 ++ .../drivers/mssql/mssql_z_unit_basic_test.go | 4 +- .../drivers/mssql/mssql_z_unit_model_test.go | 85 ++++++- contrib/drivers/oracle/oracle.go | 4 - contrib/drivers/oracle/oracle_do_exec.go | 120 ++++++++++ contrib/drivers/oracle/oracle_do_insert.go | 207 ++++++++++++------ contrib/drivers/oracle/oracle_result.go | 24 ++ contrib/drivers/oracle/oracle_table_fields.go | 27 ++- .../oracle/oracle_z_unit_basic_test.go | 7 +- .../drivers/oracle/oracle_z_unit_init_test.go | 60 ++++- .../oracle/oracle_z_unit_model_test.go | 124 ++++++++++- contrib/drivers/pgsql/pgsql.go | 4 - contrib/drivers/pgsql/pgsql_do_insert.go | 60 ++++- contrib/drivers/pgsql/pgsql_table_fields.go | 14 +- contrib/drivers/pgsql/pgsql_z_unit_db_test.go | 37 +++- .../drivers/pgsql/pgsql_z_unit_field_test.go | 16 +- .../drivers/pgsql/pgsql_z_unit_model_test.go | 109 ++++++++- .../drivers/pgsql/pgsql_z_unit_upsert_test.go | 6 +- ...t_upsert.go => sqlitecgo_format_upsert.go} | 0 database/gdb/gdb_core_utility.go | 20 ++ database/gdb/gdb_driver_wrapper_db.go | 12 +- 32 files changed, 1224 insertions(+), 393 deletions(-) create mode 100644 contrib/drivers/dm/dm_z_unit_model_test.go create mode 100644 contrib/drivers/mssql/mssql_result.go create mode 100644 contrib/drivers/oracle/oracle_do_exec.go create mode 100644 contrib/drivers/oracle/oracle_result.go rename contrib/drivers/sqlitecgo/{sqlite_format_upsert.go => sqlitecgo_format_upsert.go} (100%) diff --git a/.github/workflows/ci-main.yml b/.github/workflows/ci-main.yml index e3493de60..b468c81c2 100644 --- a/.github/workflows/ci-main.yml +++ b/.github/workflows/ci-main.yml @@ -54,7 +54,7 @@ jobs: # Service containers to run with `code-test` services: # Etcd service. - # docker run -d --name etcd -p 2379:2379 -e ALLOW_NONE_AUTHENTICATION=yes bitnamilegacy/etcd:3.4.24 + # docker run -p 2379:2379 -e ALLOW_NONE_AUTHENTICATION=yes bitnamilegacy/etcd:3.4.24 etcd: image: bitnamilegacy/etcd:3.4.24 env: @@ -75,7 +75,7 @@ jobs: - 6379:6379 # MySQL backend server. - # docker run -d --name mysql \ + # docker run \ # -p 3306:3306 \ # -e MYSQL_DATABASE=test \ # -e MYSQL_ROOT_PASSWORD=12345678 \ @@ -89,7 +89,7 @@ jobs: - 3306:3306 # MariaDb backend server. - # docker run -d --name mariadb \ + # docker run \ # -p 3307:3306 \ # -e MYSQL_DATABASE=test \ # -e MYSQL_ROOT_PASSWORD=12345678 \ @@ -103,7 +103,7 @@ jobs: - 3307:3306 # PostgreSQL backend server. - # docker run -d --name postgres \ + # docker run \ # -p 5432:5432 \ # -e POSTGRES_PASSWORD=12345678 \ # -e POSTGRES_USER=postgres \ @@ -150,7 +150,7 @@ jobs: --health-retries 10 # ClickHouse backend server. - # docker run -d --name clickhouse \ + # docker run \ # -p 9000:9000 -p 8123:8123 -p 9001:9001 \ # clickhouse/clickhouse-server:24.11.1.2557-alpine clickhouse-server: @@ -161,7 +161,7 @@ jobs: - 9001:9001 # Polaris backend server. - # docker run -d --name polaris \ + # docker run \ # -p 8090:8090 -p 8091:8091 -p 8093:8093 -p 9090:9090 -p 9091:9091 \ # polarismesh/polaris-standalone:v1.17.2 polaris: diff --git a/cmd/gf/internal/cmd/gendao/gendao.go b/cmd/gf/internal/cmd/gendao/gendao.go index 54204755e..bc6ad943a 100644 --- a/cmd/gf/internal/cmd/gendao/gendao.go +++ b/cmd/gf/internal/cmd/gendao/gendao.go @@ -104,6 +104,10 @@ var ( "smallmoney": { Type: "float64", }, + "uuid": { + Type: "uuid.UUID", + Import: "github.com/google/uuid", + }, } // tablewriter Options diff --git a/contrib/drivers/README.MD b/contrib/drivers/README.MD index c10d58d66..1adc943c0 100644 --- a/contrib/drivers/README.MD +++ b/contrib/drivers/README.MD @@ -9,7 +9,7 @@ Let's take `mysql` for example. ```shell go get github.com/gogf/gf/contrib/drivers/mysql/v2@latest -# Easy to copy +# Easy for copying: go get github.com/gogf/gf/contrib/drivers/clickhouse/v2@latest go get github.com/gogf/gf/contrib/drivers/dm/v2@latest go get github.com/gogf/gf/contrib/drivers/mssql/v2@latest @@ -57,7 +57,7 @@ import _ "github.com/gogf/gf/contrib/drivers/sqlite/v2" #### cgo version -When the target is a 32-bit Windows system, the cgo version needs to be used. +When the target is a `32-bit` Windows system, the `cgo` version needs to be used. ```go import _ "github.com/gogf/gf/contrib/drivers/sqlitecgo/v2" @@ -69,10 +69,6 @@ import _ "github.com/gogf/gf/contrib/drivers/sqlitecgo/v2" import _ "github.com/gogf/gf/contrib/drivers/pgsql/v2" ``` -Note: - -- It does not support `Replace` features. - ### SQL Server ```go @@ -81,9 +77,10 @@ import _ "github.com/gogf/gf/contrib/drivers/mssql/v2" Note: -- It does not support `Replace` features. +- `InsertIgnore` returns error if there is no primary key or unique index submitted with record. - It supports server version >= `SQL Server2005` -- It ONLY supports datetime2 and datetimeoffset types for auto handling created_at/updated_at/deleted_at columns, because datetime type does not support microseconds precision when column value is passed as string. +- It ONLY supports `datetime2` and `datetimeoffset` types for auto handling created_at/updated_at/deleted_at columns, + because datetime type does not support microseconds precision when column value is passed as string. ### Oracle @@ -93,8 +90,8 @@ import _ "github.com/gogf/gf/contrib/drivers/oracle/v2" Note: -- It does not support `Replace` features. - It does not support `LastInsertId`. +- `InsertIgnore` returns error if there is no primary key or unique index submitted with record. ### ClickHouse @@ -104,7 +101,7 @@ import _ "github.com/gogf/gf/contrib/drivers/clickhouse/v2" Note: -- It does not support `InsertIgnore/InsertGetId` features. +- It does not support `InsertIgnore/InsertAndGetId` features. - It does not support `Save/Replace` features. - It does not support `Transaction` feature. - It does not support `RowsAffected` feature. @@ -115,6 +112,10 @@ Note: import _ "github.com/gogf/gf/contrib/drivers/dm/v2" ``` +Note: + +- `InsertIgnore` returns error if there is no primary key or unique index submitted with record. + ## Custom Drivers It's quick and easy, please refer to current driver source. diff --git a/contrib/drivers/clickhouse/clickhouse_do_insert.go b/contrib/drivers/clickhouse/clickhouse_do_insert.go index a6c397ae3..a71276913 100644 --- a/contrib/drivers/clickhouse/clickhouse_do_insert.go +++ b/contrib/drivers/clickhouse/clickhouse_do_insert.go @@ -16,6 +16,7 @@ import ( ) // DoInsert inserts or updates data for given table. +// The list parameter must contain at least one record, which was previously validated. func (d *Driver) DoInsert( ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, ) (result sql.Result, err error) { diff --git a/contrib/drivers/dm/dm_do_insert.go b/contrib/drivers/dm/dm_do_insert.go index 538340ffa..72fc540b4 100644 --- a/contrib/drivers/dm/dm_do_insert.go +++ b/contrib/drivers/dm/dm_do_insert.go @@ -20,6 +20,7 @@ import ( ) // DoInsert inserts or updates data for given table. +// The list parameter must contain at least one record, which was previously validated. func (d *Driver) DoInsert( ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, ) (result sql.Result, err error) { @@ -36,6 +37,12 @@ func (d *Driver) DoInsert( return d.doInsertIgnore(ctx, link, table, list, option) default: + // DM database supports IDENTITY auto-increment columns natively. + // The driver automatically returns LastInsertId through sql.Result. + // + // Note: DM IDENTITY columns cannot accept explicit ID values unless + // IDENTITY_INSERT is enabled. When using tables with IDENTITY columns, + // avoid providing explicit ID values in the data. return d.Core.DoInsert(ctx, link, table, list, option) } } @@ -60,16 +67,12 @@ func (d *Driver) doInsertIgnore(ctx context.Context, // When withUpdate is false, it performs insert ignore (insert only when no conflict). func (d *Driver) doMergeInsert( ctx context.Context, - link gdb.Link, - table string, - list gdb.List, - option gdb.DoInsertOption, - withUpdate bool, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, withUpdate bool, ) (result sql.Result, err error) { // If OnConflict is not specified, automatically get the primary key of the table conflictKeys := option.OnConflict if len(conflictKeys) == 0 { - conflictKeys, err = d.getPrimaryKeys(ctx, table) + primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table) if err != nil { return nil, gerror.WrapCode( gcode.CodeInternalError, @@ -77,29 +80,34 @@ func (d *Driver) doMergeInsert( `failed to get primary keys for table`, ) } - if len(conflictKeys) == 0 { - return nil, gerror.NewCode( + foundPrimaryKey := false + for _, primaryKey := range primaryKeys { + for dataKey := range list[0] { + if strings.EqualFold(dataKey, primaryKey) { + foundPrimaryKey = true + break + } + } + if foundPrimaryKey { + break + } + } + if !foundPrimaryKey { + return nil, gerror.NewCodef( gcode.CodeMissingParameter, - `Please specify conflict columns or ensure the table has a primary key`, + `Replace/Save/InsertIgnore operation requires conflict detection: `+ + `either specify OnConflict() columns or ensure table '%s' has a primary key in the data`, + table, ) } - } - - if len(list) == 0 { - opName := "Save" - if !withUpdate { - opName = "InsertIgnore" - } - return nil, gerror.NewCodef( - gcode.CodeInvalidRequest, `%s operation list is empty by dm driver`, opName, - ) + // TODO consider composite primary keys. + conflictKeys = primaryKeys } var ( - one = list[0] - oneLen = len(one) - charL, charR = d.GetChars() - + one = list[0] + oneLen = len(one) + charL, charR = d.GetChars() conflictKeySet = gset.New(false) // queryHolders: Handle data with Holder that need to be merged @@ -155,24 +163,6 @@ func (d *Driver) doMergeInsert( return batchResult, nil } -// getPrimaryKeys retrieves the primary key field names of the table as a slice of strings. -// This method extracts primary key information from TableFields. -func (d *Driver) getPrimaryKeys(ctx context.Context, table string) ([]string, error) { - tableFields, err := d.TableFields(ctx, table) - if err != nil { - return nil, err - } - - var primaryKeys []string - for _, field := range tableFields { - if field.Key == "PRI" { - primaryKeys = append(primaryKeys, field.Name) - } - } - - return primaryKeys, nil -} - // parseSqlForMerge generates MERGE statement for DM database. // When updateValues is empty, it only inserts (INSERT IGNORE behavior). // When updateValues is provided, it performs upsert (INSERT or UPDATE). diff --git a/contrib/drivers/dm/dm_z_unit_basic_test.go b/contrib/drivers/dm/dm_z_unit_basic_test.go index da9ab215b..be8a53945 100644 --- a/contrib/drivers/dm/dm_z_unit_basic_test.go +++ b/contrib/drivers/dm/dm_z_unit_basic_test.go @@ -7,7 +7,6 @@ package dm_test import ( - "database/sql" "fmt" "strings" "testing" @@ -509,124 +508,3 @@ func Test_Empty_Slice_Argument(t *testing.T) { t.Assert(len(result), 0) }) } - -func TestModelSave(t *testing.T) { - table := createTable() - defer dropTable(table) - gtest.C(t, func(t *gtest.T) { - type User struct { - Id int - AccountName string - AttrIndex int - } - var ( - user User - count int - result sql.Result - err error - ) - - result, err = db.Model(table).Data(g.Map{ - "id": 1, - "accountName": "ac1", - "attrIndex": 100, - }).OnConflict("id").Save() - - t.AssertNil(err) - n, _ := result.RowsAffected() - t.Assert(n, 1) - - err = db.Model(table).Scan(&user) - t.AssertNil(err) - t.Assert(user.Id, 1) - t.Assert(user.AccountName, "ac1") - t.Assert(user.AttrIndex, 100) - - _, err = db.Model(table).Data(g.Map{ - "id": 1, - "accountName": "ac2", - "attrIndex": 200, - }).OnConflict("id").Save() - t.AssertNil(err) - - err = db.Model(table).Scan(&user) - t.AssertNil(err) - t.Assert(user.AccountName, "ac2") - t.Assert(user.AttrIndex, 200) - - count, err = db.Model(table).Count() - t.AssertNil(err) - t.Assert(count, 1) - }) -} - -func TestModelInsert(t *testing.T) { - // g.Model.insert not lost default not null column - table := "A_tables" - createInitTable(table) - gtest.C(t, func(t *gtest.T) { - i := 200 - data := User{ - ID: int64(i), - AccountName: fmt.Sprintf(`A%dtwo`, i), - PwdReset: 0, - AttrIndex: 99, - CreatedTime: time.Now(), - UpdatedTime: time.Now(), - } - // _, err := db.Schema(TestDBName).Model(table).Data(data).Insert() - _, err := db.Model(table).Insert(&data) - gtest.AssertNil(err) - }) - - gtest.C(t, func(t *gtest.T) { - i := 201 - data := User{ - ID: int64(i), - AccountName: fmt.Sprintf(`A%dtwoONE`, i), - PwdReset: 1, - CreatedTime: time.Now(), - AttrIndex: 98, - UpdatedTime: time.Now(), - } - // _, err := db.Schema(TestDBName).Model(table).Data(data).Insert() - _, err := db.Model(table).Data(&data).Insert() - gtest.AssertNil(err) - }) -} - -func Test_Model_InsertIgnore(t *testing.T) { - table := createInitTable() - defer dropTable(table) - - // db.SetDebug(true) - - gtest.C(t, func(t *gtest.T) { - data := User{ - ID: int64(666), - AccountName: fmt.Sprintf(`name_%d`, 666), - PwdReset: 0, - AttrIndex: 99, - CreatedTime: time.Now(), - UpdatedTime: time.Now(), - } - _, err := db.Model(table).Data(data).Insert() - t.AssertNil(err) - }) - gtest.C(t, func(t *gtest.T) { - data := User{ - ID: int64(666), - AccountName: fmt.Sprintf(`name_%d`, 777), - PwdReset: 0, - AttrIndex: 99, - CreatedTime: time.Now(), - UpdatedTime: time.Now(), - } - _, err := db.Model(table).Data(data).InsertIgnore() - t.AssertNil(err) - - one, err := db.Model(table).Where("id", 666).One() - t.AssertNil(err) - t.Assert(one["ACCOUNT_NAME"].String(), "name_666") - }) -} diff --git a/contrib/drivers/dm/dm_z_unit_init_test.go b/contrib/drivers/dm/dm_z_unit_init_test.go index 100329a70..30c8aca85 100644 --- a/contrib/drivers/dm/dm_z_unit_init_test.go +++ b/contrib/drivers/dm/dm_z_unit_init_test.go @@ -220,3 +220,33 @@ func createInitTables(len int) []string { } return tables } + +// createTableWithIdentity creates a table with IDENTITY column for LastInsertId testing +func createTableWithIdentity(table ...string) (name string) { + if len(table) > 0 { + name = table[0] + } else { + name = fmt.Sprintf("random_%d", gtime.Timestamp()) + } + + dropTable(name) + + if _, err := db.Exec(ctx, fmt.Sprintf(` + CREATE TABLE "%s" +( +"ID" BIGINT IDENTITY(1, 1) NOT NULL, +"ACCOUNT_NAME" VARCHAR(128) DEFAULT '' NOT NULL COMMENT 'Account Name', +"PWD_RESET" TINYINT DEFAULT 0 NOT NULL, +"ENABLED" INT DEFAULT 1 NOT NULL, +"DELETED" INT DEFAULT 0 NOT NULL, +"ATTR_INDEX" INT DEFAULT 0 , +"CREATED_BY" VARCHAR(32) DEFAULT '' NOT NULL, +"CREATED_TIME" TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP() NOT NULL, +"UPDATED_BY" VARCHAR(32) DEFAULT '' NOT NULL, +"UPDATED_TIME" TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP() NOT NULL, +NOT CLUSTER PRIMARY KEY("ID")) STORAGE(ON "MAIN", CLUSTERBTR) ; + `, name)); err != nil { + gtest.Fatal(err) + } + return +} diff --git a/contrib/drivers/dm/dm_z_unit_model_test.go b/contrib/drivers/dm/dm_z_unit_model_test.go new file mode 100644 index 000000000..80f04010b --- /dev/null +++ b/contrib/drivers/dm/dm_z_unit_model_test.go @@ -0,0 +1,185 @@ +// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package dm_test + +import ( + "database/sql" + "fmt" + "testing" + "time" + + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/gtime" + "github.com/gogf/gf/v2/test/gtest" +) + +func Test_Model_Save(t *testing.T) { + table := createTableWithIdentity() + defer dropTable(table) + gtest.C(t, func(t *gtest.T) { + type User struct { + Id int + AccountName string + AttrIndex int + } + var ( + user User + count int + result sql.Result + err error + ) + + // First insert: let IDENTITY auto-generate ID - use Insert() instead of Save() + // because Save() requires a primary key in the data for conflict detection + result, err = db.Model(table).Data(g.Map{ + "accountName": "ac1", + "attrIndex": 100, + }).Insert() + + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.AssertGT(user.Id, 0) // ID should be auto-generated + t.Assert(user.AccountName, "ac1") + t.Assert(user.AttrIndex, 100) + + // Second save: update the existing record using the generated ID + _, err = db.Model(table).Data(g.Map{ + "id": user.Id, + "accountName": "ac2", + "attrIndex": 200, + }).OnConflict("id").Save() + t.AssertNil(err) + + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.Assert(user.AccountName, "ac2") + t.Assert(user.AttrIndex, 200) + + _, err = db.Model(table).Data(g.Map{ + "id": user.Id, + "accountName": "ac2", + "attrIndex": 2000, + }).Save() + t.AssertNil(err) + + err = db.Model(table).Scan(&user) + t.AssertNil(err) + t.Assert(user.AccountName, "ac2") + t.Assert(user.AttrIndex, 2000) + + count, err = db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 1) + }) +} + +func Test_Model_Insert(t *testing.T) { + // g.Model.insert not lost default not null column + table := "A_tables" + createInitTable(table) + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + i := 200 + data := User{ + ID: int64(i), + AccountName: fmt.Sprintf(`A%dtwo`, i), + PwdReset: 0, + AttrIndex: 99, + CreatedTime: time.Now(), + UpdatedTime: time.Now(), + } + result, err := db.Model(table).Insert(&data) + gtest.AssertNil(err) + n, err := result.RowsAffected() + gtest.AssertNil(err) + gtest.Assert(n, 1) + }) + + gtest.C(t, func(t *gtest.T) { + i := 201 + data := User{ + ID: int64(i), + AccountName: fmt.Sprintf(`A%dtwoONE`, i), + PwdReset: 1, + CreatedTime: time.Now(), + AttrIndex: 98, + UpdatedTime: time.Now(), + } + result, err := db.Model(table).Data(&data).Insert() + gtest.AssertNil(err) + n, err := result.RowsAffected() + gtest.AssertNil(err) + gtest.Assert(n, 1) + }) +} + +func Test_Model_InsertIgnore(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // db.SetDebug(true) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "account_name": fmt.Sprintf(`name_%d`, 777), + "pwd_reset": 0, + "attr_index": 777, + "created_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNil(err) + + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["ACCOUNT_NAME"].String(), "name_1") + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + // "id": 1, + "account_name": fmt.Sprintf(`name_%d`, 777), + "pwd_reset": 0, + "attr_index": 777, + "created_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNE(err, nil) + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) +} + +func Test_Model_InsertAndGetId(t *testing.T) { + table := createTableWithIdentity() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + // "id": 1, + "account_name": fmt.Sprintf(`name_%d`, 1), + "pwd_reset": 0, + "attr_index": 1, + "created_time": gtime.Now(), + } + lastId, err := db.Model(table).Data(data).InsertAndGetId() + t.AssertNil(err) + t.AssertGT(lastId, 0) + }) + +} diff --git a/contrib/drivers/mssql/mssql.go b/contrib/drivers/mssql/mssql.go index 5be217ce0..a8f443e7d 100644 --- a/contrib/drivers/mssql/mssql.go +++ b/contrib/drivers/mssql/mssql.go @@ -4,11 +4,7 @@ // If a copy of the MIT was not distributed with this file, // You can obtain one at https://github.com/gogf/gf. -// Package mssql implements gdb.Driver, which supports operations for database MSSql. -// -// Note: -// 1. It does not support Replace features. -// 2. It does not support LastInsertId. +// Package mssql implements gdb.Driver, which supports operations for MSSQL. package mssql import ( diff --git a/contrib/drivers/mssql/mssql_do_exec.go b/contrib/drivers/mssql/mssql_do_exec.go index b8b014244..8cf90c484 100644 --- a/contrib/drivers/mssql/mssql_do_exec.go +++ b/contrib/drivers/mssql/mssql_do_exec.go @@ -1,3 +1,9 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + package mssql import ( @@ -87,12 +93,16 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sqlStr string, args IsTransaction: link.IsTransaction(), }) if err != nil { - return &InsertResult{lastInsertId: 0, rowsAffected: 0, err: err}, err + return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err } stdSqlResult := out.Records if len(stdSqlResult) == 0 { - err = gerror.WrapCode(gcode.CodeDbOperationError, gerror.New("affectcount is zero"), `sql.Result.RowsAffected failed`) - return &InsertResult{lastInsertId: 0, rowsAffected: 0, err: err}, err + err = gerror.WrapCode( + gcode.CodeDbOperationError, + gerror.New("affected count is zero"), + `sql.Result.RowsAffected failed`, + ) + return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err } // For batch insert, OUTPUT clause returns one row per inserted row. // So the rowsAffected should be the count of returned records. @@ -100,7 +110,7 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sqlStr string, args // get last_insert_id from the first returned row lastInsertId := stdSqlResult[0].GMap().GetVar(lastInsertIdFieldAlias).Int64() - return &InsertResult{lastInsertId: lastInsertId, rowsAffected: rowsAffected}, err + return &Result{lastInsertId: lastInsertId, rowsAffected: rowsAffected}, err } // GetTableNameFromSql get table name from sql statement @@ -111,17 +121,19 @@ func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sqlStr string, args // "user as u". func (d *Driver) GetTableNameFromSql(sqlStr string) (table string) { // INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?) - leftChars, rightChars := d.GetChars() - trimStr := leftChars + rightChars + "[] " - pattern := "INTO(.+?)\\(" - regCompile := regexp.MustCompile(pattern) - tableInfo := regCompile.FindStringSubmatch(sqlStr) + var ( + leftChars, rightChars = d.GetChars() + trimStr = leftChars + rightChars + "[] " + pattern = "INTO(.+?)\\(" + regCompile = regexp.MustCompile(pattern) + tableInfo = regCompile.FindStringSubmatch(sqlStr) + ) // get the first one. after the first it may be content of the value, it's not table name. table = tableInfo[1] table = strings.Trim(table, " ") if strings.Contains(table, ".") { tmpAry := strings.Split(table, ".") - // the last one is tablename + // the last one is table name table = tmpAry[len(tmpAry)-1] } else if strings.Contains(table, "as") || strings.Contains(table, " ") { tmpAry := strings.Split(table, "as") @@ -151,24 +163,9 @@ func (l *txLinkMssql) IsOnMaster() bool { return true } -// InsertResult instance of sql.Result -type InsertResult struct { - lastInsertId int64 - rowsAffected int64 - err error -} - -func (r *InsertResult) LastInsertId() (int64, error) { - return r.lastInsertId, r.err -} - -func (r *InsertResult) RowsAffected() (int64, error) { - return r.rowsAffected, r.err -} - // GetInsertOutputSql gen get last_insert_id code -func (m *Driver) GetInsertOutputSql(ctx context.Context, table string) string { - fds, errFd := m.GetDB().TableFields(ctx, table) +func (d *Driver) GetInsertOutputSql(ctx context.Context, table string) string { + fds, errFd := d.GetDB().TableFields(ctx, table) if errFd != nil { return "" } diff --git a/contrib/drivers/mssql/mssql_do_insert.go b/contrib/drivers/mssql/mssql_do_insert.go index 5e467d730..93bc17cfa 100644 --- a/contrib/drivers/mssql/mssql_do_insert.go +++ b/contrib/drivers/mssql/mssql_do_insert.go @@ -20,51 +20,95 @@ import ( ) // DoInsert inserts or updates data for given table. -func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) { +// The list parameter must contain at least one record, which was previously validated. +func (d *Driver) DoInsert( + ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { switch option.InsertOption { case gdb.InsertOptionSave: return d.doSave(ctx, link, table, list, option) case gdb.InsertOptionReplace: - return nil, gerror.NewCode( - gcode.CodeNotSupported, - `Replace operation is not supported by mssql driver`, - ) + // MSSQL does not support REPLACE INTO syntax, use SAVE instead. + return d.doSave(ctx, link, table, list, option) + + case gdb.InsertOptionIgnore: + // MSSQL does not support INSERT IGNORE syntax, use MERGE instead. + return d.doInsertIgnore(ctx, link, table, list, option) default: return d.Core.DoInsert(ctx, link, table, list, option) } } -// doSave support upsert for SQL server +// doSave support upsert for MSSQL func (d *Driver) doSave(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, ) (result sql.Result, err error) { - if len(option.OnConflict) == 0 { - return nil, gerror.NewCode( - gcode.CodeMissingParameter, `Please specify conflict columns`, - ) - } + return d.doMergeInsert(ctx, link, table, list, option, true) +} - if len(list) == 0 { - return nil, gerror.NewCode( - gcode.CodeInvalidRequest, `Save operation list is empty by mssql driver`, - ) +// doInsertIgnore implements INSERT IGNORE operation using MERGE statement for MSSQL database. +// It only inserts records when there's no conflict on primary/unique keys. +func (d *Driver) doInsertIgnore(ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { + return d.doMergeInsert(ctx, link, table, list, option, false) +} + +// doMergeInsert implements MERGE-based insert operations for MSSQL database. +// When withUpdate is true, it performs upsert (insert or update). +// When withUpdate is false, it performs insert ignore (insert only when no conflict). +func (d *Driver) doMergeInsert( + ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, withUpdate bool, +) (result sql.Result, err error) { + // If OnConflict is not specified, automatically get the primary key of the table + conflictKeys := option.OnConflict + if len(conflictKeys) == 0 { + primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table) + if err != nil { + return nil, gerror.WrapCode( + gcode.CodeInternalError, + err, + `failed to get primary keys for table`, + ) + } + foundPrimaryKey := false + for _, primaryKey := range primaryKeys { + for dataKey := range list[0] { + if strings.EqualFold(dataKey, primaryKey) { + foundPrimaryKey = true + break + } + } + if foundPrimaryKey { + break + } + } + if !foundPrimaryKey { + return nil, gerror.NewCodef( + gcode.CodeMissingParameter, + `Replace/Save/InsertIgnore operation requires conflict detection: `+ + `either specify OnConflict() columns or ensure table '%s' has a primary key in the data`, + table, + ) + } + // TODO consider composite primary keys. + conflictKeys = primaryKeys } var ( - one = list[0] - oneLen = len(one) - charL, charR = d.GetChars() - - conflictKeys = option.OnConflict + one = list[0] + oneLen = len(one) + charL, charR = d.GetChars() conflictKeySet = gset.New(false) - // queryHolders: Handle data with Holder that need to be upsert - // queryValues: Handle data that need to be upsert + // queryHolders: Handle data with Holder that need to be merged + // queryValues: Handle data that need to be merged // insertKeys: Handle valid keys that need to be inserted // insertValues: Handle values that need to be inserted - // updateValues: Handle values that need to be updated + // updateValues: Handle values that need to be updated (only when withUpdate=true) queryHolders = make([]string, oneLen) queryValues = make([]any, oneLen) insertKeys = make([]string, oneLen) @@ -84,9 +128,9 @@ func (d *Driver) doSave(ctx context.Context, insertKeys[index] = charL + key + charR insertValues[index] = "T2." + charL + key + charR - // filter conflict keys in updateValues. - // And the key is not a soft created field. - if !(conflictKeySet.Contains(key) || d.Core.IsSoftCreatedFieldName(key)) { + // Build updateValues only when withUpdate is true + // Filter conflict keys and soft created fields from updateValues + if withUpdate && !(conflictKeySet.Contains(key) || d.Core.IsSoftCreatedFieldName(key)) { updateValues = append( updateValues, fmt.Sprintf(`T1.%s = T2.%s`, charL+key+charR, charL+key+charR), @@ -95,8 +139,10 @@ func (d *Driver) doSave(ctx context.Context, index++ } - batchResult := new(gdb.SqlResult) - sqlStr := parseSqlForUpsert(table, queryHolders, insertKeys, insertValues, updateValues, conflictKeys) + var ( + batchResult = new(gdb.SqlResult) + sqlStr = parseSqlForMerge(table, queryHolders, insertKeys, insertValues, updateValues, conflictKeys) + ) r, err := d.DoExec(ctx, link, sqlStr, queryValues...) if err != nil { return r, err @@ -110,41 +156,48 @@ func (d *Driver) doSave(ctx context.Context, return batchResult, nil } -// parseSqlForUpsert -// MERGE INTO {{table}} T1 -// USING ( VALUES( {{queryHolders}}) T2 ({{insertKeyStr}}) -// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...) -// WHEN NOT MATCHED THEN -// INSERT {{insertKeys}} VALUES {{insertValues}} -// WHEN MATCHED THEN -// UPDATE SET {{updateValues}} -func parseSqlForUpsert(table string, +// parseSqlForMerge generates MERGE statement for MSSQL database. +// When updateValues is empty, it only inserts (INSERT IGNORE behavior). +// When updateValues is provided, it performs upsert (INSERT or UPDATE). +// Examples: +// - INSERT IGNORE: MERGE INTO table T1 USING (...) T2 ON (...) WHEN NOT MATCHED THEN INSERT(...) VALUES (...) +// - UPSERT: MERGE INTO table T1 USING (...) T2 ON (...) WHEN NOT MATCHED THEN INSERT(...) VALUES (...) WHEN MATCHED THEN UPDATE SET ... +func parseSqlForMerge(table string, queryHolders, insertKeys, insertValues, updateValues, duplicateKey []string, ) (sqlStr string) { var ( queryHolderStr = strings.Join(queryHolders, ",") insertKeyStr = strings.Join(insertKeys, ",") insertValueStr = strings.Join(insertValues, ",") - updateValueStr = strings.Join(updateValues, ",") duplicateKeyStr string - pattern = gstr.Trim(`MERGE INTO %s T1 USING (VALUES(%s)) T2 (%s) ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s;`) ) + // Build ON condition for index, keys := range duplicateKey { if index != 0 { duplicateKeyStr += " AND " } - duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys) - duplicateKeyStr += duplicateTmp + duplicateKeyStr += fmt.Sprintf("T1.%s = T2.%s", keys, keys) } - return fmt.Sprintf(pattern, - table, - queryHolderStr, - insertKeyStr, - duplicateKeyStr, - insertKeyStr, - insertValueStr, - updateValueStr, + // Build SQL based on whether UPDATE is needed + pattern := gstr.Trim( + `MERGE INTO %s T1 USING (VALUES(%s)) T2 (%s) ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s)`, ) + if len(updateValues) > 0 { + // Upsert: INSERT or UPDATE + pattern += gstr.Trim(` WHEN MATCHED THEN UPDATE SET %s`) + return fmt.Sprintf( + pattern+";", + table, + queryHolderStr, + insertKeyStr, + duplicateKeyStr, + insertKeyStr, + insertValueStr, + strings.Join(updateValues, ","), + ) + } + // Insert Ignore: INSERT only + return fmt.Sprintf(pattern+";", table, queryHolderStr, insertKeyStr, duplicateKeyStr, insertKeyStr, insertValueStr) } diff --git a/contrib/drivers/mssql/mssql_result.go b/contrib/drivers/mssql/mssql_result.go new file mode 100644 index 000000000..57f3f41a9 --- /dev/null +++ b/contrib/drivers/mssql/mssql_result.go @@ -0,0 +1,22 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package mssql + +// Result instance of sql.Result +type Result struct { + lastInsertId int64 + rowsAffected int64 + err error +} + +func (r *Result) LastInsertId() (int64, error) { + return r.lastInsertId, r.err +} + +func (r *Result) RowsAffected() (int64, error) { + return r.rowsAffected, r.err +} diff --git a/contrib/drivers/mssql/mssql_z_unit_basic_test.go b/contrib/drivers/mssql/mssql_z_unit_basic_test.go index 0999a1eee..49635776d 100644 --- a/contrib/drivers/mssql/mssql_z_unit_basic_test.go +++ b/contrib/drivers/mssql/mssql_z_unit_basic_test.go @@ -138,15 +138,17 @@ func TestDoInsert(t *testing.T) { i := 10 data := g.Map{ - "id": i, + // "id": i, "passport": fmt.Sprintf(`t%d`, i), "password": fmt.Sprintf(`p%d`, i), "nickname": fmt.Sprintf(`T%d`, i), "create_time": gtime.Now(), } + // Save without OnConflict should fail (missing conflict columns) _, err := db.Save(context.Background(), "t_user", data, 10) gtest.AssertNE(err, nil) + // Replace should fail because primary key 'id' is not in the data _, err = db.Replace(context.Background(), "t_user", data, 10) gtest.AssertNE(err, nil) }) diff --git a/contrib/drivers/mssql/mssql_z_unit_model_test.go b/contrib/drivers/mssql/mssql_z_unit_model_test.go index b3a1daa81..1e49e5f5c 100644 --- a/contrib/drivers/mssql/mssql_z_unit_model_test.go +++ b/contrib/drivers/mssql/mssql_z_unit_model_test.go @@ -117,6 +117,48 @@ func Test_Model_Insert(t *testing.T) { }) } +func Test_Model_InsertIgnore(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // db.SetDebug(true) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": fmt.Sprintf(`t%d`, 777), + "password": fmt.Sprintf(`p%d`, 777), + "nickname": fmt.Sprintf(`T%d`, 777), + "create_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNil(err) + + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "user_1") + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "passport": fmt.Sprintf(`t%d`, 777), + "password": fmt.Sprintf(`p%d`, 777), + "nickname": fmt.Sprintf(`T%d`, 777), + "create_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNE(err, nil) + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) +} + func Test_Model_Insert_KeyFieldNameMapping(t *testing.T) { table := createTable() defer dropTable(table) @@ -2658,14 +2700,53 @@ func Test_Model_Replace(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - _, err := db.Model(table).Data(g.Map{ + // Insert initial record + result, err := db.Model(table).Data(g.Map{ + "id": 1, + "passport": "t1", + "password": "pass1", + "nickname": "T1", + "create_time": "2018-10-24 10:00:00", + }).Insert() + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + // Replace with new data (should update existing record using MERGE) + result, err = db.Model(table).Data(g.Map{ "id": 1, "passport": "t11", "password": "25d55ad283aa400af464c76d713c07ad", "nickname": "T11", "create_time": "2018-10-24 10:00:00", }).Replace() - t.Assert(err, "Replace operation is not supported by mssql driver") + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) + + // Verify the data was replaced + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "t11") + t.Assert(one["NICKNAME"].String(), "T11") + + // Replace with non-existing record (should insert new record) + result, err = db.Model(table).Data(g.Map{ + "id": 2, + "passport": "t222", + "password": "pass2", + "nickname": "T222", + "create_time": "2018-10-24 11:00:00", + }).Replace() + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) // MERGE reports: 1 for insert + + // Verify the new record was inserted + one, err = db.Model(table).WherePri(2).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "t222") + t.Assert(one["NICKNAME"].String(), "T222") }) } diff --git a/contrib/drivers/oracle/oracle.go b/contrib/drivers/oracle/oracle.go index ae6a1c5dd..51aa4df25 100644 --- a/contrib/drivers/oracle/oracle.go +++ b/contrib/drivers/oracle/oracle.go @@ -5,10 +5,6 @@ // You can obtain one at https://github.com/gogf/gf. // Package oracle implements gdb.Driver, which supports operations for database Oracle. -// -// Note: -// 1. It does not support Save/Replace features. -// 2. It does not support LastInsertId. package oracle import ( diff --git a/contrib/drivers/oracle/oracle_do_exec.go b/contrib/drivers/oracle/oracle_do_exec.go new file mode 100644 index 000000000..d7fe4a39d --- /dev/null +++ b/contrib/drivers/oracle/oracle_do_exec.go @@ -0,0 +1,120 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package oracle + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" +) + +const ( + returningClause = " RETURNING %s INTO ?" +) + +// DoExec commits the sql string and its arguments to underlying driver +// through given link object and returns the execution result. +// It handles INSERT statements specially to support LastInsertId. +func (d *Driver) DoExec( + ctx context.Context, link gdb.Link, sql string, args ...interface{}, +) (result sql.Result, err error) { + var ( + isUseCoreDoExec = true + primaryKey string + pkField gdb.TableField + ) + + // Transaction checks. + if link == nil { + if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil { + link = tx + } else if link, err = d.MasterLink(); err != nil { + return nil, err + } + } else if !link.IsTransaction() { + if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil { + link = tx + } + } + + // Check if it is an insert operation with primary key from context. + if value := ctx.Value(internalPrimaryKeyInCtx); value != nil { + if field, ok := value.(gdb.TableField); ok { + pkField = field + isUseCoreDoExec = false + } + } + + // Check if it is an INSERT statement with primary key. + if !isUseCoreDoExec && pkField.Name != "" && strings.Contains(strings.ToUpper(sql), "INSERT INTO") { + primaryKey = pkField.Name + // Oracle supports RETURNING clause to get the last inserted id + sql += fmt.Sprintf(returningClause, d.QuoteWord(primaryKey)) + } else { + // Use default DoExec for non-INSERT or no primary key scenarios + return d.Core.DoExec(ctx, link, sql, args...) + } + + // Only the insert operation with primary key can execute the following code + + // SQL filtering. + sql, args = d.FormatSqlBeforeExecuting(sql, args) + sql, args, err = d.DoFilter(ctx, link, sql, args) + if err != nil { + return nil, err + } + + // Prepare output variable for RETURNING clause + var lastInsertId int64 + // Append the output parameter for the RETURNING clause + args = append(args, &lastInsertId) + + // Link execution. + _, err = d.DoCommit(ctx, gdb.DoCommitInput{ + Link: link, + Sql: sql, + Args: args, + Stmt: nil, + Type: gdb.SqlTypeExecContext, + IsTransaction: link.IsTransaction(), + }) + + if err != nil { + return &Result{ + lastInsertId: 0, + rowsAffected: 0, + lastInsertIdError: err, + }, err + } + + // Get rows affected from the result + // For single insert with RETURNING clause, affected is always 1 + var affected int64 = 1 + + // Check if the primary key field type supports LastInsertId + if !strings.Contains(strings.ToLower(pkField.Type), "int") { + return &Result{ + lastInsertId: 0, + rowsAffected: affected, + lastInsertIdError: gerror.NewCodef( + gcode.CodeNotSupported, + "LastInsertId is not supported by primary key type: %s", + pkField.Type, + ), + }, nil + } + + return &Result{ + lastInsertId: lastInsertId, + rowsAffected: affected, + }, nil +} diff --git a/contrib/drivers/oracle/oracle_do_insert.go b/contrib/drivers/oracle/oracle_do_insert.go index d59bdf95b..82f8373d5 100644 --- a/contrib/drivers/oracle/oracle_do_insert.go +++ b/contrib/drivers/oracle/oracle_do_insert.go @@ -16,11 +16,17 @@ import ( "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gctx" "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" ) +const ( + internalPrimaryKeyInCtx gctx.StrKey = "primary_key_field" +) + // DoInsert inserts or updates data for given table. +// The list parameter must contain at least one record, which was previously validated. func (d *Driver) DoInsert( ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, ) (result sql.Result, err error) { @@ -29,10 +35,39 @@ func (d *Driver) DoInsert( return d.doSave(ctx, link, table, list, option) case gdb.InsertOptionReplace: - return nil, gerror.NewCode( - gcode.CodeNotSupported, - `Replace operation is not supported by oracle driver`, - ) + // Oracle does not support REPLACE INTO syntax, use SAVE instead. + return d.doSave(ctx, link, table, list, option) + + case gdb.InsertOptionIgnore: + // Oracle does not support INSERT IGNORE syntax, use MERGE instead. + return d.doInsertIgnore(ctx, link, table, list, option) + + case gdb.InsertOptionDefault: + // For default insert, set primary key field in context to support LastInsertId. + // Only set it when the primary key is not provided in the data, for performance reason. + tableFields, err := d.GetCore().GetDB().TableFields(ctx, table) + if err == nil && len(list) > 0 { + for _, field := range tableFields { + if strings.EqualFold(field.Key, "pri") { + // Check if primary key is provided in the data. + pkProvided := false + for key := range list[0] { + if strings.EqualFold(key, field.Name) { + pkProvided = true + break + } + } + // Only use RETURNING when primary key is not provided, for performance reason. + if !pkProvided { + pkField := *field + ctx = context.WithValue(ctx, internalPrimaryKeyInCtx, pkField) + } + break + } + } + } + + default: } var ( keys []string @@ -55,8 +90,8 @@ func (d *Driver) DoInsert( valueHolderStr = strings.Join(valueHolder, ",") ) // Format "INSERT...INTO..." statement. - intoStrArray := make([]string, 0) - for i := 0; i < len(list); i++ { + // Note: Use standard INSERT INTO syntax instead of INSERT ALL to ensure triggers fire + for i := 0; i < listLength; i++ { for _, k := range keys { if s, ok := list[i][k].(gdb.Raw); ok { params = append(params, gconv.String(s)) @@ -65,30 +100,22 @@ func (d *Driver) DoInsert( } } values = append(values, valueHolderStr) - intoStrArray = append( - intoStrArray, - fmt.Sprintf( - "INTO %s(%s) VALUES(%s)", - table, keyStr, valueHolderStr, - ), - ) - if len(intoStrArray) == option.BatchCount || (i == listLength-1 && len(valueHolder) > 0) { - r, err := d.DoExec(ctx, link, fmt.Sprintf( - "INSERT ALL %s SELECT * FROM DUAL", - strings.Join(intoStrArray, " "), - ), params...) - if err != nil { - return r, err - } - if n, err := r.RowsAffected(); err != nil { - return r, err - } else { - batchResult.Result = r - batchResult.Affected += n - } - params = params[:0] - intoStrArray = intoStrArray[:0] + + // Execute individual INSERT for each record to trigger row-level triggers + r, err := d.DoExec(ctx, link, fmt.Sprintf( + "INSERT INTO %s(%s) VALUES(%s)", + table, keyStr, valueHolderStr, + ), params...) + if err != nil { + return r, err } + if n, err := r.RowsAffected(); err != nil { + return r, err + } else { + batchResult.Result = r + batchResult.Affected += n + } + params = params[:0] } return batchResult, nil } @@ -97,24 +124,63 @@ func (d *Driver) DoInsert( func (d *Driver) doSave(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, ) (result sql.Result, err error) { - if len(option.OnConflict) == 0 { - return nil, gerror.NewCode( - gcode.CodeMissingParameter, `Please specify conflict columns`, - ) - } + return d.doMergeInsert(ctx, link, table, list, option, true) +} - if len(list) == 0 { - return nil, gerror.NewCode( - gcode.CodeInvalidRequest, `Save operation list is empty by oracle driver`, - ) +// doInsertIgnore implements INSERT IGNORE operation using MERGE statement for Oracle database. +// It only inserts records when there's no conflict on primary/unique keys. +func (d *Driver) doInsertIgnore(ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { + return d.doMergeInsert(ctx, link, table, list, option, false) +} + +// doMergeInsert implements MERGE-based insert operations for Oracle database. +// When withUpdate is true, it performs upsert (insert or update). +// When withUpdate is false, it performs insert ignore (insert only when no conflict). +func (d *Driver) doMergeInsert( + ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, withUpdate bool, +) (result sql.Result, err error) { + // If OnConflict is not specified, automatically get the primary key of the table + conflictKeys := option.OnConflict + if len(conflictKeys) == 0 { + primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table) + if err != nil { + return nil, gerror.WrapCode( + gcode.CodeInternalError, + err, + `failed to get primary keys for table`, + ) + } + foundPrimaryKey := false + for _, primaryKey := range primaryKeys { + for dataKey := range list[0] { + if strings.EqualFold(dataKey, primaryKey) { + foundPrimaryKey = true + break + } + } + if foundPrimaryKey { + break + } + } + if !foundPrimaryKey { + return nil, gerror.NewCodef( + gcode.CodeMissingParameter, + `Replace/Save/InsertIgnore operation requires conflict detection: `+ + `either specify OnConflict() columns or ensure table '%s' has a primary key in the data`, + table, + ) + } + // TODO consider composite primary keys. + conflictKeys = primaryKeys } var ( - one = list[0] - oneLen = len(one) - charL, charR = d.GetChars() - - conflictKeys = option.OnConflict + one = list[0] + oneLen = len(one) + charL, charR = d.GetChars() conflictKeySet = gset.New(false) // queryHolders: Handle data with Holder that need to be upsert @@ -142,9 +208,9 @@ func (d *Driver) doSave(ctx context.Context, insertKeys[index] = keyWithChar insertValues[index] = fmt.Sprintf("T2.%s", keyWithChar) - // filter conflict keys in updateValues. - // And the key is not a soft created field. - if !(conflictKeySet.Contains(key) || d.Core.IsSoftCreatedFieldName(key)) { + // Build updateValues only when withUpdate is true + // Filter conflict keys and soft created fields from updateValues + if withUpdate && !(conflictKeySet.Contains(key) || d.Core.IsSoftCreatedFieldName(key)) { updateValues = append( updateValues, fmt.Sprintf(`T1.%s = T2.%s`, keyWithChar, keyWithChar), @@ -153,8 +219,10 @@ func (d *Driver) doSave(ctx context.Context, index++ } - batchResult := new(gdb.SqlResult) - sqlStr := parseSqlForUpsert(table, queryHolders, insertKeys, insertValues, updateValues, conflictKeys) + var ( + batchResult = new(gdb.SqlResult) + sqlStr = parseSqlForMerge(table, queryHolders, insertKeys, insertValues, updateValues, conflictKeys) + ) r, err := d.DoExec(ctx, link, sqlStr, queryValues...) if err != nil { return r, err @@ -168,40 +236,43 @@ func (d *Driver) doSave(ctx context.Context, return batchResult, nil } -// parseSqlForUpsert -// MERGE INTO {{table}} T1 -// USING ( SELECT {{queryHolders}} FROM DUAL T2 -// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...) -// WHEN NOT MATCHED THEN -// INSERT {{insertKeys}} VALUES {{insertValues}} -// WHEN MATCHED THEN -// UPDATE SET {{updateValues}} -func parseSqlForUpsert(table string, +// parseSqlForMerge generates MERGE statement for Oracle database. +// When updateValues is empty, it only inserts (INSERT IGNORE behavior). +// When updateValues is provided, it performs upsert (INSERT or UPDATE). +// Examples: +// - INSERT IGNORE: MERGE INTO table T1 USING (...) T2 ON (...) WHEN NOT MATCHED THEN INSERT(...) VALUES (...) +// - UPSERT: MERGE INTO table T1 USING (...) T2 ON (...) WHEN NOT MATCHED THEN INSERT(...) VALUES (...) WHEN MATCHED THEN UPDATE SET ... +func parseSqlForMerge(table string, queryHolders, insertKeys, insertValues, updateValues, duplicateKey []string, ) (sqlStr string) { var ( queryHolderStr = strings.Join(queryHolders, ",") insertKeyStr = strings.Join(insertKeys, ",") insertValueStr = strings.Join(insertValues, ",") - updateValueStr = strings.Join(updateValues, ",") duplicateKeyStr string - pattern = gstr.Trim(`MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s`) ) + // Build ON condition for index, keys := range duplicateKey { if index != 0 { duplicateKeyStr += " AND " } - duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys) - duplicateKeyStr += duplicateTmp + duplicateKeyStr += fmt.Sprintf("T1.%s = T2.%s", keys, keys) } - return fmt.Sprintf(pattern, - table, - queryHolderStr, - duplicateKeyStr, - insertKeyStr, - insertValueStr, - updateValueStr, + // Build SQL based on whether UPDATE is needed + pattern := gstr.Trim( + `MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN ` + + `NOT MATCHED THEN INSERT(%s) VALUES (%s)`, ) + if len(updateValues) > 0 { + // Upsert: INSERT or UPDATE + pattern += gstr.Trim(` WHEN MATCHED THEN UPDATE SET %s`) + return fmt.Sprintf( + pattern, table, queryHolderStr, duplicateKeyStr, insertKeyStr, insertValueStr, + strings.Join(updateValues, ","), + ) + } + // Insert Ignore: INSERT only + return fmt.Sprintf(pattern, table, queryHolderStr, duplicateKeyStr, insertKeyStr, insertValueStr) } diff --git a/contrib/drivers/oracle/oracle_result.go b/contrib/drivers/oracle/oracle_result.go new file mode 100644 index 000000000..a4795530b --- /dev/null +++ b/contrib/drivers/oracle/oracle_result.go @@ -0,0 +1,24 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package oracle + +// Result implements sql.Result interface for Oracle database. +type Result struct { + lastInsertId int64 + rowsAffected int64 + lastInsertIdError error +} + +// LastInsertId returns the last insert id. +func (r *Result) LastInsertId() (int64, error) { + return r.lastInsertId, r.lastInsertIdError +} + +// RowsAffected returns the rows affected. +func (r *Result) RowsAffected() (int64, error) { + return r.rowsAffected, nil +} diff --git a/contrib/drivers/oracle/oracle_table_fields.go b/contrib/drivers/oracle/oracle_table_fields.go index aa20858dc..8efb9c110 100644 --- a/contrib/drivers/oracle/oracle_table_fields.go +++ b/contrib/drivers/oracle/oracle_table_fields.go @@ -18,13 +18,23 @@ import ( var ( tableFieldsSqlTmp = ` SELECT - COLUMN_NAME AS FIELD, + c.COLUMN_NAME AS FIELD, CASE - WHEN (DATA_TYPE='NUMBER' AND NVL(DATA_SCALE,0)=0) THEN 'INT'||'('||DATA_PRECISION||','||DATA_SCALE||')' - WHEN (DATA_TYPE='NUMBER' AND NVL(DATA_SCALE,0)>0) THEN 'FLOAT'||'('||DATA_PRECISION||','||DATA_SCALE||')' - WHEN DATA_TYPE='FLOAT' THEN DATA_TYPE||'('||DATA_PRECISION||','||DATA_SCALE||')' - ELSE DATA_TYPE||'('||DATA_LENGTH||')' END AS TYPE,NULLABLE -FROM USER_TAB_COLUMNS WHERE TABLE_NAME = '%s' ORDER BY COLUMN_ID + WHEN (c.DATA_TYPE='NUMBER' AND NVL(c.DATA_SCALE,0)=0) THEN 'INT'||'('||c.DATA_PRECISION||','||c.DATA_SCALE||')' + WHEN (c.DATA_TYPE='NUMBER' AND NVL(c.DATA_SCALE,0)>0) THEN 'FLOAT'||'('||c.DATA_PRECISION||','||c.DATA_SCALE||')' + WHEN c.DATA_TYPE='FLOAT' THEN c.DATA_TYPE||'('||c.DATA_PRECISION||','||c.DATA_SCALE||')' + ELSE c.DATA_TYPE||'('||c.DATA_LENGTH||')' END AS TYPE, + c.NULLABLE, + CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 'PRI' ELSE '' END AS KEY +FROM USER_TAB_COLUMNS c +LEFT JOIN ( + SELECT cols.COLUMN_NAME + FROM USER_CONSTRAINTS cons + JOIN USER_CONS_COLUMNS cols ON cons.CONSTRAINT_NAME = cols.CONSTRAINT_NAME + WHERE cons.TABLE_NAME = '%s' AND cons.CONSTRAINT_TYPE = 'P' +) pk ON c.COLUMN_NAME = pk.COLUMN_NAME +WHERE c.TABLE_NAME = '%s' +ORDER BY c.COLUMN_ID ` ) @@ -44,7 +54,8 @@ func (d *Driver) TableFields(ctx context.Context, table string, schema ...string result gdb.Result link gdb.Link usedSchema = gutil.GetOrDefaultStr(d.GetSchema(), schema...) - structureSql = fmt.Sprintf(tableFieldsSqlTmp, strings.ToUpper(table)) + upperTable = strings.ToUpper(table) + structureSql = fmt.Sprintf(tableFieldsSqlTmp, upperTable, upperTable) ) if link, err = d.SlaveLink(usedSchema); err != nil { return nil, err @@ -53,6 +64,7 @@ func (d *Driver) TableFields(ctx context.Context, table string, schema ...string if err != nil { return nil, err } + fields = make(map[string]*gdb.TableField) for i, m := range result { isNull := false @@ -65,6 +77,7 @@ func (d *Driver) TableFields(ctx context.Context, table string, schema ...string Name: m["FIELD"].String(), Type: m["TYPE"].String(), Null: isNull, + Key: m["KEY"].String(), } } return fields, nil diff --git a/contrib/drivers/oracle/oracle_z_unit_basic_test.go b/contrib/drivers/oracle/oracle_z_unit_basic_test.go index 7a0af2624..24836455a 100644 --- a/contrib/drivers/oracle/oracle_z_unit_basic_test.go +++ b/contrib/drivers/oracle/oracle_z_unit_basic_test.go @@ -139,10 +139,10 @@ func Test_Do_Insert(t *testing.T) { "CREATE_TIME": gtime.Now().String(), } _, err := db.Save(ctx, "t_user", data, 10) - gtest.AssertNE(err, nil) + gtest.AssertNil(err) _, err = db.Replace(ctx, "t_user", data, 10) - gtest.AssertNE(err, nil) + gtest.AssertNil(err) }) } @@ -185,6 +185,7 @@ func Test_DB_Insert(t *testing.T) { table := createTable() defer dropTable(table) + // db.SetDebug(true) gtest.C(t, func(t *gtest.T) { _, err := db.Insert(ctx, table, g.Map{ "ID": 1, @@ -233,7 +234,7 @@ func Test_DB_Insert(t *testing.T) { one, err := db.Model(table).Where("ID", 3).One() t.AssertNil(err) - fmt.Println(one) + // fmt.Println(one) t.Assert(one["ID"].Int(), 3) t.Assert(one["PASSPORT"].String(), "user_3") t.Assert(one["PASSWORD"].String(), "25d55ad283aa400af464c76d713c07ad") diff --git a/contrib/drivers/oracle/oracle_z_unit_init_test.go b/contrib/drivers/oracle/oracle_z_unit_init_test.go index 3e549d4b2..74bf9f26e 100644 --- a/contrib/drivers/oracle/oracle_z_unit_init_test.go +++ b/contrib/drivers/oracle/oracle_z_unit_init_test.go @@ -113,16 +113,48 @@ func createTable(table ...string) (name string) { dropTable(name) - if _, err := db.Exec(ctx, fmt.Sprintf(` - CREATE TABLE %s ( - ID NUMBER(10) NOT NULL, - PASSPORT VARCHAR(45) NOT NULL, - PASSWORD CHAR(32) NOT NULL, - NICKNAME VARCHAR(45) NOT NULL, - CREATE_TIME varchar(45), - SALARY NUMBER(18,2), - PRIMARY KEY (ID)) - `, name)); err != nil { + // Step 1: Create table + createTableSQL := fmt.Sprintf(` + CREATE TABLE %s ( + ID NUMBER(10) NOT NULL, + PASSPORT VARCHAR(45) NOT NULL, + PASSWORD CHAR(32) NOT NULL, + NICKNAME VARCHAR(45) NOT NULL, + CREATE_TIME VARCHAR(45), + SALARY NUMBER(18,2), + PRIMARY KEY (ID) + )`, name) + + if _, err := db.Exec(ctx, createTableSQL); err != nil { + gtest.Fatal(err) + } + + // Step 2: Create sequence + createSeqSQL := fmt.Sprintf(` + CREATE SEQUENCE %s_ID_SEQ + START WITH 1 + INCREMENT BY 1 + MINVALUE 1 + MAXVALUE 9999999999 + NOCYCLE + NOCACHE`, name) + + if _, err := db.Exec(ctx, createSeqSQL); err != nil { + gtest.Fatal(err) + } + + // Step 3: Create trigger - only set ID from sequence when it's NULL + createTriggerSQL := fmt.Sprintf(` +CREATE OR REPLACE TRIGGER %s_ID_TRG +BEFORE INSERT ON %s +FOR EACH ROW +BEGIN + IF :NEW.ID IS NULL THEN + :NEW.ID := %s_ID_SEQ.NEXTVAL; + END IF; +END;`, name, name, name) + + if _, err := db.Exec(ctx, createTriggerSQL); err != nil { gtest.Fatal(err) } @@ -160,7 +192,15 @@ func dropTable(table string) { if count == 0 { return } + + // Drop table if _, err = db.Exec(ctx, fmt.Sprintf("DROP TABLE %s", table)); err != nil { gtest.Fatal(err) } + + // Drop sequence if exists + seqCount, err := db.GetCount(ctx, "SELECT COUNT(*) FROM USER_SEQUENCES WHERE SEQUENCE_NAME = ?", strings.ToUpper(table+"_ID_SEQ")) + if err == nil && seqCount > 0 { + db.Exec(ctx, fmt.Sprintf("DROP SEQUENCE %s_ID_SEQ", table)) + } } diff --git a/contrib/drivers/oracle/oracle_z_unit_model_test.go b/contrib/drivers/oracle/oracle_z_unit_model_test.go index 26031615e..185446b24 100644 --- a/contrib/drivers/oracle/oracle_z_unit_model_test.go +++ b/contrib/drivers/oracle/oracle_z_unit_model_test.go @@ -233,6 +233,67 @@ func Test_Model_Insert(t *testing.T) { }) } +func Test_Model_InsertIgnore(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + // db.SetDebug(true) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "id": 1, + "passport": fmt.Sprintf(`t%d`, 777), + "password": fmt.Sprintf(`p%d`, 777), + "nickname": fmt.Sprintf(`T%d`, 777), + "create_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNil(err) + + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "user_1") + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + "passport": fmt.Sprintf(`t%d`, 777), + "password": fmt.Sprintf(`p%d`, 777), + "nickname": fmt.Sprintf(`T%d`, 777), + "create_time": gtime.Now(), + } + _, err := db.Model(table).Data(data).InsertIgnore() + t.AssertNE(err, nil) + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, TableSize) + }) +} + +func Test_Model_InsertAndGetId(t *testing.T) { + table := createTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + data := g.Map{ + // "id": 1, + "passport": fmt.Sprintf(`t%d`, 1), + "password": fmt.Sprintf(`p%d`, 1), + "nickname": fmt.Sprintf(`T%d`, 1), + "create_time": gtime.Now(), + } + lastId, err := db.Model(table).Data(data).InsertAndGetId() + t.AssertNil(err) + t.AssertGT(lastId, 0) + }) + +} + // https://github.com/gogf/gf/issues/3286 func Test_Model_Insert_Raw(t *testing.T) { table := createTable() @@ -1179,14 +1240,73 @@ func Test_Model_Replace(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - _, err := db.Model(table).Data(g.Map{ + // Insert initial record + result, err := db.Model(table).Data(g.Map{ + "id": 1, + "passport": "t1", + "password": "pass1", + "nickname": "T1", + "create_time": "2018-10-24 10:00:00", + }).Insert() + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + // Replace with new data (should update existing record using MERGE) + result, err = db.Model(table).Data(g.Map{ "id": 1, "passport": "t11", "password": "25d55ad283aa400af464c76d713c07ad", "nickname": "T11", "create_time": "2018-10-24 10:00:00", + }).OnConflict("id").Replace() + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) + + // Verify the data was replaced + one, err := db.Model(table).WherePri(1).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "t11") + t.Assert(one["PASSWORD"].String(), "25d55ad283aa400af464c76d713c07ad") + t.Assert(one["NICKNAME"].String(), "T11") + + // Replace with new ID (insert new record) + result, err = db.Model(table).Data(g.Map{ + "id": 2, + "passport": "t222", + "password": "pass2", + "nickname": "T222", + "create_time": "2018-10-24 11:00:00", + }).OnConflict("id").Replace() + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) + + // Verify new record was inserted + one, err = db.Model(table).Where("id", 2).One() + t.AssertNil(err) + t.Assert(one["PASSPORT"].String(), "t222") + t.Assert(one["NICKNAME"].String(), "T222") + + // Replace without OnConflict (primary key auto-detection is implemented) + _, err = db.Model(table).Data(g.Map{ + "id": 3, + "passport": "t3", + "password": "pass3", + "nickname": "T3", + "create_time": "2018-10-24 12:00:00", }).Replace() - t.Assert(err, "Replace operation is not supported by oracle driver") + t.AssertNil(err) + + _, err = db.Model(table).Data(g.Map{ + // "id": 3, + "passport": "t3", + "password": "pass3", + "nickname": "T3", + "create_time": "2018-10-24 12:00:00", + }).Replace() + t.AssertNE(err, nil) }) } diff --git a/contrib/drivers/pgsql/pgsql.go b/contrib/drivers/pgsql/pgsql.go index b7a18a810..3c3272240 100644 --- a/contrib/drivers/pgsql/pgsql.go +++ b/contrib/drivers/pgsql/pgsql.go @@ -5,10 +5,6 @@ // You can obtain one at https://github.com/gogf/gf. // Package pgsql implements gdb.Driver, which supports operations for database PostgreSQL. -// -// Note: -// 1. It does not support Replace features. -// 2. It does not support Insert Ignore features. package pgsql import ( diff --git a/contrib/drivers/pgsql/pgsql_do_insert.go b/contrib/drivers/pgsql/pgsql_do_insert.go index ec24edde3..6bd0ba142 100644 --- a/contrib/drivers/pgsql/pgsql_do_insert.go +++ b/contrib/drivers/pgsql/pgsql_do_insert.go @@ -9,6 +9,7 @@ package pgsql import ( "context" "database/sql" + "strings" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/errors/gcode" @@ -16,25 +17,68 @@ import ( ) // DoInsert inserts or updates data for given table. -func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption) (result sql.Result, err error) { +// The list parameter must contain at least one record, which was previously validated. +func (d *Driver) DoInsert( + ctx context.Context, + link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption, +) (result sql.Result, err error) { switch option.InsertOption { - case gdb.InsertOptionReplace: - return nil, gerror.NewCode( - gcode.CodeNotSupported, - `Replace operation is not supported by pgsql driver`, - ) + case + gdb.InsertOptionSave, + gdb.InsertOptionReplace: + // PostgreSQL does not support REPLACE INTO syntax, use Save (ON CONFLICT ... DO UPDATE) instead. + // Automatically detect primary keys if OnConflict is not specified. + if len(option.OnConflict) == 0 { + primaryKeys, err := d.Core.GetPrimaryKeys(ctx, table) + if err != nil { + return nil, gerror.WrapCode( + gcode.CodeInternalError, + err, + `failed to get primary keys for Save/Replace operation`, + ) + } + foundPrimaryKey := false + for _, primaryKey := range primaryKeys { + for dataKey := range list[0] { + if strings.EqualFold(dataKey, primaryKey) { + foundPrimaryKey = true + break + } + } + if foundPrimaryKey { + break + } + } + if !foundPrimaryKey { + return nil, gerror.NewCodef( + gcode.CodeMissingParameter, + `Replace/Save operation requires conflict detection: `+ + `either specify OnConflict() columns or ensure table '%s' has a primary key in the data`, + table, + ) + } + // TODO consider composite primary keys. + option.OnConflict = primaryKeys + } + // Treat Replace as Save operation + option.InsertOption = gdb.InsertOptionSave - case gdb.InsertOptionDefault: + // pgsql support InsertIgnore natively, so no need to set primary key in context. + case gdb.InsertOptionIgnore, gdb.InsertOptionDefault: + // Get table fields to retrieve the primary key TableField object (not just the name) + // because DoExec needs the `TableField.Type` to determine if LastInsertId is supported. tableFields, err := d.GetCore().GetDB().TableFields(ctx, table) if err == nil { for _, field := range tableFields { - if field.Key == "pri" { + if strings.EqualFold(field.Key, "pri") { pkField := *field ctx = context.WithValue(ctx, internalPrimaryKeyInCtx, pkField) break } } } + + default: } return d.Core.DoInsert(ctx, link, table, list, option) } diff --git a/contrib/drivers/pgsql/pgsql_table_fields.go b/contrib/drivers/pgsql/pgsql_table_fields.go index 07f3a4e43..8573648f6 100644 --- a/contrib/drivers/pgsql/pgsql_table_fields.go +++ b/contrib/drivers/pgsql/pgsql_table_fields.go @@ -80,10 +80,22 @@ func (d *Driver) TableFields(ctx context.Context, table string, schema ...string } continue } + + var ( + fieldType string + dataType = m["type"].String() + dataLength = m["length"].Int() + ) + if dataLength > 0 { + fieldType = fmt.Sprintf("%s(%d)", dataType, dataLength) + } else { + fieldType = dataType + } + fields[name] = &gdb.TableField{ Index: index, Name: name, - Type: m["type"].String(), + Type: fieldType, Null: !m["null"].Bool(), Key: m["key"].String(), Default: m["default_value"].Val(), diff --git a/contrib/drivers/pgsql/pgsql_z_unit_db_test.go b/contrib/drivers/pgsql/pgsql_z_unit_db_test.go index 67b9d7978..a83cb38bf 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_db_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_db_test.go @@ -90,7 +90,7 @@ func Test_DB_Save(t *testing.T) { "create_time": gtime.Now().String(), } _, err := db.Save(ctx, "t_user", data, 10) - gtest.AssertNE(err, nil) + gtest.AssertNil(err) }) } @@ -99,6 +99,7 @@ func Test_DB_Replace(t *testing.T) { createTable("t_user") defer dropTable("t_user") + // Insert initial record i := 10 data := g.Map{ "id": i, @@ -107,8 +108,26 @@ func Test_DB_Replace(t *testing.T) { "nickname": fmt.Sprintf(`T%d`, i), "create_time": gtime.Now().String(), } - _, err := db.Replace(ctx, "t_user", data, 10) - gtest.AssertNE(err, nil) + _, err := db.Insert(ctx, "t_user", data) + gtest.AssertNil(err) + + // Replace with new data + data2 := g.Map{ + "id": i, + "passport": fmt.Sprintf(`t%d_new`, i), + "password": fmt.Sprintf(`p%d_new`, i), + "nickname": fmt.Sprintf(`T%d_new`, i), + "create_time": gtime.Now().String(), + } + _, err = db.Replace(ctx, "t_user", data2) + gtest.AssertNil(err) + + // Verify the data was replaced + one, err := db.GetOne(ctx, fmt.Sprintf("SELECT * FROM t_user WHERE id=?"), i) + gtest.AssertNil(err) + gtest.Assert(one["passport"].String(), fmt.Sprintf(`t%d_new`, i)) + gtest.Assert(one["password"].String(), fmt.Sprintf(`p%d_new`, i)) + gtest.Assert(one["nickname"].String(), fmt.Sprintf(`T%d_new`, i)) }) } @@ -304,10 +323,10 @@ func Test_DB_TableFields(t *testing.T) { var expect = map[string][]any{ // []string: Index Type Null Key Default Comment // id is bigserial so the default is a pgsql function - "id": {0, "int8", false, "pri", fmt.Sprintf("nextval('%s_id_seq'::regclass)", table), ""}, - "passport": {1, "varchar", false, "", nil, ""}, - "password": {2, "varchar", false, "", nil, ""}, - "nickname": {3, "varchar", false, "", nil, ""}, + "id": {0, "int8(64)", false, "pri", fmt.Sprintf("nextval('%s_id_seq'::regclass)", table), ""}, + "passport": {1, "varchar(45)", false, "", nil, ""}, + "password": {2, "varchar(32)", false, "", nil, ""}, + "nickname": {3, "varchar(45)", false, "", nil, ""}, "create_time": {4, "timestamp", false, "", nil, ""}, } @@ -410,13 +429,13 @@ func Test_DB_TableFields_DuplicateConstraints(t *testing.T) { t.AssertNE(fields["id"], nil) t.Assert(fields["id"].Key, "pri") t.Assert(fields["id"].Name, "id") - t.Assert(fields["id"].Type, "int8") + t.Assert(fields["id"].Type, "int8(64)") // Verify email field has unique constraint t.AssertNE(fields["email"], nil) t.Assert(fields["email"].Key, "uni") t.Assert(fields["email"].Name, "email") - t.Assert(fields["email"].Type, "varchar") + t.Assert(fields["email"].Type, "varchar(100)") // Verify username field has no constraint t.AssertNE(fields["username"], nil) diff --git a/contrib/drivers/pgsql/pgsql_z_unit_field_test.go b/contrib/drivers/pgsql/pgsql_z_unit_field_test.go index 7c3df4ab0..22a8f25a2 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_field_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_field_test.go @@ -73,18 +73,18 @@ func Test_TableFields_Types(t *testing.T) { t.AssertNil(err) // Test integer type names - t.Assert(fields["col_int2"].Type, "int2") - t.Assert(fields["col_int4"].Type, "int4") - t.Assert(fields["col_int8"].Type, "int8") + t.Assert(fields["col_int2"].Type, "int2(16)") + t.Assert(fields["col_int4"].Type, "int4(32)") + t.Assert(fields["col_int8"].Type, "int8(64)") // Test float type names - t.Assert(fields["col_float4"].Type, "float4") - t.Assert(fields["col_float8"].Type, "float8") - t.Assert(fields["col_numeric"].Type, "numeric") + t.Assert(fields["col_float4"].Type, "float4(24)") + t.Assert(fields["col_float8"].Type, "float8(53)") + t.Assert(fields["col_numeric"].Type, "numeric(10)") // Test character type names - t.Assert(fields["col_char"].Type, "bpchar") - t.Assert(fields["col_varchar"].Type, "varchar") + t.Assert(fields["col_char"].Type, "bpchar(10)") + t.Assert(fields["col_varchar"].Type, "varchar(100)") t.Assert(fields["col_text"].Type, "text") // Test boolean type name diff --git a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go index 56db22a85..4ca729891 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_model_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_model_test.go @@ -334,14 +334,53 @@ func Test_Model_Replace(t *testing.T) { defer dropTable(table) gtest.C(t, func(t *gtest.T) { - _, err := db.Model(table).Data(g.Map{ + // Insert initial record + result, err := db.Model(table).Data(g.Map{ + "id": 1, + "passport": "t1", + "password": "pass1", + "nickname": "T1", + "create_time": "2018-10-24 10:00:00", + }).Insert() + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + // Replace with new data + result, err = db.Model(table).Data(g.Map{ "id": 1, "passport": "t11", "password": "25d55ad283aa400af464c76d713c07ad", "nickname": "T11", "create_time": "2018-10-24 10:00:00", }).Replace() - t.Assert(err, "Replace operation is not supported by pgsql driver") + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) + + // Verify the data was replaced + one, err := db.Model(table).Where("id", 1).One() + t.AssertNil(err) + t.Assert(one["passport"].String(), "t11") + t.Assert(one["password"].String(), "25d55ad283aa400af464c76d713c07ad") + t.Assert(one["nickname"].String(), "T11") + + // Replace with new ID (insert new record) + result, err = db.Model(table).Data(g.Map{ + "id": 2, + "passport": "t22", + "password": "pass22", + "nickname": "T22", + "create_time": "2018-10-24 11:00:00", + }).Replace() + t.AssertNil(err) + n, _ = result.RowsAffected() + t.Assert(n, 1) + + // Verify new record was inserted + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 2) }) } @@ -757,3 +796,69 @@ func Test_ConvertSliceFloat64(t *testing.T) { }) } } + +func Test_Model_InsertIgnore(t *testing.T) { + table := createTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + user := db.Model(table) + result, err := user.Data(g.Map{ + "id": 1, + "uid": 1, + "passport": "t1", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "name_1", + "create_time": gtime.Now().String(), + }).Insert() + t.AssertNil(err) + n, _ := result.RowsAffected() + t.Assert(n, 1) + + result, err = db.Model(table).Data(g.Map{ + "id": 1, + "uid": 1, + "passport": "t1", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "name_1", + "create_time": gtime.Now().String(), + }).Insert() + t.AssertNE(err, nil) + + result, err = db.Model(table).Data(g.Map{ + "id": 1, + "uid": 1, + "passport": "t2", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "name_2", + "create_time": gtime.Now().String(), + }).InsertIgnore() + t.AssertNil(err) + + n, _ = result.RowsAffected() + t.Assert(n, 0) + + value, err := db.Model(table).Fields("passport").WherePri(1).Value() + t.AssertNil(err) + t.Assert(value.String(), "t1") + + count, err := db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 1) + + // pgsql support ignore without primary key + result, err = db.Model(table).Data(g.Map{ + // "id": 1, + "uid": 1, + "passport": "t2", + "password": "25d55ad283aa400af464c76d713c07ad", + "nickname": "name_2", + "create_time": gtime.Now().String(), + }).InsertIgnore() + t.AssertNil(err) + + count, err = db.Model(table).Count() + t.AssertNil(err) + t.Assert(count, 1) + }) +} diff --git a/contrib/drivers/pgsql/pgsql_z_unit_upsert_test.go b/contrib/drivers/pgsql/pgsql_z_unit_upsert_test.go index a93017e97..edefab632 100644 --- a/contrib/drivers/pgsql/pgsql_z_unit_upsert_test.go +++ b/contrib/drivers/pgsql/pgsql_z_unit_upsert_test.go @@ -219,10 +219,10 @@ func Test_FormatUpsert_NoOnConflict(t *testing.T) { }).Insert() t.AssertNil(err) - // Try Save without OnConflict - should fail for pgsql - // PostgreSQL requires OnConflict() for Save() operations, unlike MySQL + // Try Save without OnConflict and without primary key in data - should fail + // because driver cannot auto-detect conflict columns when primary key is missing _, err = db.Model(table).Data(g.Map{ - "id": 1, + // "id": 1, "passport": "no_conflict_user", "password": "newpwd", "nickname": "newnick", diff --git a/contrib/drivers/sqlitecgo/sqlite_format_upsert.go b/contrib/drivers/sqlitecgo/sqlitecgo_format_upsert.go similarity index 100% rename from contrib/drivers/sqlitecgo/sqlite_format_upsert.go rename to contrib/drivers/sqlitecgo/sqlitecgo_format_upsert.go diff --git a/database/gdb/gdb_core_utility.go b/database/gdb/gdb_core_utility.go index 4872f13a1..b97d7431e 100644 --- a/database/gdb/gdb_core_utility.go +++ b/database/gdb/gdb_core_utility.go @@ -10,6 +10,7 @@ package gdb import ( "context" "fmt" + "strings" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" @@ -251,3 +252,22 @@ func (c *Core) guessPrimaryTableName(tableStr string) string { } return guessedTableName } + +// GetPrimaryKeys retrieves and returns the primary key field names of the specified table. +// This method extracts primary key information from TableFields. +// The parameter `schema` is optional, if not specified it uses the default schema. +func (c *Core) GetPrimaryKeys(ctx context.Context, table string, schema ...string) ([]string, error) { + tableFields, err := c.db.TableFields(ctx, table, schema...) + if err != nil { + return nil, err + } + + var primaryKeys []string + for _, field := range tableFields { + if strings.EqualFold(field.Key, "pri") { + primaryKeys = append(primaryKeys, field.Name) + } + } + + return primaryKeys, nil +} diff --git a/database/gdb/gdb_driver_wrapper_db.go b/database/gdb/gdb_driver_wrapper_db.go index 7dbc1d0ce..81c5b729c 100644 --- a/database/gdb/gdb_driver_wrapper_db.go +++ b/database/gdb/gdb_driver_wrapper_db.go @@ -109,7 +109,17 @@ func (d *DriverWrapperDB) TableFields( // InsertOptionReplace: if there's unique/primary key in the data, it deletes it from table and inserts a new one; // InsertOptionSave: if there's unique/primary key in the data, it updates it or else inserts a new one; // InsertOptionIgnore: if there's unique/primary key in the data, it ignores the inserting; -func (d *DriverWrapperDB) DoInsert(ctx context.Context, link Link, table string, list List, option DoInsertOption) (result sql.Result, err error) { +func (d *DriverWrapperDB) DoInsert( + ctx context.Context, link Link, table string, list List, option DoInsertOption, +) (result sql.Result, err error) { + if len(list) == 0 { + return nil, gerror.NewCodef( + gcode.CodeInvalidRequest, + `data list is empty for %s operation`, + GetInsertOperationByOption(option.InsertOption), + ) + } + // Convert data type before commit it to underlying db driver. for i, item := range list { list[i], err = d.GetCore().ConvertDataForRecord(ctx, item, table)