diff --git a/contrib/drivers/pgsql/pgsql.go b/contrib/drivers/pgsql/pgsql.go index 11552ebf2..3befe50d0 100644 --- a/contrib/drivers/pgsql/pgsql.go +++ b/contrib/drivers/pgsql/pgsql.go @@ -7,7 +7,6 @@ // Note: // 1. It needs manually import: _ "github.com/lib/pq" // 2. It does not support Save/Replace features. -// 3. It does not support LastInsertId. // Package pgsql implements gdb.Driver, which supports operations for PostgreSql. package pgsql @@ -18,15 +17,15 @@ import ( "fmt" "strings" - "github.com/gogf/gf/v2/util/gconv" - _ "github.com/lib/pq" - "github.com/gogf/gf/v2/container/gmap" "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" + "github.com/gogf/gf/v2/os/gctx" "github.com/gogf/gf/v2/text/gregex" "github.com/gogf/gf/v2/text/gstr" + "github.com/gogf/gf/v2/util/gconv" + _ "github.com/lib/pq" ) // Driver is the driver for postgresql database. @@ -39,6 +38,10 @@ var ( tableFieldsMap = gmap.New(true) ) +const ( + internalPrimaryKeyInCtx gctx.StrKey = "primary_key" +) + func init() { if err := gdb.Register(`pgsql`, New()); err != nil { panic(err) @@ -310,7 +313,111 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list `Replace operation is not supported by pgsql driver`, ) - default: - return d.Core.DoInsert(ctx, link, table, list, option) + case gdb.InsertOptionIgnore: + return nil, gerror.NewCode( + gcode.CodeNotSupported, + `Insert ignore operation is not supported by pgsql driver`, + ) + + case gdb.InsertOptionDefault: + tableFields, err := d.TableFields(ctx, table) + if err == nil { + for _, field := range tableFields { + if field.Key == "pri" { + pkField := *field + ctx = context.WithValue(ctx, internalPrimaryKeyInCtx, pkField) + } + } + } } + return d.Core.DoInsert(ctx, link, table, list, option) +} + +func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...interface{}) (result sql.Result, err error) { + var ( + isUseCoreDoExec bool = false // Check whether the default method needs to be used + primaryKey string = "" + pkField gdb.TableField + ) + + // Transaction checks. + if link == nil { + if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil { + isUseCoreDoExec = true + } + } else if link.IsTransaction() { + isUseCoreDoExec = true + } + + if value := ctx.Value(internalPrimaryKeyInCtx); value != nil { + var ok bool + pkField, ok = value.(gdb.TableField) + if !ok { + isUseCoreDoExec = true + } + } else { + isUseCoreDoExec = true + } + + // check if it is a insert operation. + if !isUseCoreDoExec && pkField.Name != "" && strings.Contains(sql, "INSERT INTO") { + primaryKey = pkField.Name + sql += " RETURNING " + primaryKey + } else { + // use default DoExec + return d.Core.DoExec(ctx, link, sql, args...) + } + + // Only the insert operation with primary key can execute the following code + + if d.GetConfig().ExecTimeout > 0 { + var cancelFunc context.CancelFunc + ctx, cancelFunc = context.WithTimeout(ctx, d.GetConfig().ExecTimeout) + defer cancelFunc() + } + + // Sql filtering. + // TODO: internal function formatSql + // sql, args = formatSql(sql, args) + sql, args, err = d.DoFilter(ctx, link, sql, args) + if err != nil { + return nil, err + } + + // Link execution. + var out gdb.DoCommitOutput + out, err = d.DoCommit(ctx, gdb.DoCommitInput{ + Link: link, + Sql: sql, + Args: args, + Stmt: nil, + Type: gdb.SqlTypeQueryContext, + IsTransaction: link.IsTransaction(), + }) + + if err != nil { + return nil, err + } + affected := len(out.Records) + if affected > 0 { + if !strings.Contains(pkField.Type, "int") { + return Result{ + affected: int64(affected), + lastInsertId: 0, + lastInsertIdError: gerror.NewCodef( + gcode.CodeNotSupported, + "LastInsertId is not supported by primary key type: %s", pkField.Type), + }, nil + } + + if out.Records[affected-1][primaryKey] != nil { + lastInsertId := out.Records[affected-1][primaryKey].Int() + return Result{ + affected: int64(affected), + lastInsertId: int64(lastInsertId), + }, nil + } + } + + return Result{}, nil } diff --git a/contrib/drivers/pgsql/pgsql_result.go b/contrib/drivers/pgsql/pgsql_result.go new file mode 100644 index 000000000..a290d45f9 --- /dev/null +++ b/contrib/drivers/pgsql/pgsql_result.go @@ -0,0 +1,18 @@ +package pgsql + +import "database/sql" + +type Result struct { + sql.Result + affected int64 + lastInsertId int64 + lastInsertIdError error +} + +func (pgr Result) RowsAffected() (int64, error) { + return pgr.affected, nil +} + +func (pgr Result) LastInsertId() (int64, error) { + return pgr.lastInsertId, pgr.lastInsertIdError +} diff --git a/contrib/drivers/pgsql/pgsql_test.go b/contrib/drivers/pgsql/pgsql_test.go index 481d666cf..40a5f329f 100644 --- a/contrib/drivers/pgsql/pgsql_test.go +++ b/contrib/drivers/pgsql/pgsql_test.go @@ -15,6 +15,36 @@ import ( "github.com/gogf/gf/v2/test/gtest" ) +func Test_LastInsertId(t *testing.T) { + + // err not nil + gtest.C(t, func(t *gtest.T) { + _, err := db.Model("notexist").Insert(g.List{ + {"name": "user1"}, + {"name": "user2"}, + {"name": "user3"}, + }) + t.AssertNE(err, nil) + }) + + gtest.C(t, func(t *gtest.T) { + tableName := createTable() + defer dropTable(tableName) + res, err := db.Model(tableName).Insert(g.List{ + {"passport": "user1", "password": "pwd", "nickname": "nickname", "create_time": CreateTime}, + {"passport": "user2", "password": "pwd", "nickname": "nickname", "create_time": CreateTime}, + {"passport": "user3", "password": "pwd", "nickname": "nickname", "create_time": CreateTime}, + }) + t.Assert(err, nil) + lastInsertId, err := res.LastInsertId() + t.Assert(err, nil) + t.Assert(lastInsertId, int64(3)) + rowsAffected, err := res.RowsAffected() + t.Assert(err, nil) + t.Assert(rowsAffected, int64(3)) + }) +} + func Test_Driver_DoFilter(t *testing.T) { var ( ctx = gctx.New()