add hook feature for model of package gdb

This commit is contained in:
John Guo
2022-03-14 23:47:55 +08:00
parent d58186372f
commit dccfc1c8cd
13 changed files with 415 additions and 52 deletions

View File

@ -20,6 +20,7 @@ import (
"github.com/gogf/gf/v2/internal/intlog"
"github.com/gogf/gf/v2/os/gcache"
"github.com/gogf/gf/v2/os/gcmd"
"github.com/gogf/gf/v2/os/gctx"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/util/grand"
)
@ -215,9 +216,6 @@ type Driver interface {
// Link is a common database function wrapper interface.
type Link interface {
Query(sql string, args ...interface{}) (*sql.Rows, error)
Exec(sql string, args ...interface{}) (sql.Result, error)
Prepare(sql string) (*sql.Stmt, error)
QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error)
ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error)
@ -287,9 +285,10 @@ const (
ctxTimeoutTypeExec = iota
ctxTimeoutTypeQuery
ctxTimeoutTypePrepare
commandEnvKeyForDryRun = "gf.gdb.dryrun"
modelForDaoSuffix = `ForDao`
dbRoleSlave = `slave`
commandEnvKeyForDryRun = "gf.gdb.dryrun"
modelForDaoSuffix = `ForDao`
dbRoleSlave = `slave`
contextKeyForDB gctx.StrKey = `DBInContext`
)
const (

View File

@ -45,7 +45,6 @@ func (c *Core) Ctx(ctx context.Context) DB {
configNode = c.db.GetConfig()
)
*newCore = *c
newCore.ctx = ctx
// It creates a new DB object, which is commonly a wrapper for object `Core`.
newCore.db, err = driverMap[configNode.Type].New(newCore, configNode)
if err != nil {
@ -53,6 +52,7 @@ func (c *Core) Ctx(ctx context.Context) DB {
// Do not let it continue.
panic(err)
}
newCore.ctx = WithDB(ctx, newCore.db)
return newCore.db
}

View File

@ -8,12 +8,39 @@
package gdb
import (
"context"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gregex"
"github.com/gogf/gf/v2/text/gstr"
)
// WithDB injects given db object into context and returns a new context.
func WithDB(ctx context.Context, db DB) context.Context {
if db == nil {
return ctx
}
dbCtx := db.GetCtx()
if ctxDb := DBFromCtx(dbCtx); ctxDb != nil {
return dbCtx
}
ctx = context.WithValue(ctx, contextKeyForDB, db)
return ctx
}
// DBFromCtx retrieves and returns DB object from context.
func DBFromCtx(ctx context.Context) DB {
if ctx == nil {
return nil
}
v := ctx.Value(contextKeyForDB)
if v != nil {
return v.(DB)
}
return nil
}
// MasterLink acts like function Master but with additional `schema` parameter specifying
// the schema for the connection. It is defined for internal usage.
// Also see Master.

View File

@ -44,6 +44,7 @@ type Model struct {
lockInfo string // Lock for update or in shared lock.
cacheEnabled bool // Enable sql result cache feature, which is mainly for indicating cache duration(especially 0) usage.
cacheOption CacheOption // Cache option for query statement.
hook HookHandler // Hook functions for model hook feature.
unscoped bool // Disables soft deleting features when select/delete operations.
safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model.
onDuplicate interface{} // onDuplicate is used for ON "DUPLICATE KEY UPDATE" statement.

View File

@ -34,18 +34,40 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) {
)
// Soft deleting.
if !m.unscoped && fieldNameDelete != "" {
return m.db.DoUpdate(
m.GetCtx(),
m.getLink(true),
m.tables,
fmt.Sprintf(`%s=?`, m.db.GetCore().QuoteString(fieldNameDelete)),
conditionWhere+conditionExtra,
append([]interface{}{gtime.Now().String()}, conditionArgs...),
)
in := &HookUpdateInput{
internalParamHookUpdate: internalParamHookUpdate{
internalParamHook: internalParamHook{
db: m.db,
link: m.getLink(true),
},
handler: m.hook.Update,
},
Table: m.tables,
Data: fmt.Sprintf(`%s=?`, m.db.GetCore().QuoteString(fieldNameDelete)),
Condition: conditionWhere + conditionExtra,
Args: append([]interface{}{gtime.Now().String()}, conditionArgs...),
}
return in.Next(m.GetCtx())
}
conditionStr := conditionWhere + conditionExtra
if !gstr.ContainsI(conditionStr, " WHERE ") {
return nil, gerror.NewCode(gcode.CodeMissingParameter, "there should be WHERE condition statement for DELETE operation")
return nil, gerror.NewCode(
gcode.CodeMissingParameter,
"there should be WHERE condition statement for DELETE operation",
)
}
return m.db.DoDelete(m.GetCtx(), m.getLink(true), m.tables, conditionStr, conditionArgs...)
in := &HookDeleteInput{
internalParamHookDelete: internalParamHookDelete{
internalParamHook: internalParamHook{
db: m.db,
link: m.getLink(true),
},
handler: m.hook.Delete,
},
Table: m.tables,
Condition: conditionStr,
Args: conditionArgs,
}
return in.Next(m.GetCtx())
}

