mirror of
https://gitee.com/johng/gf
synced 2026-06-06 16:21:40 +08:00
add last insert id support for pgsql (#1994)
This commit is contained in:
@ -7,7 +7,6 @@
|
||||
// Note:
|
||||
// 1. It needs manually import: _ "github.com/lib/pq"
|
||||
// 2. It does not support Save/Replace features.
|
||||
// 3. It does not support LastInsertId.
|
||||
|
||||
// Package pgsql implements gdb.Driver, which supports operations for PostgreSql.
|
||||
package pgsql
|
||||
@ -18,15 +17,15 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gmap"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/errors/gcode"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/os/gctx"
|
||||
"github.com/gogf/gf/v2/text/gregex"
|
||||
"github.com/gogf/gf/v2/text/gstr"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// Driver is the driver for postgresql database.
|
||||
@ -39,6 +38,10 @@ var (
|
||||
tableFieldsMap = gmap.New(true)
|
||||
)
|
||||
|
||||
const (
|
||||
internalPrimaryKeyInCtx gctx.StrKey = "primary_key"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if err := gdb.Register(`pgsql`, New()); err != nil {
|
||||
panic(err)
|
||||
@ -310,7 +313,111 @@ func (d *Driver) DoInsert(ctx context.Context, link gdb.Link, table string, list
|
||||
`Replace operation is not supported by pgsql driver`,
|
||||
)
|
||||
|
||||
default:
|
||||
return d.Core.DoInsert(ctx, link, table, list, option)
|
||||
case gdb.InsertOptionIgnore:
|
||||
return nil, gerror.NewCode(
|
||||
gcode.CodeNotSupported,
|
||||
`Insert ignore operation is not supported by pgsql driver`,
|
||||
)
|
||||
|
||||
case gdb.InsertOptionDefault:
|
||||
tableFields, err := d.TableFields(ctx, table)
|
||||
if err == nil {
|
||||
for _, field := range tableFields {
|
||||
if field.Key == "pri" {
|
||||
pkField := *field
|
||||
ctx = context.WithValue(ctx, internalPrimaryKeyInCtx, pkField)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return d.Core.DoInsert(ctx, link, table, list, option)
|
||||
}
|
||||
|
||||
func (d *Driver) DoExec(ctx context.Context, link gdb.Link, sql string, args ...interface{}) (result sql.Result, err error) {
|
||||
var (
|
||||
isUseCoreDoExec bool = false // Check whether the default method needs to be used
|
||||
primaryKey string = ""
|
||||
pkField gdb.TableField
|
||||
)
|
||||
|
||||
// Transaction checks.
|
||||
if link == nil {
|
||||
if tx := gdb.TXFromCtx(ctx, d.GetGroup()); tx != nil {
|
||||
isUseCoreDoExec = true
|
||||
}
|
||||
} else if link.IsTransaction() {
|
||||
isUseCoreDoExec = true
|
||||
}
|
||||
|
||||
if value := ctx.Value(internalPrimaryKeyInCtx); value != nil {
|
||||
var ok bool
|
||||
pkField, ok = value.(gdb.TableField)
|
||||
if !ok {
|
||||
isUseCoreDoExec = true
|
||||
}
|
||||
} else {
|
||||
isUseCoreDoExec = true
|
||||
}
|
||||
|
||||
// check if it is a insert operation.
|
||||
if !isUseCoreDoExec && pkField.Name != "" && strings.Contains(sql, "INSERT INTO") {
|
||||
primaryKey = pkField.Name
|
||||
sql += " RETURNING " + primaryKey
|
||||
} else {
|
||||
// use default DoExec
|
||||
return d.Core.DoExec(ctx, link, sql, args...)
|
||||
}
|
||||
|
||||
// Only the insert operation with primary key can execute the following code
|
||||
|
||||
if d.GetConfig().ExecTimeout > 0 {
|
||||
var cancelFunc context.CancelFunc
|
||||
ctx, cancelFunc = context.WithTimeout(ctx, d.GetConfig().ExecTimeout)
|
||||
defer cancelFunc()
|
||||
}
|
||||
|
||||
// Sql filtering.
|
||||
// TODO: internal function formatSql
|
||||
// sql, args = formatSql(sql, args)
|
||||
sql, args, err = d.DoFilter(ctx, link, sql, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Link execution.
|
||||
var out gdb.DoCommitOutput
|
||||
out, err = d.DoCommit(ctx, gdb.DoCommitInput{
|
||||
Link: link,
|
||||
Sql: sql,
|
||||
Args: args,
|
||||
Stmt: nil,
|
||||
Type: gdb.SqlTypeQueryContext,
|
||||
IsTransaction: link.IsTransaction(),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
affected := len(out.Records)
|
||||
if affected > 0 {
|
||||
if !strings.Contains(pkField.Type, "int") {
|
||||
return Result{
|
||||
affected: int64(affected),
|
||||
lastInsertId: 0,
|
||||
lastInsertIdError: gerror.NewCodef(
|
||||
gcode.CodeNotSupported,
|
||||
"LastInsertId is not supported by primary key type: %s", pkField.Type),
|
||||
}, nil
|
||||
}
|
||||
|
||||
if out.Records[affected-1][primaryKey] != nil {
|
||||
lastInsertId := out.Records[affected-1][primaryKey].Int()
|
||||
return Result{
|
||||
affected: int64(affected),
|
||||
lastInsertId: int64(lastInsertId),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return Result{}, nil
|
||||
}
|
||||
|
||||
18
contrib/drivers/pgsql/pgsql_result.go
Normal file
18
contrib/drivers/pgsql/pgsql_result.go
Normal file
@ -0,0 +1,18 @@
|
||||
package pgsql
|
||||
|
||||
import "database/sql"
|
||||
|
||||
type Result struct {
|
||||
sql.Result
|
||||
affected int64
|
||||
lastInsertId int64
|
||||
lastInsertIdError error
|
||||
}
|
||||
|
||||
func (pgr Result) RowsAffected() (int64, error) {
|
||||
return pgr.affected, nil
|
||||
}
|
||||
|
||||
func (pgr Result) LastInsertId() (int64, error) {
|
||||
return pgr.lastInsertId, pgr.lastInsertIdError
|
||||
}
|
||||
@ -15,6 +15,36 @@ import (
|
||||
"github.com/gogf/gf/v2/test/gtest"
|
||||
)
|
||||
|
||||
func Test_LastInsertId(t *testing.T) {
|
||||
|
||||
// err not nil
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
_, err := db.Model("notexist").Insert(g.List{
|
||||
{"name": "user1"},
|
||||
{"name": "user2"},
|
||||
{"name": "user3"},
|
||||
})
|
||||
t.AssertNE(err, nil)
|
||||
})
|
||||
|
||||
gtest.C(t, func(t *gtest.T) {
|
||||
tableName := createTable()
|
||||
defer dropTable(tableName)
|
||||
res, err := db.Model(tableName).Insert(g.List{
|
||||
{"passport": "user1", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
|
||||
{"passport": "user2", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
|
||||
{"passport": "user3", "password": "pwd", "nickname": "nickname", "create_time": CreateTime},
|
||||
})
|
||||
t.Assert(err, nil)
|
||||
lastInsertId, err := res.LastInsertId()
|
||||
t.Assert(err, nil)
|
||||
t.Assert(lastInsertId, int64(3))
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
t.Assert(err, nil)
|
||||
t.Assert(rowsAffected, int64(3))
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Driver_DoFilter(t *testing.T) {
|
||||
var (
|
||||
ctx = gctx.New()
|
||||
|
||||
Reference in New Issue
Block a user