From dccfc1c8cd2a4564dab053ebea2a8e92f5929ae8 Mon Sep 17 00:00:00 2001 From: John Guo Date: Mon, 14 Mar 2022 23:47:55 +0800 Subject: [PATCH] add hook feature for model of package gdb --- database/gdb/gdb.go | 11 +- database/gdb/gdb_core.go | 2 +- database/gdb/gdb_core_utility.go | 27 ++++ database/gdb/gdb_model.go | 1 + database/gdb/gdb_model_delete.go | 42 +++-- database/gdb/gdb_model_hook.go | 148 ++++++++++++++++++ database/gdb/gdb_model_insert.go | 15 +- database/gdb/gdb_model_select.go | 18 ++- database/gdb/gdb_model_update.go | 23 ++- database/gdb/gdb_model_with.go | 4 +- database/gdb/gdb_z_mysql_feature_hook_test.go | 136 ++++++++++++++++ os/glog/glog_logger.go | 5 +- os/glog/glog_logger_handler.go | 35 ++--- 13 files changed, 415 insertions(+), 52 deletions(-) create mode 100644 database/gdb/gdb_model_hook.go create mode 100644 database/gdb/gdb_z_mysql_feature_hook_test.go diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 2b9034325..a112d16cd 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -20,6 +20,7 @@ import ( "github.com/gogf/gf/v2/internal/intlog" "github.com/gogf/gf/v2/os/gcache" "github.com/gogf/gf/v2/os/gcmd" + "github.com/gogf/gf/v2/os/gctx" "github.com/gogf/gf/v2/os/glog" "github.com/gogf/gf/v2/util/grand" ) @@ -215,9 +216,6 @@ type Driver interface { // Link is a common database function wrapper interface. type Link interface { - Query(sql string, args ...interface{}) (*sql.Rows, error) - Exec(sql string, args ...interface{}) (sql.Result, error) - Prepare(sql string) (*sql.Stmt, error) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) @@ -287,9 +285,10 @@ const ( ctxTimeoutTypeExec = iota ctxTimeoutTypeQuery ctxTimeoutTypePrepare - commandEnvKeyForDryRun = "gf.gdb.dryrun" - modelForDaoSuffix = `ForDao` - dbRoleSlave = `slave` + commandEnvKeyForDryRun = "gf.gdb.dryrun" + modelForDaoSuffix = `ForDao` + dbRoleSlave = `slave` + contextKeyForDB gctx.StrKey = `DBInContext` ) const ( diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 39d3b4794..1bb335d9f 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -45,7 +45,6 @@ func (c *Core) Ctx(ctx context.Context) DB { configNode = c.db.GetConfig() ) *newCore = *c - newCore.ctx = ctx // It creates a new DB object, which is commonly a wrapper for object `Core`. newCore.db, err = driverMap[configNode.Type].New(newCore, configNode) if err != nil { @@ -53,6 +52,7 @@ func (c *Core) Ctx(ctx context.Context) DB { // Do not let it continue. panic(err) } + newCore.ctx = WithDB(ctx, newCore.db) return newCore.db } diff --git a/database/gdb/gdb_core_utility.go b/database/gdb/gdb_core_utility.go index 0e0ff4403..9dd3beedd 100644 --- a/database/gdb/gdb_core_utility.go +++ b/database/gdb/gdb_core_utility.go @@ -8,12 +8,39 @@ package gdb import ( + "context" + "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/text/gregex" "github.com/gogf/gf/v2/text/gstr" ) +// WithDB injects given db object into context and returns a new context. +func WithDB(ctx context.Context, db DB) context.Context { + if db == nil { + return ctx + } + dbCtx := db.GetCtx() + if ctxDb := DBFromCtx(dbCtx); ctxDb != nil { + return dbCtx + } + ctx = context.WithValue(ctx, contextKeyForDB, db) + return ctx +} + +// DBFromCtx retrieves and returns DB object from context. +func DBFromCtx(ctx context.Context) DB { + if ctx == nil { + return nil + } + v := ctx.Value(contextKeyForDB) + if v != nil { + return v.(DB) + } + return nil +} + // MasterLink acts like function Master but with additional `schema` parameter specifying // the schema for the connection. It is defined for internal usage. // Also see Master. diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index b1c1e8196..94b92a8d9 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -44,6 +44,7 @@ type Model struct { lockInfo string // Lock for update or in shared lock. cacheEnabled bool // Enable sql result cache feature, which is mainly for indicating cache duration(especially 0) usage. cacheOption CacheOption // Cache option for query statement. + hook HookHandler // Hook functions for model hook feature. unscoped bool // Disables soft deleting features when select/delete operations. safe bool // If true, it clones and returns a new model object whenever operation done; or else it changes the attribute of current model. onDuplicate interface{} // onDuplicate is used for ON "DUPLICATE KEY UPDATE" statement. diff --git a/database/gdb/gdb_model_delete.go b/database/gdb/gdb_model_delete.go index 334338f6f..8b27cd945 100644 --- a/database/gdb/gdb_model_delete.go +++ b/database/gdb/gdb_model_delete.go @@ -34,18 +34,40 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) { ) // Soft deleting. if !m.unscoped && fieldNameDelete != "" { - return m.db.DoUpdate( - m.GetCtx(), - m.getLink(true), - m.tables, - fmt.Sprintf(`%s=?`, m.db.GetCore().QuoteString(fieldNameDelete)), - conditionWhere+conditionExtra, - append([]interface{}{gtime.Now().String()}, conditionArgs...), - ) + in := &HookUpdateInput{ + internalParamHookUpdate: internalParamHookUpdate{ + internalParamHook: internalParamHook{ + db: m.db, + link: m.getLink(true), + }, + handler: m.hook.Update, + }, + Table: m.tables, + Data: fmt.Sprintf(`%s=?`, m.db.GetCore().QuoteString(fieldNameDelete)), + Condition: conditionWhere + conditionExtra, + Args: append([]interface{}{gtime.Now().String()}, conditionArgs...), + } + return in.Next(m.GetCtx()) } conditionStr := conditionWhere + conditionExtra if !gstr.ContainsI(conditionStr, " WHERE ") { - return nil, gerror.NewCode(gcode.CodeMissingParameter, "there should be WHERE condition statement for DELETE operation") + return nil, gerror.NewCode( + gcode.CodeMissingParameter, + "there should be WHERE condition statement for DELETE operation", + ) } - return m.db.DoDelete(m.GetCtx(), m.getLink(true), m.tables, conditionStr, conditionArgs...) + + in := &HookDeleteInput{ + internalParamHookDelete: internalParamHookDelete{ + internalParamHook: internalParamHook{ + db: m.db, + link: m.getLink(true), + }, + handler: m.hook.Delete, + }, + Table: m.tables, + Condition: conditionStr, + Args: conditionArgs, + } + return in.Next(m.GetCtx()) } diff --git a/database/gdb/gdb_model_hook.go b/database/gdb/gdb_model_hook.go new file mode 100644 index 000000000..ca2702c40 --- /dev/null +++ b/database/gdb/gdb_model_hook.go @@ -0,0 +1,148 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gdb + +import ( + "context" + "database/sql" + + "github.com/gogf/gf/v2/text/gstr" +) + +type ( + HookFuncSelect func(ctx context.Context, in *HookSelectInput) (result Result, err error) + HookFuncInsert func(ctx context.Context, in *HookInsertInput) (result sql.Result, err error) + HookFuncUpdate func(ctx context.Context, in *HookUpdateInput) (result sql.Result, err error) + HookFuncDelete func(ctx context.Context, in *HookDeleteInput) (result sql.Result, err error) +) + +// HookHandler manages all supported hook functions for Model. +type HookHandler struct { + Select HookFuncSelect + Insert HookFuncInsert + Update HookFuncUpdate + Delete HookFuncDelete +} + +// internalParamHook manages all internal parameters for hook operations. +// The `internal` obviously means you cannot access these parameters outside this package. +type internalParamHook struct { + db DB // Underlying DB object. + link Link // Connection object from third party sql driver. + handlerCalled bool // Simple mark for custom handler called, in case of recursive calling. + removedWhere bool // Removed mark for condition string that was removed `WHERE` prefix. +} + +type internalParamHookSelect struct { + internalParamHook + handler HookFuncSelect +} + +type internalParamHookInsert struct { + internalParamHook + handler HookFuncInsert +} + +type internalParamHookUpdate struct { + internalParamHook + handler HookFuncUpdate +} + +type internalParamHookDelete struct { + internalParamHook + handler HookFuncDelete +} + +// HookSelectInput holds the parameters for select hook operation. +type HookSelectInput struct { + internalParamHookSelect + Table string + Sql string + Args []interface{} +} + +// HookInsertInput holds the parameters for insert hook operation. +type HookInsertInput struct { + internalParamHookInsert + Table string + Data List + Option DoInsertOption +} + +// HookUpdateInput holds the parameters for update hook operation. +type HookUpdateInput struct { + internalParamHookUpdate + Table string + Data interface{} + Condition string + Args []interface{} +} + +// HookDeleteInput holds the parameters for delete hook operation. +type HookDeleteInput struct { + internalParamHookDelete + Table string + Condition string + Args []interface{} +} + +// Next calls the next hook handler. +func (h *HookSelectInput) Next(ctx context.Context) (result Result, err error) { + if h.handler != nil && !h.handlerCalled { + h.handlerCalled = true + return h.handler(ctx, h) + } + return h.db.DoSelect(ctx, h.link, h.Sql, h.Args...) +} + +// Next calls the next hook handler. +func (h *HookInsertInput) Next(ctx context.Context) (result sql.Result, err error) { + if h.handler != nil && !h.handlerCalled { + h.handlerCalled = true + return h.handler(ctx, h) + } + return h.db.DoInsert(ctx, h.link, h.Table, h.Data, h.Option) +} + +// Next calls the next hook handler. +func (h *HookUpdateInput) Next(ctx context.Context) (result sql.Result, err error) { + if h.handler != nil && !h.handlerCalled { + h.handlerCalled = true + if gstr.HasPrefix(h.Condition, " WHERE ") { + h.removedWhere = true + h.Condition = gstr.TrimLeftStr(h.Condition, " WHERE ") + } + return h.handler(ctx, h) + } + if h.removedWhere { + h.Condition = " WHERE " + h.Condition + } + return h.db.DoUpdate(ctx, h.link, h.Table, h.Data, h.Condition, h.Args...) +} + +// Next calls the next hook handler. +func (h *HookDeleteInput) Next(ctx context.Context) (result sql.Result, err error) { + if h.handler != nil && !h.handlerCalled { + h.handlerCalled = true + if gstr.HasPrefix(h.Condition, " WHERE ") { + h.removedWhere = true + h.Condition = gstr.TrimLeftStr(h.Condition, " WHERE ") + } + return h.handler(ctx, h) + } + if h.removedWhere { + h.Condition = " WHERE " + h.Condition + } + return h.db.DoDelete(ctx, h.link, h.Table, h.Condition, h.Args...) +} + +// Hook sets the hook functions for current model. +func (m *Model) Hook(hook HookHandler) *Model { + model := m.getModel() + model.hook = hook + return model +} diff --git a/database/gdb/gdb_model_insert.go b/database/gdb/gdb_model_insert.go index d516ef5fb..4013ca7f7 100644 --- a/database/gdb/gdb_model_insert.go +++ b/database/gdb/gdb_model_insert.go @@ -310,7 +310,20 @@ func (m *Model) doInsertWithOption(insertOption int) (result sql.Result, err err if err != nil { return result, err } - return m.db.DoInsert(m.GetCtx(), m.getLink(true), m.tables, list, doInsertOption) + + in := &HookInsertInput{ + internalParamHookInsert: internalParamHookInsert{ + internalParamHook: internalParamHook{ + db: m.db, + link: m.getLink(true), + }, + handler: m.hook.Insert, + }, + Table: m.tables, + Data: list, + Option: doInsertOption, + } + return in.Next(m.GetCtx()) } func (m *Model) formatDoInsertOption(insertOption int, columnNames []string) (option DoInsertOption, err error) { diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index 9ca632272..e172c4a11 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -532,9 +532,21 @@ func (m *Model) doGetAllBySql(sql string, args ...interface{}) (result Result, e } } } - result, err = m.db.DoSelect( - m.GetCtx(), m.getLink(false), sql, m.mergeArguments(args)..., - ) + + in := &HookSelectInput{ + internalParamHookSelect: internalParamHookSelect{ + internalParamHook: internalParamHook{ + db: m.db, + link: m.getLink(false), + }, + handler: m.hook.Select, + }, + Table: m.tables, + Sql: sql, + Args: m.mergeArguments(args), + } + result, err = in.Next(m.GetCtx()) + // Cache the result. if cacheKey != "" && err == nil { if m.cacheOption.Duration < 0 { diff --git a/database/gdb/gdb_model_update.go b/database/gdb/gdb_model_update.go index 064fccceb..f1be33829 100644 --- a/database/gdb/gdb_model_update.go +++ b/database/gdb/gdb_model_update.go @@ -74,14 +74,21 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro if !gstr.ContainsI(conditionStr, " WHERE ") { return nil, gerror.NewCode(gcode.CodeMissingParameter, "there should be WHERE condition statement for UPDATE operation") } - return m.db.DoUpdate( - m.GetCtx(), - m.getLink(true), - m.tables, - newData, - conditionStr, - m.mergeArguments(conditionArgs)..., - ) + + in := &HookUpdateInput{ + internalParamHookUpdate: internalParamHookUpdate{ + internalParamHook: internalParamHook{ + db: m.db, + link: m.getLink(true), + }, + handler: m.hook.Update, + }, + Table: m.tables, + Data: newData, + Condition: conditionStr, + Args: m.mergeArguments(conditionArgs), + } + return in.Next(m.GetCtx()) } // Increment increments a column's value by a given amount. diff --git a/database/gdb/gdb_model_with.go b/database/gdb/gdb_model_with.go index e5846e883..3e953a034 100644 --- a/database/gdb/gdb_model_with.go +++ b/database/gdb/gdb_model_with.go @@ -143,7 +143,7 @@ func (m *Model) doWithScanStruct(pointer interface{}) error { } // Recursively with feature checks. - model = m.db.With(field.Value) + model = m.db.With(field.Value).Hook(m.hook) if m.withAll { model = model.WithAll() } else { @@ -258,7 +258,7 @@ func (m *Model) doWithScanStructs(pointer interface{}) error { fieldKeys = structType.FieldKeys() } // Recursively with feature checks. - model = m.db.With(field.Value) + model = m.db.With(field.Value).Hook(m.hook) if m.withAll { model = model.WithAll() } else { diff --git a/database/gdb/gdb_z_mysql_feature_hook_test.go b/database/gdb/gdb_z_mysql_feature_hook_test.go new file mode 100644 index 000000000..0183dfa1f --- /dev/null +++ b/database/gdb/gdb_z_mysql_feature_hook_test.go @@ -0,0 +1,136 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gdb_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + + "github.com/gogf/gf/v2/container/gvar" + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/test/gtest" +) + +func Test_Model_Hook_Select(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + m := db.Model(table).Hook(gdb.HookHandler{ + Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { + result, err = in.Next(ctx) + if err != nil { + return + } + for i, record := range result { + record["test"] = gvar.New(100 + record["id"].Int()) + result[i] = record + } + return + }, + }) + all, err := m.Where(`id > 6`).OrderAsc(`id`).All() + t.AssertNil(err) + t.Assert(len(all), 4) + t.Assert(all[0]["id"].Int(), 7) + t.Assert(all[0]["test"].Int(), 107) + t.Assert(all[1]["test"].Int(), 108) + t.Assert(all[2]["test"].Int(), 109) + t.Assert(all[3]["test"].Int(), 110) + }) +} + +func Test_Model_Hook_Insert(t *testing.T) { + table := createTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + m := db.Model(table).Hook(gdb.HookHandler{ + Insert: func(ctx context.Context, in *gdb.HookInsertInput) (result sql.Result, err error) { + for i, item := range in.Data { + item["passport"] = fmt.Sprintf(`test_port_%d`, item["id"]) + item["nickname"] = fmt.Sprintf(`test_name_%d`, item["id"]) + in.Data[i] = item + } + return in.Next(ctx) + }, + }) + _, err := m.Insert(g.Map{ + "id": 1, + "nickname": "name_1", + }) + t.AssertNil(err) + one, err := m.One() + t.AssertNil(err) + t.Assert(one["id"].Int(), 1) + t.Assert(one["passport"], `test_port_1`) + t.Assert(one["nickname"], `test_name_1`) + }) +} + +func Test_Model_Hook_Update(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + m := db.Model(table).Hook(gdb.HookHandler{ + Update: func(ctx context.Context, in *gdb.HookUpdateInput) (result sql.Result, err error) { + switch value := in.Data.(type) { + case gdb.List: + for i, data := range value { + data["passport"] = `port` + data["nickname"] = `name` + value[i] = data + } + in.Data = value + + case gdb.Map: + value["passport"] = `port` + value["nickname"] = `name` + in.Data = value + } + return in.Next(ctx) + }, + }) + _, err := m.Data(g.Map{ + "nickname": "name_1", + }).WherePri(1).Update() + t.AssertNil(err) + + one, err := m.One() + t.AssertNil(err) + t.Assert(one["id"].Int(), 1) + t.Assert(one["passport"], `port`) + t.Assert(one["nickname"], `name`) + }) +} + +func Test_Model_Hook_Delete(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + m := db.Model(table).Hook(gdb.HookHandler{ + Delete: func(ctx context.Context, in *gdb.HookDeleteInput) (result sql.Result, err error) { + return db.Model(table).Data(g.Map{ + "nickname": `deleted`, + }).Where(in.Condition).Update() + }, + }) + _, err := m.Where(1).Delete() + t.AssertNil(err) + + all, err := m.All() + t.AssertNil(err) + for _, item := range all { + t.Assert(item["nickname"].String(), `deleted`) + } + }) +} diff --git a/os/glog/glog_logger.go b/os/glog/glog_logger.go index 8dd985759..97acbe557 100644 --- a/os/glog/glog_logger.go +++ b/os/glog/glog_logger.go @@ -117,7 +117,6 @@ func (l *Logger) print(ctx context.Context, level int, values ...interface{}) { input = &HandlerInput{ Logger: l, Buffer: bytes.NewBuffer(nil), - Ctx: ctx, Time: now, Color: defaultLevelColor[level], Level: level, @@ -221,13 +220,13 @@ func (l *Logger) print(ctx context.Context, level int, values ...interface{}) { if l.config.Flags&F_ASYNC > 0 { input.IsAsync = true err := asyncPool.Add(ctx, func(ctx context.Context) { - input.Next() + input.Next(ctx) }) if err != nil { intlog.Errorf(ctx, `%+v`, err) } } else { - input.Next() + input.Next(ctx) } } diff --git a/os/glog/glog_logger_handler.go b/os/glog/glog_logger_handler.go index bdeb4cdfc..5c7d9932c 100644 --- a/os/glog/glog_logger_handler.go +++ b/os/glog/glog_logger_handler.go @@ -17,30 +17,29 @@ type Handler func(ctx context.Context, in *HandlerInput) // HandlerInput is the input parameter struct for logging Handler. type HandlerInput struct { - Logger *Logger // Logger. - Ctx context.Context // Context. - Buffer *bytes.Buffer // Buffer for logging content outputs. - Time time.Time // Logging time, which is the time that logging triggers. - TimeFormat string // Formatted time string, like "2016-01-09 12:00:00". - Color int // Using color, like COLOR_RED, COLOR_BLUE, etc. - Level int // Using level, like LEVEL_INFO, LEVEL_ERRO, etc. - LevelFormat string // Formatted level string, like "DEBU", "ERRO", etc. - CallerFunc string // The source function name that calls logging. - CallerPath string // The source file path and its line number that calls logging. - CtxStr string // The retrieved context value string from context. - Prefix string // Custom prefix string for logging content. - Content string // Content is the main logging content that passed by you. - IsAsync bool // IsAsync marks it is in asynchronous logging. - handlerIndex int // Middleware handling index for internal usage. + Logger *Logger // Logger. + Buffer *bytes.Buffer // Buffer for logging content outputs. + Time time.Time // Logging time, which is the time that logging triggers. + TimeFormat string // Formatted time string, like "2016-01-09 12:00:00". + Color int // Using color, like COLOR_RED, COLOR_BLUE, etc. + Level int // Using level, like LEVEL_INFO, LEVEL_ERRO, etc. + LevelFormat string // Formatted level string, like "DEBU", "ERRO", etc. + CallerFunc string // The source function name that calls logging. + CallerPath string // The source file path and its line number that calls logging. + CtxStr string // The retrieved context value string from context. + Prefix string // Custom prefix string for logging content. + Content string // Content is the main logging content that passed by you. + IsAsync bool // IsAsync marks it is in asynchronous logging. + handlerIndex int // Middleware handling index for internal usage. } // Next calls the next logging handler in middleware way. -func (i *HandlerInput) Next() { +func (i *HandlerInput) Next(ctx context.Context) { if len(i.Logger.config.Handlers)-1 > i.handlerIndex { i.handlerIndex++ - i.Logger.config.Handlers[i.handlerIndex](i.Ctx, i) + i.Logger.config.Handlers[i.handlerIndex](ctx, i) } else { - defaultHandler(i.Ctx, i) + defaultHandler(ctx, i) } }