add last insert id support for pgsql (#1994)

This commit is contained in:
HaiLaz
2022-08-09 19:45:05 +08:00
committed by GitHub
parent 4ded89d453
commit 95888e0b77
3 changed files with 161 additions and 6 deletions

View File

@ -7,7 +7,6 @@
// Note:
// 1. It needs manually import: _ "github.com/lib/pq"
// 2. It does not support Save/Replace features.
// 3. It does not support LastInsertId.
// Package pgsql implements gdb.Driver, which supports operations for PostgreSql.
package pgsql
@ -18,15 +17,15 @@ import (
"fmt"
"strings"
"github.com/gogf/gf/v2/util/gconv"
_ "github.com/lib/pq"
"github.com/gogf/gf/v2/container/gmap"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
_ "github.com/lib/pq"
)
// Driver is the driver for postgresql database.
@ -39,6 +38,10 @@ var (
tableFieldsMap = gmap.New(true)
)
const (
internalPrimaryKeyInCtx gctx.StrKey = "primary_key"
)
func init() {
if err := gdb.Register(`pgsql`, New()); err != nil {
panic(err)
@ -310,7 +313,111 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list
`Replace operation is not supported by pgsql driver`,
)
default:
return d.Core.DoInsert(ctx, link, table, list, option)
case gdb.InsertOptionIgnore:
return nil, gerror.NewCode(
gcode.CodeNotSupported,
`Insert ignore operation is not supported by pgsql driver`,
)
case gdb.InsertOptionDefault:
tableFields, err := d.TableFields(ctx, table)
if err == nil {
for _, field := range tableFields {
if field.Key == "pri" {
pkField := *field
ctx = context.WithValue(ctx, internalPrimaryKeyInCtx, pkField)
}
}
}
}
return d.Core.DoInsert(ctx, link, table, list, option)
}
func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...interface{}) (result sql.Result, err error) {
var (
isUseCoreDoExec bool = false // Check whether the default method needs to be used
primaryKey string = ""
pkField gdb.TableField
)
// Transaction checks.
if link == nil {
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
isUseCoreDoExec = true
}
} else if link.IsTransaction() {
isUseCoreDoExec = true
}
if value := ctx.Value(internalPrimaryKeyInCtx); value != nil {
var ok bool
pkField, ok = value.(gdb.TableField)
if !ok {
isUseCoreDoExec = true
}
} else {
isUseCoreDoExec = true
}
// check if it is a insert operation.
if !isUseCoreDoExec && pkField.Name != "" && strings.Contains(sql, "INSERT INTO") {
primaryKey = pkField.Name
sql += " RETURNING " + primaryKey
} else {
// use default DoExec
return d.Core.DoExec(ctx, link, sql, args...)
}
// Only the insert operation with primary key can execute the following code
if d.GetConfig().ExecTimeout > 0 {
var cancelFunc context.CancelFunc
ctx, cancelFunc = context.WithTimeout(ctx, d.GetConfig().ExecTimeout)
defer cancelFunc()
}
// Sql filtering.
// TODO: internal function formatSql
// sql, args = formatSql(sql, args)
sql, args, err = d.DoFilter(ctx, link, sql, args)
if err != nil {
return nil, err
}
// Link execution.
var out gdb.DoCommitOutput
out, err = d.DoCommit(ctx, gdb.DoCommitInput{
Link: link,
Sql: sql,
Args: args,
Stmt: nil,
Type: gdb.SqlTypeQueryContext,
IsTransaction: link.IsTransaction(),
})
if err != nil {
return nil, err
}
affected := len(out.Records)
if affected > 0 {
if !strings.Contains(pkField.Type, "int") {
return Result{
affected: int64(affected),
lastInsertId: 0,
lastInsertIdError: gerror.NewCodef(
gcode.CodeNotSupported,
"LastInsertId is not supported by primary key type: %s", pkField.Type),
}, nil
}
if out.Records[affected-1][primaryKey] != nil {
lastInsertId := out.Records[affected-1][primaryKey].Int()
return Result{
affected: int64(affected),
lastInsertId: int64(lastInsertId),
}, nil
}
}
return Result{}, nil
}

View File

@ -0,0 +1,18 @@
package pgsql
import "database/sql"
type Result struct {
sql.Result
affected int64
lastInsertId int64
lastInsertIdError error
}
func (pgr Result) RowsAffected() (int64, error) {
return pgr.affected, nil
}
func (pgr Result) LastInsertId() (int64, error) {
return pgr.lastInsertId, pgr.lastInsertIdError
}

View File

@ -15,6 +15,36 @@ import (
"github.com/gogf/gf/v2/test/gtest"
)
func Test_LastInsertId(t *testing.T) {
// err not nil
gtest.C(t, func(t *gtest.T) {
_, err := db.Model("notexist").Insert(g.List{
{"name": "user1"},
{"name": "user2"},
{"name": "user3"},
})
t.AssertNE(err, nil)
})
gtest.C(t, func(t *gtest.T) {
tableName := createTable()
defer dropTable(tableName)
res, err := db.Model(tableName).Insert(g.List{
{"passport": "user1", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
{"passport": "user2", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
{"passport": "user3", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
})
t.Assert(err, nil)
lastInsertId, err := res.LastInsertId()
t.Assert(err, nil)
t.Assert(lastInsertId, int64(3))
rowsAffected, err := res.RowsAffected()
t.Assert(err, nil)
t.Assert(rowsAffected, int64(3))
})
}
func Test_Driver_DoFilter(t *testing.T) {
var (
ctx = gctx.New()