From bbcf49db98f304d4562b3e0c09773f999db2cb9c Mon Sep 17 00:00:00 2001 From: John Guo Date: Tue, 16 Apr 2024 19:31:06 +0800 Subject: [PATCH] fix: #3238 first column might be overwritten in interal context data in multiple goroutines querying (#3476) --- .../drivers/mysql/mysql_z_unit_issue_test.go | 35 +++++++++++++ database/gdb/gdb.go | 6 +-- database/gdb/gdb_core.go | 4 +- database/gdb/gdb_core_config.go | 8 +-- database/gdb/gdb_core_ctx.go | 49 +++++++++++++++---- database/gdb/gdb_core_transaction.go | 4 +- database/gdb/gdb_core_underlying.go | 5 +- database/gdb/gdb_model_cache.go | 17 ++++--- database/gdb/gdb_model_select.go | 22 ++++++--- 9 files changed, 114 insertions(+), 36 deletions(-) diff --git a/contrib/drivers/mysql/mysql_z_unit_issue_test.go b/contrib/drivers/mysql/mysql_z_unit_issue_test.go index 775060630..072fe1b93 100644 --- a/contrib/drivers/mysql/mysql_z_unit_issue_test.go +++ b/contrib/drivers/mysql/mysql_z_unit_issue_test.go @@ -9,6 +9,7 @@ package mysql_test import ( "context" "fmt" + "sync" "testing" "time" @@ -1125,3 +1126,37 @@ func Test_Issue2643(t *testing.T) { t.Assert(gstr.Contains(sqlContent, expectKey2), true) }) } + +// https://github.com/gogf/gf/issues/3238 +func Test_Issue3238(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + gtest.C(t, func(t *gtest.T) { + for i := 0; i < 100; i++ { + _, err := db.Ctx(ctx).Model(table).Hook(gdb.HookHandler{ + Select: func(ctx context.Context, in *gdb.HookSelectInput) (result gdb.Result, err error) { + result, err = in.Next(ctx) + if err != nil { + return + } + var wg sync.WaitGroup + for _, record := range result { + wg.Add(1) + go func(record gdb.Record) { + defer wg.Done() + id, _ := db.Ctx(ctx).Model(table).WherePri(1).Value(`id`) + nickname, _ := db.Ctx(ctx).Model(table).WherePri(1).Value(`nickname`) + t.Assert(id.Int(), 1) + t.Assert(nickname.String(), "name_1") + }(record) + } + wg.Wait() + return + }, + }, + ).All() + t.AssertNil(err) + } + }) +} diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index 9145ddeea..20fd8b507 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -744,10 +744,10 @@ func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error node.Name = nodeSchema } // Update the configuration object in internal data. - internalData := c.GetInternalCtxDataFromCtx(ctx) - if internalData != nil { - internalData.ConfigNode = node + if err = c.setConfigNodeToCtx(ctx, node); err != nil { + return } + // Cache the underlying connection pool object by node. var ( instanceCacheFunc = func() interface{} { diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 73cfb541d..80ffef5cd 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -56,7 +56,7 @@ func (c *Core) Ctx(ctx context.Context) DB { panic(err) } newCore.ctx = WithDB(ctx, newCore.db) - newCore.ctx = c.InjectInternalCtxData(newCore.ctx) + newCore.ctx = c.injectInternalCtxData(newCore.ctx) return newCore.db } @@ -67,7 +67,7 @@ func (c *Core) GetCtx() context.Context { if ctx == nil { ctx = context.TODO() } - return c.InjectInternalCtxData(ctx) + return c.injectInternalCtxData(ctx) } // GetCtxTimeout returns the context and cancel function for specified timeout type. diff --git a/database/gdb/gdb_core_config.go b/database/gdb/gdb_core_config.go index 4c3f08740..d649f8998 100644 --- a/database/gdb/gdb_core_config.go +++ b/database/gdb/gdb_core_config.go @@ -208,15 +208,15 @@ func (c *Core) SetMaxConnLifeTime(d time.Duration) { // GetConfig returns the current used node configuration. func (c *Core) GetConfig() *ConfigNode { - internalData := c.GetInternalCtxDataFromCtx(c.db.GetCtx()) - if internalData != nil && internalData.ConfigNode != nil { + var configNode = c.getConfigNodeFromCtx(c.db.GetCtx()) + if configNode != nil { // Note: // It so here checks and returns the config from current DB, // if different schemas between current DB and config.Name from context, // for example, in nested transaction scenario, the context is passed all through the logic procedure, // but the config.Name from context may be still the original one from the first transaction object. - if c.config.Name == internalData.ConfigNode.Name { - return internalData.ConfigNode + if c.config.Name == configNode.Name { + return configNode } } return c.config diff --git a/database/gdb/gdb_core_ctx.go b/database/gdb/gdb_core_ctx.go index 3c4c1ab4f..ea8ca4b66 100644 --- a/database/gdb/gdb_core_ctx.go +++ b/database/gdb/gdb_core_ctx.go @@ -8,18 +8,22 @@ package gdb import ( "context" + "sync" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/os/gctx" ) // internalCtxData stores data in ctx for internal usage purpose. type internalCtxData struct { - // Operation DB. - DB DB - + sync.Mutex // Used configuration node in current operation. ConfigNode *ConfigNode +} +// column stores column data in ctx for internal usage purpose. +type internalColumnData struct { // The first column in result response from database server. // This attribute is used for Value/Count selection statement purpose, // which is to avoid HOOK handler that might modify the result columns @@ -28,7 +32,8 @@ type internalCtxData struct { } const ( - internalCtxDataKeyInCtx gctx.StrKey = "InternalCtxData" + internalCtxDataKeyInCtx gctx.StrKey = "InternalCtxData" + internalColumnDataKeyInCtx gctx.StrKey = "InternalColumnData" // `ignoreResultKeyInCtx` is a mark for some db drivers that do not support `RowsAffected` function, // for example: `clickhouse`. The `clickhouse` does not support fetching insert/update results, @@ -37,20 +42,46 @@ const ( ignoreResultKeyInCtx gctx.StrKey = "IgnoreResult" ) -func (c *Core) InjectInternalCtxData(ctx context.Context) context.Context { +func (c *Core) injectInternalCtxData(ctx context.Context) context.Context { // If the internal data is already injected, it does nothing. if ctx.Value(internalCtxDataKeyInCtx) != nil { return ctx } return context.WithValue(ctx, internalCtxDataKeyInCtx, &internalCtxData{ - DB: c.db, ConfigNode: c.config, }) } -func (c *Core) GetInternalCtxDataFromCtx(ctx context.Context) *internalCtxData { - if v := ctx.Value(internalCtxDataKeyInCtx); v != nil { - return v.(*internalCtxData) +func (c *Core) setConfigNodeToCtx(ctx context.Context, node *ConfigNode) error { + value := ctx.Value(internalCtxDataKeyInCtx) + if value == nil { + return gerror.NewCode(gcode.CodeInternalError, `no internal data found in context`) + } + + data := value.(*internalCtxData) + data.Lock() + defer data.Unlock() + data.ConfigNode = node + return nil +} + +func (c *Core) getConfigNodeFromCtx(ctx context.Context) *ConfigNode { + if value := ctx.Value(internalCtxDataKeyInCtx); value != nil { + data := value.(*internalCtxData) + data.Lock() + defer data.Unlock() + return data.ConfigNode + } + return nil +} + +func (c *Core) injectInternalColumn(ctx context.Context) context.Context { + return context.WithValue(ctx, internalColumnDataKeyInCtx, &internalColumnData{}) +} + +func (c *Core) getInternalColumnFromCtx(ctx context.Context) *internalColumnData { + if v := ctx.Value(internalColumnDataKeyInCtx); v != nil { + return v.(*internalColumnData) } return nil } diff --git a/database/gdb/gdb_core_transaction.go b/database/gdb/gdb_core_transaction.go index 33b3aa443..e098982c2 100644 --- a/database/gdb/gdb_core_transaction.go +++ b/database/gdb/gdb_core_transaction.go @@ -72,7 +72,7 @@ func (c *Core) Transaction(ctx context.Context, f func(ctx context.Context, tx T if ctx == nil { ctx = c.db.GetCtx() } - ctx = c.InjectInternalCtxData(ctx) + ctx = c.injectInternalCtxData(ctx) // Check transaction object from context. var tx TX tx = TXFromCtx(ctx, c.db.GetGroup()) @@ -160,7 +160,7 @@ func (tx *TXCore) transactionKeyForNestedPoint() string { func (tx *TXCore) Ctx(ctx context.Context) TX { tx.ctx = ctx if tx.ctx != nil { - tx.ctx = tx.db.GetCore().InjectInternalCtxData(tx.ctx) + tx.ctx = tx.db.GetCore().injectInternalCtxData(tx.ctx) } return tx } diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index f574a1948..dec32513a 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -156,9 +156,6 @@ func (c *Core) DoFilter(ctx context.Context, link Link, sql string, args []inter // DoCommit commits current sql and arguments to underlying sql driver. func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutput, err error) { - // Inject internal data into ctx, especially for transaction creating. - ctx = c.InjectInternalCtxData(ctx) - var ( sqlTx *sql.Tx sqlStmt *sql.Stmt @@ -420,7 +417,7 @@ func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) } if len(columnTypes) > 0 { - if internalData := c.GetInternalCtxDataFromCtx(ctx); internalData != nil { + if internalData := c.getInternalColumnFromCtx(ctx); internalData != nil { internalData.FirstResultColumn = columnTypes[0].Name() } } diff --git a/database/gdb/gdb_model_cache.go b/database/gdb/gdb_model_cache.go index 793a9696d..2bdab1c8e 100644 --- a/database/gdb/gdb_model_cache.go +++ b/database/gdb/gdb_model_cache.go @@ -69,10 +69,11 @@ func (m *Model) getSelectResultFromCache(ctx context.Context, sql string, args . cacheItem *selectCacheItem cacheKey = m.makeSelectCacheKey(sql, args...) cacheObj = m.db.GetCache() + core = m.db.GetCore() ) defer func() { if cacheItem != nil { - if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(ctx); internalData != nil { + if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil { if cacheItem.FirstResultColumn != "" { internalData.FirstResultColumn = cacheItem.FirstResultColumn } @@ -106,9 +107,10 @@ func (m *Model) saveSelectResultToCache( } // Special handler for Value/Count operations result. if len(result) > 0 { + var core = m.db.GetCore() switch queryType { case queryTypeValue, queryTypeCount: - if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(ctx); internalData != nil { + if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil { if result[0][internalData.FirstResultColumn].IsEmpty() { result = nil } @@ -124,10 +126,13 @@ func (m *Model) saveSelectResultToCache( result = nil } } - var cacheItem = &selectCacheItem{ - Result: result, - } - if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(ctx); internalData != nil { + var ( + core = m.db.GetCore() + cacheItem = &selectCacheItem{ + Result: result, + } + ) + if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil { cacheItem.FirstResultColumn = internalData.FirstResultColumn } if errCache := cacheObj.Set(ctx, cacheKey, cacheItem, m.cacheOption.Duration); errCache != nil { diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index 69bfd38c3..975ca6dc4 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -139,9 +139,13 @@ func (m *Model) Array(fieldsAndWhere ...interface{}) ([]Value, error) { if err != nil { return nil, err } - var field string + var ( + field string + core = m.db.GetCore() + ctx = core.injectInternalColumn(m.GetCtx()) + ) if len(all) > 0 { - if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(m.GetCtx()); internalData != nil { + if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil { field = internalData.FirstResultColumn } else { return nil, gerror.NewCode( @@ -376,7 +380,10 @@ func (m *Model) ScanList(structSlicePointer interface{}, bindToAttrName string, // and fieldsAndWhere[1:] is treated as where condition fields. // Also see Model.Fields and Model.Where functions. func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) { - var ctx = m.GetCtx() + var ( + core = m.db.GetCore() + ctx = core.injectInternalColumn(m.GetCtx()) + ) if len(fieldsAndWhere) > 0 { if len(fieldsAndWhere) > 2 { return m.Fields(gconv.String(fieldsAndWhere[0])).Where(fieldsAndWhere[1], fieldsAndWhere[2:]...).Value() @@ -394,7 +401,7 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) { return nil, err } if len(all) > 0 { - if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(ctx); internalData != nil { + if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil { if v, ok := all[0][internalData.FirstResultColumn]; ok { return v, nil } @@ -412,7 +419,10 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) { // The optional parameter `where` is the same as the parameter of Model.Where function, // see Model.Where. func (m *Model) Count(where ...interface{}) (int, error) { - var ctx = m.GetCtx() + var ( + core = m.db.GetCore() + ctx = core.injectInternalColumn(m.GetCtx()) + ) if len(where) > 0 { return m.Where(where[0], where[1:]...).Count() } @@ -424,7 +434,7 @@ func (m *Model) Count(where ...interface{}) (int, error) { return 0, err } if len(all) > 0 { - if internalData := m.db.GetCore().GetInternalCtxDataFromCtx(ctx); internalData != nil { + if internalData := core.getInternalColumnFromCtx(ctx); internalData != nil { if v, ok := all[0][internalData.FirstResultColumn]; ok { return v.Int(), nil }