fix issue of error code lost in middleware handling for package ghttp

This commit is contained in:
John Guo
2020-12-30 12:56:24 +08:00
parent 820befa1a0
commit bdf23ef48f
4 changed files with 101 additions and 36 deletions

View File

@ -48,9 +48,12 @@ func (c *Core) Ctx(ctx context.Context) DB {
}
// GetCtx returns the context for current DB.
// Note that it might be nil.
// It returns `context.Background()` is there's no context previously set.
func (c *Core) GetCtx() context.Context {
return c.ctx
if c.ctx != nil {
return c.ctx
}
return context.Background()
}
// Master creates and returns a connection from master node if master-slave configured.
@ -78,15 +81,11 @@ func (c *Core) Query(sql string, args ...interface{}) (rows *sql.Rows, err error
// DoQuery commits the sql string and its arguments to underlying driver
// through given link object and returns the execution result.
func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Rows, err error) {
ctx := c.DB.GetCtx()
if ctx == nil {
ctx = context.Background()
}
sql, args = formatSql(sql, args)
sql, args = c.DB.HandleSqlBeforeCommit(link, sql, args)
if c.DB.GetDebug() {
mTime1 := gtime.TimestampMilli()
rows, err = link.QueryContext(ctx, sql, args...)
rows, err = link.QueryContext(c.DB.GetCtx(), sql, args...)
mTime2 := gtime.TimestampMilli()
s := &Sql{
Sql: sql,
@ -99,7 +98,7 @@ func (c *Core) DoQuery(link Link, sql string, args ...interface{}) (rows *sql.Ro
}
c.writeSqlToLogger(s)
} else {
rows, err = link.QueryContext(ctx, sql, args...)
rows, err = link.QueryContext(c.DB.GetCtx(), sql, args...)
}
if err == nil {
return rows, nil
@ -122,16 +121,12 @@ func (c *Core) Exec(sql string, args ...interface{}) (result sql.Result, err err
// DoExec commits the sql string and its arguments to underlying driver
// through given link object and returns the execution result.
func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Result, err error) {
ctx := c.DB.GetCtx()
if ctx == nil {
ctx = context.Background()
}
sql, args = formatSql(sql, args)
sql, args = c.DB.HandleSqlBeforeCommit(link, sql, args)
if c.DB.GetDebug() {
mTime1 := gtime.TimestampMilli()
if !c.DB.GetDryRun() {
result, err = link.ExecContext(ctx, sql, args...)
result, err = link.ExecContext(c.DB.GetCtx(), sql, args...)
} else {
result = new(SqlResult)
}
@ -148,7 +143,7 @@ func (c *Core) DoExec(link Link, sql string, args ...interface{}) (result sql.Re
c.writeSqlToLogger(s)
} else {
if !c.DB.GetDryRun() {
result, err = link.ExecContext(ctx, sql, args...)
result, err = link.ExecContext(c.DB.GetCtx(), sql, args...)
} else {
result = new(SqlResult)
}
@ -183,11 +178,7 @@ func (c *Core) Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error) {
// doPrepare calls prepare function on given link object and returns the statement object.
func (c *Core) DoPrepare(link Link, sql string) (*sql.Stmt, error) {
ctx := c.DB.GetCtx()
if ctx == nil {
ctx = context.Background()
}
return link.PrepareContext(ctx, sql)
return link.PrepareContext(c.DB.GetCtx(), sql)
}
// GetAll queries and returns data records from database.

View File

@ -2974,23 +2974,48 @@ func Test_Model_Issue1002(t *testing.T) {
t.Assert(v.Int(), 1)
})
// where + time.Time arguments, UTC.
t1, _ := time.Parse("2006-01-02 15:04:05", "2020-10-27 19:03:32")
t2, _ := time.Parse("2006-01-02 15:04:05", "2020-10-27 19:03:34")
gtest.C(t, func(t *gtest.T) {
v, err := db.Table(table).Fields("id").Where("create_time>? and create_time<?", t1, t2).Value()
t.Assert(err, nil)
t.Assert(v.Int(), 1)
})
gtest.C(t, func(t *gtest.T) {
v, err := db.Table(table).Fields("id").Where("create_time>? and create_time<?", t1, t2).FindValue()
t.Assert(err, nil)
t.Assert(v.Int(), 1)
})
gtest.C(t, func(t *gtest.T) {
v, err := db.Table(table).Where("create_time>? and create_time<?", t1, t2).FindValue("id")
t.Assert(err, nil)
t.Assert(v.Int(), 1)
t1, _ := time.Parse("2006-01-02 15:04:05", "2020-10-27 19:03:32")
t2, _ := time.Parse("2006-01-02 15:04:05", "2020-10-27 19:03:34")
{
v, err := db.Table(table).Fields("id").Where("create_time>? and create_time<?", t1, t2).Value()
t.Assert(err, nil)
t.Assert(v.Int(), 1)
}
{
v, err := db.Table(table).Fields("id").Where("create_time>? and create_time<?", t1, t2).FindValue()
t.Assert(err, nil)
t.Assert(v.Int(), 1)
}
{
v, err := db.Table(table).Where("create_time>? and create_time<?", t1, t2).FindValue("id")
t.Assert(err, nil)
t.Assert(v.Int(), 1)
}
})
// where + time.Time arguments, +8.
//gtest.C(t, func(t *gtest.T) {
// // Change current timezone to +8 zone.
// location, err := time.LoadLocation("Asia/Shanghai")
// t.Assert(err, nil)
// t1, _ := time.ParseInLocation("2006-01-02 15:04:05", "2020-10-27 19:03:32", location)
// t2, _ := time.ParseInLocation("2006-01-02 15:04:05", "2020-10-27 19:03:34", location)
// {
// v, err := db.Table(table).Fields("id").Where("create_time>? and create_time<?", t1, t2).Value()
// t.Assert(err, nil)
// t.Assert(v.Int(), 1)
// }
// {
// v, err := db.Table(table).Fields("id").Where("create_time>? and create_time<?", t1, t2).FindValue()
// t.Assert(err, nil)
// t.Assert(v.Int(), 1)
// }
// {
// v, err := db.Table(table).Where("create_time>? and create_time<?", t1, t2).FindValue("id")
// t.Assert(err, nil)
// t.Assert(v.Int(), 1)
// }
//})
}
func createTableForTimeZoneTest() string {

View File

@ -0,0 +1,45 @@
// Copyright GoFrame Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
// static service testing.
package ghttp_test
import (
"fmt"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/net/ghttp"
"testing"
"time"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/test/gtest"
)
func Test_Error_Code(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
p, _ := ports.PopRand()
s := g.Server(p)
s.Group("/", func(group *ghttp.RouterGroup) {
group.Middleware(func(r *ghttp.Request) {
r.Middleware.Next()
r.Response.ClearBuffer()
r.Response.Write(gerror.Code(r.GetError()))
})
group.ALL("/", func(r *ghttp.Request) {
panic(gerror.NewCode(10000, "test error"))
})
})
s.SetPort(p)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
c := g.Client()
c.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
t.Assert(c.GetContent("/"), "10000")
})
}

View File

@ -35,8 +35,12 @@ func Try(try func()) (err error) {
// It automatically calls function <catch> if any exception occurs ans passes the exception as an error.
func TryCatch(try func(), catch ...func(exception error)) {
defer func() {
if e := recover(); e != nil && len(catch) > 0 {
catch[0](fmt.Errorf(`%v`, e))
if exception := recover(); exception != nil && len(catch) > 0 {
if err, ok := exception.(error); ok {
catch[0](err)
} else {
catch[0](fmt.Errorf(`%v`, exception))
}
}
}()
try()