View File

@ -0,0 +1,148 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gdb
import (
"context"
"database/sql"
"github.com/gogf/gf/v2/text/gstr"
)
type (
HookFuncSelect func(ctx context.Context, in *HookSelectInput) (result Result, err error)
HookFuncInsert func(ctx context.Context, in *HookInsertInput) (result sql.Result, err error)
HookFuncUpdate func(ctx context.Context, in *HookUpdateInput) (result sql.Result, err error)
HookFuncDelete func(ctx context.Context, in *HookDeleteInput) (result sql.Result, err error)
)
// HookHandler manages all supported hook functions for Model.
type HookHandler struct {
Select HookFuncSelect
Insert HookFuncInsert
Update HookFuncUpdate
Delete HookFuncDelete
}
// internalParamHook manages all internal parameters for hook operations.
// The `internal` obviously means you cannot access these parameters outside this package.
type internalParamHook struct {
db DB // Underlying DB object.
link Link // Connection object from third party sql driver.
handlerCalled bool // Simple mark for custom handler called, in case of recursive calling.
removedWhere bool // Removed mark for condition string that was removed `WHERE` prefix.
}
type internalParamHookSelect struct {
internalParamHook
handler HookFuncSelect
}
type internalParamHookInsert struct {
internalParamHook
handler HookFuncInsert
}
type internalParamHookUpdate struct {
internalParamHook
handler HookFuncUpdate
}
type internalParamHookDelete struct {
internalParamHook
handler HookFuncDelete
}
// HookSelectInput holds the parameters for select hook operation.
type HookSelectInput struct {
internalParamHookSelect
Table string
Sql string
Args []interface{}
}
// HookInsertInput holds the parameters for insert hook operation.
type HookInsertInput struct {
internalParamHookInsert
Table string
Data List
Option DoInsertOption
}
// HookUpdateInput holds the parameters for update hook operation.
type HookUpdateInput struct {
internalParamHookUpdate
Table string
Data interface{}
Condition string
Args []interface{}
}
// HookDeleteInput holds the parameters for delete hook operation.
type HookDeleteInput struct {
internalParamHookDelete
Table string
Condition string
Args []interface{}
}
// Next calls the next hook handler.
func (h *HookSelectInput) Next(ctx context.Context) (result Result, err error) {
if h.handler != nil && !h.handlerCalled {
h.handlerCalled = true
return h.handler(ctx, h)
}
return h.db.DoSelect(ctx, h.link, h.Sql, h.Args...)
}
// Next calls the next hook handler.
func (h *HookInsertInput) Next(ctx context.Context) (result sql.Result, err error) {
if h.handler != nil && !h.handlerCalled {
h.handlerCalled = true
return h.handler(ctx, h)
}
return h.db.DoInsert(ctx, h.link, h.Table, h.Data, h.Option)
}
// Next calls the next hook handler.
func (h *HookUpdateInput) Next(ctx context.Context) (result sql.Result, err error) {
if h.handler != nil && !h.handlerCalled {
h.handlerCalled = true
if gstr.HasPrefix(h.Condition, " WHERE ") {
h.removedWhere = true
h.Condition = gstr.TrimLeftStr(h.Condition, " WHERE ")
}
return h.handler(ctx, h)
}
if h.removedWhere {
h.Condition = " WHERE " + h.Condition
}
return h.db.DoUpdate(ctx, h.link, h.Table, h.Data, h.Condition, h.Args...)
}
// Next calls the next hook handler.
func (h *HookDeleteInput) Next(ctx context.Context) (result sql.Result, err error) {
if h.handler != nil && !h.handlerCalled {
h.handlerCalled = true
if gstr.HasPrefix(h.Condition, " WHERE ") {
h.removedWhere = true
h.Condition = gstr.TrimLeftStr(h.Condition, " WHERE ")
}
return h.handler(ctx, h)
}
if h.removedWhere {
h.Condition = " WHERE " + h.Condition
}
return h.db.DoDelete(ctx, h.link, h.Table, h.Condition, h.Args...)
}
// Hook sets the hook functions for current model.
func (m *Model) Hook(hook HookHandler) *Model {
model := m.getModel()
model.hook = hook
return model
}

