improve hook and sharding feature for package gdb

This commit is contained in:
John Guo
2022-03-24 15:33:30 +08:00
parent a5e20e4939
commit cc01629b57
7 changed files with 178 additions and 76 deletions

View File

@ -53,6 +53,7 @@ func (c *Core) Ctx(ctx context.Context) DB {
panic(err)
}
newCore.ctx = WithDB(ctx, newCore.db)
newCore.ctx = c.injectInternalCtxData(newCore.ctx)
return newCore.db
}
@ -62,7 +63,7 @@ func (c *Core) GetCtx() context.Context {
if c.ctx != nil {
return c.ctx
}
return context.TODO()
return c.injectInternalCtxData(context.TODO())
}
// GetCtxTimeout returns the context and cancel function for specified timeout type.
@ -264,6 +265,7 @@ func (c *Core) UnionAll(unions ...*Model) *Model {
func (c *Core) doUnion(unionType int, unions ...*Model) *Model {
var (
ctx = c.db.GetCtx()
unionTypeStr string
composedSqlStr string
composedArgs = make([]interface{}, 0)
@ -274,7 +276,7 @@ func (c *Core) doUnion(unionType int, unions ...*Model) *Model {
unionTypeStr = "UNION"
}
for _, v := range unions {
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(queryTypeNormal, false)
sqlWithHolder, holderArgs := v.getFormattedSqlAndArgs(ctx, queryTypeNormal, false)
if composedSqlStr == "" {
composedSqlStr += fmt.Sprintf(`(%s)`, sqlWithHolder)
} else {

View File

@ -0,0 +1,46 @@
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package gdb
import (
"context"
"github.com/gogf/gf/v2/os/gctx"
)
// internalCtxData stores data in ctx for internal usage purpose.
type internalCtxData struct {
// Operation DB.
DB DB
// The first column in result response from database server.
// This attribute is used for Value/Count selection statement purpose,
// which is to avoid HOOK handler that might modify the result columns
// that can confuse the Value/Count selection statement logic.
FirstResultColumn string
}
const (
internalCtxDataKeyInCtx gctx.StrKey = "InternalCtxData"
)
func (c *Core) injectInternalCtxData(ctx context.Context) context.Context {
// If the internal data is already injected, it does nothing.
if ctx.Value(internalCtxDataKeyInCtx) != nil {
return ctx
}
return context.WithValue(ctx, internalCtxDataKeyInCtx, &internalCtxData{
DB: c.db,
})
}
func (c *Core) getInternalCtxDataFromCtx(ctx context.Context) *internalCtxData {
if v := ctx.Value(internalCtxDataKeyInCtx); v != nil {
return v.(*internalCtxData)
}
return nil
}

View File

@ -66,9 +66,10 @@ type callShardingHandlerFromCtxInput struct {
}
type callShardingHandlerFromCtxOutput struct {
Sql string
Table string
Schema string
Sql string
Table string
Schema string
ParsedSqlOutput *parseFormattedSqlOutput
}
func (c *Core) callShardingHandlerFromCtx(
@ -80,6 +81,7 @@ func (c *Core) callShardingHandlerFromCtx(
shardingHandler ShardingHandler
ok bool
)
// If no sharding handler, it does nothing.
if ctxValue = ctx.Value(ctxKeyForShardingHandler); ctxValue == nil {
return nil, nil
}
@ -108,9 +110,10 @@ func (c *Core) callShardingHandlerFromCtx(
}
}
out = &callShardingHandlerFromCtxOutput{
Sql: newSql,
Table: shardingOut.Table,
Schema: shardingOut.Schema,
Sql: newSql,
Table: shardingOut.Table,
Schema: shardingOut.Schema,
ParsedSqlOutput: parsedOut,
}
return out, nil
}
@ -143,10 +146,11 @@ func (c *Core) formatSqlWithNewTable(sql, table string) (newSql string, err erro
}
type parseFormattedSqlOutput struct {
Table string
OperationData map[string]Value
ConditionData map[string]Value
ParsedStmt sqlparser.Statement
Table string
OperationData map[string]Value
ConditionData map[string]Value
ParsedStmt sqlparser.Statement
SelectedFields []string
}
func (c *Core) parseFormattedSql(formattedSql string) (*parseFormattedSqlOutput, error) {
@ -154,8 +158,9 @@ func (c *Core) parseFormattedSql(formattedSql string) (*parseFormattedSqlOutput,
condition sqlparser.Expr
err error
out = &parseFormattedSqlOutput{
OperationData: make(map[string]Value),
ConditionData: make(map[string]Value),
SelectedFields: make([]string, 0),
OperationData: make(map[string]Value),
ConditionData: make(map[string]Value),
}
)
out.ParsedStmt, err = sqlparser.NewParser(strings.NewReader(formattedSql)).ParseStatement()
@ -164,18 +169,30 @@ func (c *Core) parseFormattedSql(formattedSql string) (*parseFormattedSqlOutput,
}
switch stmt := out.ParsedStmt.(type) {
case *sqlparser.SelectStatement:
table, ok := stmt.FromItems.(*sqlparser.TableName)
if !ok {
return nil, gerror.Newf(
`invalid table name "%s" in SQL: %s`,
stmt.FromItems.String(), formattedSql,
)
if stmt.FromItems != nil {
table, ok := stmt.FromItems.(*sqlparser.TableName)
if !ok {
return nil, gerror.Newf(
`invalid table name "%s" in SQL: %s`,
stmt.FromItems.String(), formattedSql,
)
}
out.Table = table.TableName()
}
out.Table = table.TableName()
condition = stmt.Condition
if stmt.Columns != nil {
for _, column := range *stmt.Columns {
if column.Alias != nil {
out.SelectedFields = append(out.SelectedFields, column.Alias.Name)
} else if column.Expr != nil {
out.SelectedFields = append(out.SelectedFields, column.Expr.String())
}
}
}
case *sqlparser.InsertStatement:
out.Table = stmt.TableName.TableName()
if len(stmt.Expressions) > 0 {
if len(stmt.Expressions) > 0 && len(stmt.ColumnNames) > 0 {
names := make([]string, len(stmt.ColumnNames))
for i, ident := range stmt.ColumnNames {
names[i] = ident.Name

View File

@ -127,8 +127,44 @@ func (c *Core) DoFilter(ctx context.Context, link Link, sql string, args []inter
return sql, args, nil
}
type sqlParsingHandlerInput struct {
DoCommitInput
FormattedSql string
}
type sqlParsingHandlerOutput struct {
DoCommitInput
}
func (c *Core) sqlParsingHandler(ctx context.Context, in sqlParsingHandlerInput) (out *sqlParsingHandlerOutput, err error) {
var shardingOut *callShardingHandlerFromCtxOutput
// Sharding handling.
shardingOut, err = c.callShardingHandlerFromCtx(ctx, callShardingHandlerFromCtxInput{
Sql: in.Sql,
FormattedSql: in.FormattedSql,
})
if err != nil {
return
}
if shardingOut != nil {
if shardingOut.Sql != "" {
in.Sql = shardingOut.Sql
}
// If schema changes, it here creates and uses a new DB link operation object.
if shardingOut.Schema != c.db.GetSchema() {
in.Link, err = c.db.GetCore().GetLink(ctx, in.Link.IsOnMaster(), shardingOut.Schema)
}
}
out = &sqlParsingHandlerOutput{
DoCommitInput: in.DoCommitInput,
}
return
}
// DoCommit commits current sql and arguments to underlying sql driver.
func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutput, err error) {
// Inject internal data into ctx, just for double check.
ctx = c.injectInternalCtxData(ctx)
var (
sqlTx *sql.Tx
@ -138,30 +174,21 @@ func (c *Core) DoCommit(ctx context.Context, in DoCommitInput) (out DoCommitOutp
stmtSqlRows *sql.Rows
stmtSqlRow *sql.Row
rowsAffected int64
shardingOut *callShardingHandlerFromCtxOutput
cancelFuncForTimeout context.CancelFunc
formattedSql = FormatSqlWithArgs(in.Sql, in.Args)
timestampMilli1 = gtime.TimestampMilli()
)
shardingOut, err = c.callShardingHandlerFromCtx(ctx, callShardingHandlerFromCtxInput{
Sql: in.Sql,
FormattedSql: formattedSql,
// SQL parser handler.
sqlParsingHandlerOut, err := c.sqlParsingHandler(ctx, sqlParsingHandlerInput{
DoCommitInput: in,
FormattedSql: formattedSql,
})
if err != nil {
return
}
// Sharding handling.
if shardingOut != nil {
if shardingOut.Sql != "" {
in.Sql = shardingOut.Sql
}
// If schema changes, it here creates and uses a new DB link operation object.
if shardingOut.Schema != c.db.GetSchema() {
in.Link, err = c.db.GetCore().GetLink(ctx, in.Link.IsOnMaster(), shardingOut.Schema)
if err != nil {
return
}
}
if sqlParsingHandlerOut != nil {
in = sqlParsingHandlerOut.DoCommitInput
}
// Trace span start.
@ -372,6 +399,11 @@ func (c *Core) RowsToResult(ctx context.Context, rows *sql.Rows) (Result, error)
columnTypes[k] = v.DatabaseTypeName()
columnNames[k] = v.Name()
}
if len(columnNames) > 0 {
if internalData := c.getInternalCtxDataFromCtx(ctx); internalData != nil {
internalData.FirstResultColumn = columnNames[0]
}
}
var (
values = make([]interface{}, len(columnNames))
result = make(Result, 0)

View File

@ -523,7 +523,7 @@ func formatWhereHolder(db DB, in formatWhereHolderInput) (newWhere string, newAr
whereStr, _ = gregex.ReplaceStringFunc(`(\?)`, whereStr, func(s string) string {
index++
if i+len(newArgs) == index {
sqlWithHolder, holderArgs := model.getFormattedSqlAndArgs(queryTypeNormal, false)
sqlWithHolder, holderArgs := model.getFormattedSqlAndArgs(model.GetCtx(), queryTypeNormal, false)
newArgs = append(newArgs, holderArgs...)
// Automatically adding the brackets.
return "(" + sqlWithHolder + ")"

View File

@ -39,8 +39,7 @@ type internalParamHook struct {
type internalParamHookSelect struct {
internalParamHook
handler HookFuncSelect
queryType int
handler HookFuncSelect
}
type internalParamHookInsert struct {
@ -111,11 +110,6 @@ func (h *HookSelectInput) Next(ctx context.Context) (result Result, err error) {
return h.model.db.DoSelect(ctx, h.link, h.Sql, h.Args...)
}
// IsCountStatement checks and returns whether current SELECT statement is COUNT statement.
func (h *HookSelectInput) IsCountStatement() bool {
return h.queryType == queryTypeCount
}
// Next calls the next hook handler.
func (h *HookInsertInput) Next(ctx context.Context) (result sql.Result, err error) {
if h.handler != nil && !h.handlerCalled {

View File

@ -7,6 +7,7 @@
package gdb
import (
"context"
"fmt"
"reflect"
@ -29,7 +30,7 @@ import (
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) All(where ...interface{}) (Result, error) {
return m.doGetAll(false, where...)
return m.doGetAll(m.GetCtx(), false, where...)
}
// doGetAll does "SELECT FROM ..." statement for the model.
@ -39,12 +40,12 @@ func (m *Model) All(where ...interface{}) (Result, error) {
// The parameter `limit1` specifies whether limits querying only one record if m.limit is not set.
// The optional parameter `where` is the same as the parameter of Model.Where function,
// see Model.Where.
func (m *Model) doGetAll(limit1 bool, where ...interface{}) (Result, error) {
func (m *Model) doGetAll(ctx context.Context, limit1 bool, where ...interface{}) (Result, error) {
if len(where) > 0 {
return m.Where(where[0], where[1:]...).All()
}
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(queryTypeNormal, limit1)
return m.doGetAllBySql(queryTypeNormal, sqlWithHolder, holderArgs...)
sqlWithHolder, holderArgs := m.getFormattedSqlAndArgs(m.GetCtx(), queryTypeNormal, limit1)
return m.doGetAllBySql(ctx, queryTypeNormal, sqlWithHolder, holderArgs...)
}
// getFieldsFiltered checks the fields and fieldsEx attributes, filters and returns the fields that will
@ -133,7 +134,7 @@ func (m *Model) One(where ...interface{}) (Record, error) {
if len(where) > 0 {
return m.Where(where[0], where[1:]...).One()
}
all, err := m.doGetAll(true)
all, err := m.doGetAll(m.GetCtx(), true)
if err != nil {
return nil, err
}
@ -159,14 +160,24 @@ func (m *Model) Value(fieldsAndWhere ...interface{}) (Value, error) {
return m.Fields(gconv.String(fieldsAndWhere[0])).Value()
}
}
one, err := m.One()
if err != nil {
return gvar.New(nil), err
var (
all Result
err error
ctx = m.GetCtx()
)
if all, err = m.doGetAll(ctx, true); err != nil {
return nil, err
}
for _, v := range one {
return v, nil
if len(all) == 0 {
return gvar.New(nil), nil
}
return gvar.New(nil), nil
if internalData := m.db.GetCore().getInternalCtxDataFromCtx(ctx); internalData != nil {
record := all[0]
if v, ok := record[internalData.FirstResultColumn]; ok {
return v, nil
}
}
return nil, gerror.NewCode(gcode.CodeInternalError, `query value error`)
}
// Array queries and returns data values as slice from database.
@ -366,16 +377,21 @@ func (m *Model) Count(where ...interface{}) (int, error) {
return m.Where(where[0], where[1:]...).Count()
}
var (
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(queryTypeCount, false)
list, err = m.doGetAllBySql(queryTypeCount, sqlWithHolder, holderArgs...)
ctx = m.GetCtx()
sqlWithHolder, holderArgs = m.getFormattedSqlAndArgs(ctx, queryTypeCount, false)
all, err = m.doGetAllBySql(ctx, queryTypeCount, sqlWithHolder, holderArgs...)
)
if err != nil {
return 0, err
}
if len(list) > 0 {
for _, v := range list[0] {
return v.Int(), nil
if len(all) > 0 {
if internalData := m.db.GetCore().getInternalCtxDataFromCtx(ctx); internalData != nil {
record := all[0]
if v, ok := record[internalData.FirstResultColumn]; ok {
return v.Int(), nil
}
}
return 0, gerror.NewCode(gcode.CodeInternalError, `query count error`)
}
return 0, nil
}
@ -502,10 +518,9 @@ func (m *Model) Having(having interface{}, args ...interface{}) *Model {
}
// doGetAllBySql does the select statement on the database.
func (m *Model) doGetAllBySql(queryType int, sql string, args ...interface{}) (result Result, err error) {
func (m *Model) doGetAllBySql(ctx context.Context, queryType int, sql string, args ...interface{}) (result Result, err error) {
var (
ok bool
ctx = m.GetCtx()
cacheKey = ""
cacheObj = m.db.GetCache()
)
@ -539,14 +554,13 @@ func (m *Model) doGetAllBySql(queryType int, sql string, args ...interface{}) (r
link: m.getLink(false),
model: m,
},
handler: m.hookHandler.Select,
queryType: queryType,
handler: m.hookHandler.Select,
},
Table: m.tables,
Sql: sql,
Args: m.mergeArguments(args),
}
result, err = in.Next(m.GetCtx())
result, err = in.Next(ctx)
// Cache the result.
if cacheKey != "" && err == nil {
@ -567,22 +581,22 @@ func (m *Model) doGetAllBySql(queryType int, sql string, args ...interface{}) (r
return result, err
}
func (m *Model) getFormattedSqlAndArgs(queryType int, limit1 bool) (sqlWithHolder string, holderArgs []interface{}) {
func (m *Model) getFormattedSqlAndArgs(ctx context.Context, queryType int, limit1 bool) (sqlWithHolder string, holderArgs []interface{}) {
switch queryType {
case queryTypeCount:
countFields := "COUNT(1)"
queryFields := "COUNT(1)"
if m.fields != "" && m.fields != "*" {
// DO NOT quote the m.fields here, in case of fields like:
// DISTINCT t.user_id uid
countFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.fields)
queryFields = fmt.Sprintf(`COUNT(%s%s)`, m.distinct, m.fields)
}
// Raw SQL Model.
if m.rawSql != "" {
sqlWithHolder = fmt.Sprintf("SELECT %s FROM (%s) AS T", countFields, m.rawSql)
sqlWithHolder = fmt.Sprintf("SELECT %s FROM (%s) AS T", queryFields, m.rawSql)
return sqlWithHolder, nil
}
conditionWhere, conditionExtra, conditionArgs := m.formatCondition(false, true)
sqlWithHolder = fmt.Sprintf("SELECT %s FROM %s%s", countFields, m.tables, conditionWhere+conditionExtra)
sqlWithHolder = fmt.Sprintf("SELECT %s FROM %s%s", queryFields, m.tables, conditionWhere+conditionExtra)
if len(m.groupBy) > 0 {
sqlWithHolder = fmt.Sprintf("SELECT COUNT(1) FROM (%s) count_alias", sqlWithHolder)
}
@ -603,10 +617,7 @@ func (m *Model) getFormattedSqlAndArgs(queryType int, limit1 bool) (sqlWithHolde
// DISTINCT t.user_id uid
sqlWithHolder = fmt.Sprintf(
"SELECT %s%s FROM %s%s",
m.distinct,
m.getFieldsFiltered(),
m.tables,
conditionWhere+conditionExtra,
m.distinct, m.getFieldsFiltered(), m.tables, conditionWhere+conditionExtra,
)
return sqlWithHolder, conditionArgs
}