diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 1bb335d9f..e4ed2312e 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -53,6 +53,7 @@ func (c *Core) Ctx(ctx context.Context) DB { panic(err) } newCore.ctx = WithDB(ctx, newCore.db) + newCore.ctx = c.injectInternalCtxData(newCore.ctx) return newCore.db } @@ -62,7 +63,7 @@ func (c *Core) GetCtx() context.Context { if c.ctx != nil { return c.ctx } - return context.TODO() + return c.injectInternalCtxData(context.TODO()) } // GetCtxTimeout returns the context and cancel function for specified timeout type. @@ -264,6 +265,7 @@ func (c *Core) UnionAll(unions ...*Model) *Model { func (c *Core) doUnion(unionType int, unions ...*Model) *Model { var ( + ctx = c.db.GetCtx() unionTypeStr string composedSqlStr string composedArgs = make([]interface{}, 0) @@ -274,7 +276,7 @@ func (c *Core) doUnion(unionType int, unions ...*Model) *Model { unionTypeStr = "UNION" } for _, v := range unions { - sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(queryTypeNormal, false) + sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, queryTypeNormal, false) if composedSqlStr == "" { composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder) } else { diff --git a/database/gdb/gdb_core_ctx.go b/database/gdb/gdb_core_ctx.go new file mode 100644 index 000000000..d41db8133 --- /dev/null +++ b/database/gdb/gdb_core_ctx.go @@ -0,0 +1,46 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gdb + +import ( + "context" + + "github.com/gogf/gf/v2/os/gctx" +) + +// internalCtxData stores data in ctx for internal usage purpose. +type internalCtxData struct { + // Operation DB. + DB DB + + // The first column in result response from database server. + // This attribute is used for Value/Count selection statement purpose, + // which is to avoid HOOK handler that might modify the result columns + // that can confuse the Value/Count selection statement logic. + FirstResultColumn string +} + +const ( + internalCtxDataKeyInCtx gctx.StrKey = "InternalCtxData" +) + +func (c *Core) injectInternalCtxData(ctx context.Context) context.Context { + // If the internal data is already injected, it does nothing. + if ctx.Value(internalCtxDataKeyInCtx) != nil { + return ctx + } + return context.WithValue(ctx, internalCtxDataKeyInCtx, &internalCtxData{ + DB: c.db, + }) +} + +func (c *Core) getInternalCtxDataFromCtx(ctx context.Context) *internalCtxData { + if v := ctx.Value(internalCtxDataKeyInCtx); v != nil { + return v.(*internalCtxData) + } + return nil +} diff --git a/database/gdb/gdb_core_sharding.go b/database/gdb/gdb_core_sharding.go index f27c26b23..4869f6f20 100644 --- a/database/gdb/gdb_core_sharding.go +++ b/database/gdb/gdb_core_sharding.go @@ -66,9 +66,10 @@ type callShardingHandlerFromCtxInput struct { } type callShardingHandlerFromCtxOutput struct { - Sql string - Table string - Schema string + Sql string + Table string + Schema string + ParsedSqlOutput *parseFormattedSqlOutput } func (c *Core) callShardingHandlerFromCtx( @@ -80,6 +81,7 @@ func (c *Core) callShardingHandlerFromCtx( shardingHandler ShardingHandler ok bool ) + // If no sharding handler, it does nothing. if ctxValue = ctx.Value(ctxKeyForShardingHandler); ctxValue == nil { return nil, nil } @@ -108,9 +110,10 @@ func (c *Core) callShardingHandlerFromCtx( } } out = &callShardingHandlerFromCtxOutput{ - Sql: newSql, - Table: shardingOut.Table, - Schema: shardingOut.Schema, + Sql: newSql, + Table: shardingOut.Table, + Schema: shardingOut.Schema, + ParsedSqlOutput: parsedOut, } return out, nil } @@ -143,10 +146,11 @@ func (c *Core) formatSqlWithNewTable(sql, table string) (newSql string, err erro } type parseFormattedSqlOutput struct { - Table string - OperationData map[string]Value - ConditionData map[string]Value - ParsedStmt sqlparser.Statement + Table string + OperationData map[string]Value + ConditionData map[string]Value + ParsedStmt sqlparser.Statement + SelectedFields []string } func (c *Core) parseFormattedSql(formattedSql string) (*parseFormattedSqlOutput, error) { @@ -154,8 +158,9 @@ func (c *Core) parseFormattedSql(formattedSql string) (*parseFormattedSqlOutput, condition sqlparser.Expr err error out = &parseFormattedSqlOutput{ - OperationData: make(map[string]Value), - ConditionData: make(map[string]Value), + SelectedFields: make([]string, 0), + OperationData: make(map[string]Value), + ConditionData: make(map[string]Value), } ) out.ParsedStmt, err = sqlparser.NewParser(strings.NewReader(formattedSql)).ParseStatement() @@ -164,18 +169,30 @@ func (c *Core) parseFormattedSql(formattedSql string) (*parseFormattedSqlOutput, } switch stmt := out.ParsedStmt.(type) { case *sqlparser.SelectStatement: - table, ok := stmt.FromItems.(*sqlparser.TableName) - if !ok { - return nil, gerror.Newf( - `invalid table name "%s" in SQL: %s`, - stmt.FromItems.String(), formattedSql, - ) + if stmt.FromItems != nil { + table, ok := stmt.FromItems.(*sqlparser.TableName) + if !ok { + return nil, gerror.Newf( + `invalid table name "%s" in SQL: %s`, + stmt.FromItems.String(), formattedSql, + ) + } + out.Table = table.TableName() } - out.Table = table.TableName() condition = stmt.Condition + if stmt.Columns != nil { + for _, column := range *stmt.Columns { + if column.Alias != nil { + out.SelectedFields = append(out.SelectedFields, column.Alias.Name) + } else if column.Expr != nil { + out.SelectedFields = append(out.SelectedFields, column.Expr.String()) + } + } + } + case *sqlparser.InsertStatement: out.Table = stmt.TableName.TableName() - if len(stmt.Expressions) > 0 { + if len(stmt.Expressions) > 0 && len(stmt.ColumnNames) > 0 { names := make([]string, len(stmt.ColumnNames)) for i, ident := range stmt.ColumnNames { names[i] = ident.Name diff --git a/database/gdb/gdb_core_underlying.go b/database/gdb/gdb_core_underlying.go index d9e564c16..c9b051823 100644 --- a/database/gdb/gdb_core_underlying.go +++ b/database/gdb/gdb_core_underlying.go @@ -127,8 +127,44 @@ func (c *Core) DoFilter(ctx context.Context, link Link, sql string, args []inter return sql, args, nil } +type sqlParsingHandlerInput struct { + DoCommitInput + FormattedSql string +} + +type sqlParsingHandlerOutput struct { + DoCommitInput +} + +func (c *Core) sqlParsingHandler(ctx context.Context, in sqlParsingHandlerInput) (out *sqlParsingHandlerOutput, err error) { + var shardingOut *callShardingHandlerFromCtxOutput + // Sharding handling. + shardingOut, err = c.callShardingHandlerFromCtx(ctx, callShardingHandlerFromCtxInput{ + Sql: in.Sql, + FormattedSql: in.FormattedSql, + }) + if err != nil { + return + } + if shardingOut != nil { + if shardingOut.Sql != "" { + in.Sql = shardingOut.Sql + } + // If schema changes, it here creates and uses a new DB link operation object. + if shardingOut.Schema != c.db.GetSchema() { + in.Link, err = c.db.GetCore().GetLink(ctx, in.Link.IsOnMaster(), shardingOut.Schema) + } + } + out = &sqlParsingHandlerOutput{ + DoCommitInput: in.DoCommitInput, + } + return +} + // DoCommit commits current sql and arguments to underlying sql driver. func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutput, err error) { + // Inject internal data into ctx, just for double check. + ctx = c.injectInternalCtxData(ctx) var ( sqlTx *sql.Tx @@ -138,30 +174,21 @@ func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutp stmtSqlRows *sql.Rows stmtSqlRow *sql.Row rowsAffected int64 - shardingOut *callShardingHandlerFromCtxOutput cancelFuncForTimeout context.CancelFunc formattedSql = FormatSqlWithArgs(in.Sql, in.Args) timestampMilli1 = gtime.TimestampMilli() ) - shardingOut, err = c.callShardingHandlerFromCtx(ctx, callShardingHandlerFromCtxInput{ - Sql: in.Sql, - FormattedSql: formattedSql, + + // SQL parser handler. + sqlParsingHandlerOut, err := c.sqlParsingHandler(ctx, sqlParsingHandlerInput{ + DoCommitInput: in, + FormattedSql: formattedSql, }) if err != nil { return } - // Sharding handling. - if shardingOut != nil { - if shardingOut.Sql != "" { - in.Sql = shardingOut.Sql - } - // If schema changes, it here creates and uses a new DB link operation object. - if shardingOut.Schema != c.db.GetSchema() { - in.Link, err = c.db.GetCore().GetLink(ctx, in.Link.IsOnMaster(), shardingOut.Schema) - if err != nil { - return - } - } + if sqlParsingHandlerOut != nil { + in = sqlParsingHandlerOut.DoCommitInput } // Trace span start. @@ -372,6 +399,11 @@ func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error) columnTypes[k] = v.DatabaseTypeName() columnNames[k] = v.Name() } + if len(columnNames) > 0 { + if internalData := c.getInternalCtxDataFromCtx(ctx); internalData != nil { + internalData.FirstResultColumn = columnNames[0] + } + } var ( values = make([]interface{}, len(columnNames)) result = make(Result, 0) diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index 04cdd17f6..8878d0682 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -523,7 +523,7 @@ func formatWhereHolder(db DB, in formatWhereHolderInput) (newWhere string, newAr whereStr, _ = gregex.ReplaceStringFunc(`(\?)`, whereStr, func(s string) string { index++ if i+len(newArgs) == index { - sqlWithHolder, holderArgs := model.getFormattedSqlAndArgs(queryTypeNormal, false) + sqlWithHolder, holderArgs := model.getFormattedSqlAndArgs(model.GetCtx(), queryTypeNormal, false) newArgs = append(newArgs, holderArgs...) // Automatically adding the brackets. return "(" + sqlWithHolder + ")" diff --git a/database/gdb/gdb_model_hook.go b/database/gdb/gdb_model_hook.go index 8f54db9f7..fd823495d 100644 --- a/database/gdb/gdb_model_hook.go +++ b/database/gdb/gdb_model_hook.go @@ -39,8 +39,7 @@ type internalParamHook struct { type internalParamHookSelect struct { internalParamHook - handler HookFuncSelect - queryType int + handler HookFuncSelect } type internalParamHookInsert struct { @@ -111,11 +110,6 @@ func (h *HookSelectInput) Next(ctx context.Context) (result Result, err error) { return h.model.db.DoSelect(ctx, h.link, h.Sql, h.Args...) } -// IsCountStatement checks and returns whether current SELECT statement is COUNT statement. -func (h *HookSelectInput) IsCountStatement() bool { - return h.queryType == queryTypeCount -} - // Next calls the next hook handler. func (h *HookInsertInput) Next(ctx context.Context) (result sql.Result, err error) { if h.handler != nil && !h.handlerCalled { diff --git a/database/gdb/gdb_model_select.go b/database/gdb/gdb_model_select.go index 0585dfe47..057b2f146 100644 --- a/database/gdb/gdb_model_select.go +++ b/database/gdb/gdb_model_select.go @@ -7,6 +7,7 @@ package gdb import ( + "context" "fmt" "reflect" @@ -29,7 +30,7 @@ import ( // The optional parameter `where` is the same as the parameter of Model.Where function, // see Model.Where. func (m *Model) All(where ...interface{}) (Result, error) { - return m.doGetAll(false, where...) + return m.doGetAll(m.GetCtx(), false, where...) } // doGetAll does "SELECT FROM ..." statement for the model. @@ -39,12 +40,12 @@ func (m *Model) All(where ...interface{}) (Result, error) { // The parameter `limit1` specifies whether limits querying only one record if m.limit is not set. // The optional parameter `where` is the same as the parameter of Model.Where function, // see Model.Where. -func (m *Model) doGetAll(limit1 bool, where ...interface{}) (Result, error) { +func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{}) (Result, error) { if len(where) > 0 { return m.Where(where[0], where[1:]...).All() } - sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(queryTypeNormal, limit1) - return m.doGetAllBySql(queryTypeNormal, sqlWithHolder, holderArgs...) + sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(m.GetCtx(), queryTypeNormal, limit1) + return m.doGetAllBySql(ctx, queryTypeNormal, sqlWithHolder, holderArgs...) } // getFieldsFiltered checks the fields and fieldsEx attributes, filters and returns the fields that will @@ -133,7 +134,7 @@ func (m *Model) One(where ...interface{}) (Record, error) { if len(where) > 0 { return m.Where(where[0], where[1:]...).One() } - all, err := m.doGetAll(true) + all, err := m.doGetAll(m.GetCtx(), true) if err != nil { return nil, err } @@ -159,14 +160,24 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) { return m.Fields(gconv.String(fieldsAndWhere[0])).Value() } } - one, err := m.One() - if err != nil { - return gvar.New(nil), err + var ( + all Result + err error + ctx = m.GetCtx() + ) + if all, err = m.doGetAll(ctx, true); err != nil { + return nil, err } - for _, v := range one { - return v, nil + if len(all) == 0 { + return gvar.New(nil), nil } - return gvar.New(nil), nil + if internalData := m.db.GetCore().getInternalCtxDataFromCtx(ctx); internalData != nil { + record := all[0] + if v, ok := record[internalData.FirstResultColumn]; ok { + return v, nil + } + } + return nil, gerror.NewCode(gcode.CodeInternalError, `query value error`) } // Array queries and returns data values as slice from database. @@ -366,16 +377,21 @@ func (m *Model) Count(where ...interface{}) (int, error) { return m.Where(where[0], where[1:]...).Count() } var ( - sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(queryTypeCount, false) - list, err = m.doGetAllBySql(queryTypeCount, sqlWithHolder, holderArgs...) + ctx = m.GetCtx() + sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeCount, false) + all, err = m.doGetAllBySql(ctx, queryTypeCount, sqlWithHolder, holderArgs...) ) if err != nil { return 0, err } - if len(list) > 0 { - for _, v := range list[0] { - return v.Int(), nil + if len(all) > 0 { + if internalData := m.db.GetCore().getInternalCtxDataFromCtx(ctx); internalData != nil { + record := all[0] + if v, ok := record[internalData.FirstResultColumn]; ok { + return v.Int(), nil + } } + return 0, gerror.NewCode(gcode.CodeInternalError, `query count error`) } return 0, nil } @@ -502,10 +518,9 @@ func (m *Model) Having(having interface{}, args ...interface{}) *Model { } // doGetAllBySql does the select statement on the database. -func (m *Model) doGetAllBySql(queryType int, sql string, args ...interface{}) (result Result, err error) { +func (m *Model) doGetAllBySql(ctx context.Context, queryType int, sql string, args ...interface{}) (result Result, err error) { var ( ok bool - ctx = m.GetCtx() cacheKey = "" cacheObj = m.db.GetCache() ) @@ -539,14 +554,13 @@ func (m *Model) doGetAllBySql(queryType int, sql string, args ...interface{}) (r link: m.getLink(false), model: m, }, - handler: m.hookHandler.Select, - queryType: queryType, + handler: m.hookHandler.Select, }, Table: m.tables, Sql: sql, Args: m.mergeArguments(args), } - result, err = in.Next(m.GetCtx()) + result, err = in.Next(ctx) // Cache the result. if cacheKey != "" && err == nil { @@ -567,22 +581,22 @@ func (m *Model) doGetAllBySql(queryType int, sql string, args ...interface{}) (r return result, err } -func (m *Model) getFormattedSqlAndArgs(queryType int, limit1 bool) (sqlWithHolder string, holderArgs []interface{}) { +func (m *Model) getFormattedSqlAndArgs(ctx context.Context, queryType int, limit1 bool) (sqlWithHolder string, holderArgs []interface{}) { switch queryType { case queryTypeCount: - countFields := "COUNT(1)" + queryFields := "COUNT(1)" if m.fields != "" && m.fields != "*" { // DO NOT quote the m.fields here, in case of fields like: // DISTINCT t.user_id uid - countFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.fields) + queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.fields) } // Raw SQL Model. if m.rawSql != "" { - sqlWithHolder = fmt.Sprintf("SELECT %s FROM (%s) AS T", countFields, m.rawSql) + sqlWithHolder = fmt.Sprintf("SELECT %s FROM (%s) AS T", queryFields, m.rawSql) return sqlWithHolder, nil } conditionWhere, conditionExtra, conditionArgs := m.formatCondition(false, true) - sqlWithHolder = fmt.Sprintf("SELECT %s FROM %s%s", countFields, m.tables, conditionWhere+conditionExtra) + sqlWithHolder = fmt.Sprintf("SELECT %s FROM %s%s", queryFields, m.tables, conditionWhere+conditionExtra) if len(m.groupBy) > 0 { sqlWithHolder = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", sqlWithHolder) } @@ -603,10 +617,7 @@ func (m *Model) getFormattedSqlAndArgs(queryType int, limit1 bool) (sqlWithHolde // DISTINCT t.user_id uid sqlWithHolder = fmt.Sprintf( "SELECT %s%s FROM %s%s", - m.distinct, - m.getFieldsFiltered(), - m.tables, - conditionWhere+conditionExtra, + m.distinct, m.getFieldsFiltered(), m.tables, conditionWhere+conditionExtra, ) return sqlWithHolder, conditionArgs }