View File

@ -310,7 +310,20 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err
if err != nil {
return result, err
}
return m.db.DoInsert(m.GetCtx(), m.getLink(true), m.tables, list, doInsertOption)
in := &HookInsertInput{
internalParamHookInsert: internalParamHookInsert{
internalParamHook: internalParamHook{
db: m.db,
link: m.getLink(true),
},
handler: m.hook.Insert,
},
Table: m.tables,
Data: list,
Option: doInsertOption,
}
return in.Next(m.GetCtx())
}
func (m *Model) formatDoInsertOption(insertOption int, columnNames []string) (option DoInsertOption, err error) {

View File

@ -532,9 +532,21 @@ func (m *Model) doGetAllBySql(sql string, args ...interface{}) (result Result, e
}
}
}
result, err = m.db.DoSelect(
m.GetCtx(), m.getLink(false), sql, m.mergeArguments(args)...,
)
in := &HookSelectInput{
internalParamHookSelect: internalParamHookSelect{
internalParamHook: internalParamHook{
db: m.db,
link: m.getLink(false),
},
handler: m.hook.Select,
},
Table: m.tables,
Sql: sql,
Args: m.mergeArguments(args),
}
result, err = in.Next(m.GetCtx())
// Cache the result.
if cacheKey != "" && err == nil {
if m.cacheOption.Duration < 0 {

View File

@ -74,14 +74,21 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro
if !gstr.ContainsI(conditionStr, " WHERE ") {
return nil, gerror.NewCode(gcode.CodeMissingParameter, "there should be WHERE condition statement for UPDATE operation")
}
return m.db.DoUpdate(
m.GetCtx(),
m.getLink(true),
m.tables,
newData,
conditionStr,
m.mergeArguments(conditionArgs)...,
)
in := &HookUpdateInput{
internalParamHookUpdate: internalParamHookUpdate{
internalParamHook: internalParamHook{
db: m.db,
link: m.getLink(true),
},
handler: m.hook.Update,
},
Table: m.tables,
Data: newData,
Condition: conditionStr,
Args: m.mergeArguments(conditionArgs),
}
return in.Next(m.GetCtx())
}
// Increment increments a column's value by a given amount.

View File

@ -143,7 +143,7 @@ func (m *Model) doWithScanStruct(pointer interface{}) error {
}
// Recursively with feature checks.
model = m.db.With(field.Value)
model = m.db.With(field.Value).Hook(m.hook)
if m.withAll {
model = model.WithAll()
} else {
@ -258,7 +258,7 @@ func (m *Model) doWithScanStructs(pointer interface{}) error {
fieldKeys = structType.FieldKeys()
}
// Recursively with feature checks.
model = m.db.With(field.Value)
model = m.db.With(field.Value).Hook(m.hook)
if m.withAll {
model = model.WithAll()
} else {

View File

@ -0,0 +1,136 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gdb_test
import (
"context"
"database/sql"
"fmt"
"testing"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/test/gtest"
)
func Test_Model_Hook_Select(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
m := db.Model(table).Hook(gdb.HookHandler{
Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) {
result, err = in.Next(ctx)
if err != nil {
return
}
for i, record := range result {
record["test"] = gvar.New(100 + record["id"].Int())
result[i] = record
}
return
},
})
all, err := m.Where(`id > 6`).OrderAsc(`id`).All()
t.AssertNil(err)
t.Assert(len(all), 4)
t.Assert(all[0]["id"].Int(), 7)
t.Assert(all[0]["test"].Int(), 107)
t.Assert(all[1]["test"].Int(), 108)
t.Assert(all[2]["test"].Int(), 109)
t.Assert(all[3]["test"].Int(), 110)
})
}
func Test_Model_Hook_Insert(t *testing.T) {
table := createTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
m := db.Model(table).Hook(gdb.HookHandler{
Insert: func(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) {
for i, item := range in.Data {
item["passport"] = fmt.Sprintf(`test_port_%d`, item["id"])
item["nickname"] = fmt.Sprintf(`test_name_%d`, item["id"])
in.Data[i] = item
}
return in.Next(ctx)
},
})
_, err := m.Insert(g.Map{
"id": 1,
"nickname": "name_1",
})
t.AssertNil(err)
one, err := m.One()
t.AssertNil(err)
t.Assert(one["id"].Int(), 1)
t.Assert(one["passport"], `test_port_1`)
t.Assert(one["nickname"], `test_name_1`)
})
}
func Test_Model_Hook_Update(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
m := db.Model(table).Hook(gdb.HookHandler{
Update: func(ctx context.Context, in *gdb.HookUpdateInput) (result sql.Result, err error) {
switch value := in.Data.(type) {
case gdb.List:
for i, data := range value {
data["passport"] = `port`
data["nickname"] = `name`
value[i] = data
}
in.Data = value
case gdb.Map:
value["passport"] = `port`
value["nickname"] = `name`
in.Data = value
}
return in.Next(ctx)
},
})
_, err := m.Data(g.Map{
"nickname": "name_1",
}).WherePri(1).Update()
t.AssertNil(err)
one, err := m.One()
t.AssertNil(err)
t.Assert(one["id"].Int(), 1)
t.Assert(one["passport"], `port`)
t.Assert(one["nickname"], `name`)
})
}
func Test_Model_Hook_Delete(t *testing.T) {
table := createInitTable()
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
m := db.Model(table).Hook(gdb.HookHandler{
Delete: func(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result, err error) {
return db.Model(table).Data(g.Map{
"nickname": `deleted`,
}).Where(in.Condition).Update()
},
})
_, err := m.Where(1).Delete()
t.AssertNil(err)
all, err := m.All()
t.AssertNil(err)
for _, item := range all {
t.Assert(item["nickname"].String(), `deleted`)
}
})
}

View File

@ -117,7 +117,6 @@ func (l *Logger) print(ctx context.Context, level int, values ...interface{}) {
input = &HandlerInput{
Logger: l,
Buffer: bytes.NewBuffer(nil),
Ctx: ctx,
Time: now,
Color: defaultLevelColor[level],
Level: level,
@ -221,13 +220,13 @@ func (l *Logger) print(ctx context.Context, level int, values ...interface{}) {
if l.config.Flags&F_ASYNC > 0 {
input.IsAsync = true
err := asyncPool.Add(ctx, func(ctx context.Context) {
input.Next()
input.Next(ctx)
})
if err != nil {
intlog.Errorf(ctx, `%+v`, err)
}
} else {
input.Next()
input.Next(ctx)
}
}

View File

@ -17,30 +17,29 @@ type Handler func(ctx context.Context, in *HandlerInput)
// HandlerInput is the input parameter struct for logging Handler.
type HandlerInput struct {
Logger *Logger // Logger.
Ctx context.Context // Context.
Buffer *bytes.Buffer // Buffer for logging content outputs.
Time time.Time // Logging time, which is the time that logging triggers.
TimeFormat string // Formatted time string, like "2016-01-09 12:00:00".
Color int // Using color, like COLOR_RED, COLOR_BLUE, etc.
Level int // Using level, like LEVEL_INFO, LEVEL_ERRO, etc.
LevelFormat string // Formatted level string, like "DEBU", "ERRO", etc.
CallerFunc string // The source function name that calls logging.
CallerPath string // The source file path and its line number that calls logging.
CtxStr string // The retrieved context value string from context.
Prefix string // Custom prefix string for logging content.
Content string // Content is the main logging content that passed by you.
IsAsync bool // IsAsync marks it is in asynchronous logging.
handlerIndex int // Middleware handling index for internal usage.
Logger *Logger // Logger.
Buffer *bytes.Buffer // Buffer for logging content outputs.
Time time.Time // Logging time, which is the time that logging triggers.
TimeFormat string // Formatted time string, like "2016-01-09 12:00:00".
Color int // Using color, like COLOR_RED, COLOR_BLUE, etc.
Level int // Using level, like LEVEL_INFO, LEVEL_ERRO, etc.
LevelFormat string // Formatted level string, like "DEBU", "ERRO", etc.
CallerFunc string // The source function name that calls logging.
CallerPath string // The source file path and its line number that calls logging.
CtxStr string // The retrieved context value string from context.
Prefix string // Custom prefix string for logging content.
Content string // Content is the main logging content that passed by you.
IsAsync bool // IsAsync marks it is in asynchronous logging.
handlerIndex int // Middleware handling index for internal usage.
}
// Next calls the next logging handler in middleware way.
func (i *HandlerInput) Next() {
func (i *HandlerInput) Next(ctx context.Context) {
if len(i.Logger.config.Handlers)-1 > i.handlerIndex {
i.handlerIndex++
i.Logger.config.Handlers[i.handlerIndex](i.Ctx, i)
i.Logger.config.Handlers[i.handlerIndex](ctx, i)
} else {
defaultHandler(i.Ctx, i)
defaultHandler(ctx, i)
}
}