improve middleware and error logging for ghttp.Server

This commit is contained in:
John
2019-09-14 22:53:28 +08:00
parent 41a0b52939
commit 966c93af00
14 changed files with 251 additions and 65 deletions

View File

@ -0,0 +1,34 @@
package main
import (
"net/http"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
)
func MiddlewareAuth(r *ghttp.Request) {
token := r.Get("token")
if token == "123456" {
r.Middleware.Next()
} else {
r.Response.WriteStatus(http.StatusForbidden)
}
}
func MiddlewareCORS(r *ghttp.Request) {
r.Response.CORSDefault()
r.Middleware.Next()
}
func main() {
s := g.Server()
s.Group("/api.v2", func(g *ghttp.RouterGroup) {
g.Middleware(MiddlewareAuth, MiddlewareCORS)
g.ALL("/user/list", func(r *ghttp.Request) {
r.Response.Write("list")
})
})
s.SetPort(8199)
s.Run()
}

View File

@ -0,0 +1,23 @@
package main
import (
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
)
func MiddlewareCORS(r *ghttp.Request) {
r.Response.CORSDefault()
r.Middleware.Next()
}
func main() {
s := g.Server()
s.Group("/api.v2", func(g *ghttp.RouterGroup) {
g.Middleware(MiddlewareCORS)
g.ALL("/user/list", func(r *ghttp.Request) {
r.Response.Write("list")
})
})
s.SetPort(8199)
s.Run()
}

View File

@ -0,0 +1,44 @@
package main
import (
"net/http"
"github.com/gogf/gf/os/glog"
"github.com/gogf/gf/frame/g"
"github.com/gogf/gf/net/ghttp"
)
func MiddlewareAuth(r *ghttp.Request) {
token := r.Get("token")
if token == "123456" {
r.Middleware.Next()
} else {
r.Response.WriteStatus(http.StatusForbidden)
}
}
func MiddlewareCORS(r *ghttp.Request) {
r.Response.CORSDefault()
r.Middleware.Next()
}
func MiddlewareLog(r *ghttp.Request) {
r.Middleware.Next()
glog.Println(r.Response.Status, r.URL.Path)
}
func main() {
s := g.Server()
s.Group("/", func(g *ghttp.RouterGroup) {
g.Middleware(MiddlewareLog)
})
s.Group("/api.v2", func(g *ghttp.RouterGroup) {
g.Middleware(MiddlewareAuth, MiddlewareCORS)
g.ALL("/user/list", func(r *ghttp.Request) {
panic("custom error")
})
})
s.SetPort(8199)
s.Run()
}

View File

@ -40,6 +40,7 @@
## 功能改进
1. `ghttp`
- 当`WebServer`产生`panic`异常错误时,默认打印调用链堆栈到错误日志中;
- `Cookie`及`Session`的`TTL`配置数据类型修改为`time.Duration`;
- 新增允许同时通过`Header/Cookie`传递`SessionId`
- 新增`ConfigFromMap/SetConfigWithMap`方法,支持通过`map`参数设置WebServer

View File

@ -40,6 +40,7 @@ type Request struct {
parsedPost bool // POST参数是否已经解析
queryVars map[string][]string // GET参数
routerVars map[string][]string // 路由解析参数
error error // 当前请求执行错误
exit bool // 是否退出当前请求流程执行
params map[string]interface{} // 开发者自定义参数(请求流程中有效)
parsedHost string // 解析过后不带端口号的服务器域名名称
@ -92,12 +93,8 @@ func (r *Request) GetVar(key string, def ...interface{}) *gvar.Var {
// 获取原始请求输入二进制。
func (r *Request) GetRaw() []byte {
err := error(nil)
if r.rawContent == nil {
r.rawContent, err = ioutil.ReadAll(r.Body)
if err != nil {
r.Error("error reading request body: ", err)
}
r.rawContent, _ = ioutil.ReadAll(r.Body)
}
return r.rawContent
}

View File

@ -1,14 +0,0 @@
// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package ghttp
import "fmt"
// 打印error日志
func (r *Request) Error(value ...interface{}) {
r.Server.handleErrorLog(fmt.Sprint(value...), r)
}

View File

@ -6,7 +6,14 @@
package ghttp
import "reflect"
import (
"net/http"
"reflect"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/util/gutil"
)
// 中间件对象
type Middleware struct {
@ -17,7 +24,8 @@ type Middleware struct {
// 执行下一个请求流程处理函数
func (m *Middleware) Next() {
item := (*handlerParsedItem)(nil)
for {
loop := true
for loop {
// 是否停止请求执行
if m.request.IsExited() || m.request.handlerIndex >= len(m.request.handlers) {
return
@ -34,49 +42,69 @@ func (m *Middleware) Next() {
}
m.request.Router = item.handler.router
// 执行函数处理
switch item.handler.itemType {
case gHANDLER_TYPE_CONTROLLER:
m.served = true
c := reflect.New(item.handler.ctrlInfo.reflect)
niceCallFunc(func() {
c.MethodByName("Init").Call([]reflect.Value{reflect.ValueOf(m.request)})
})
if !m.request.IsExited() {
gutil.TryCatch(func() {
switch item.handler.itemType {
case gHANDLER_TYPE_CONTROLLER:
m.served = true
if m.request.IsExited() {
break
}
c := reflect.New(item.handler.ctrlInfo.reflect)
niceCallFunc(func() {
c.MethodByName(item.handler.ctrlInfo.name).Call(nil)
c.MethodByName("Init").Call([]reflect.Value{reflect.ValueOf(m.request)})
})
}
if !m.request.IsExited() {
niceCallFunc(func() {
c.MethodByName("Shut").Call(nil)
})
}
case gHANDLER_TYPE_OBJECT:
m.served = true
if item.handler.initFunc != nil {
niceCallFunc(func() {
item.handler.initFunc(m.request)
})
}
if !m.request.IsExited() {
if !m.request.IsExited() {
niceCallFunc(func() {
c.MethodByName(item.handler.ctrlInfo.name).Call(nil)
})
}
if !m.request.IsExited() {
niceCallFunc(func() {
c.MethodByName("Shut").Call(nil)
})
}
case gHANDLER_TYPE_OBJECT:
m.served = true
if m.request.IsExited() {
break
}
if item.handler.initFunc != nil {
niceCallFunc(func() {
item.handler.initFunc(m.request)
})
}
if !m.request.IsExited() {
niceCallFunc(func() {
item.handler.itemFunc(m.request)
})
}
if !m.request.IsExited() && item.handler.shutFunc != nil {
niceCallFunc(func() {
item.handler.shutFunc(m.request)
})
}
case gHANDLER_TYPE_HANDLER:
m.served = true
if m.request.IsExited() {
break
}
niceCallFunc(func() {
item.handler.itemFunc(m.request)
})
}
if !m.request.IsExited() && item.handler.shutFunc != nil {
case gHANDLER_TYPE_MIDDLEWARE:
niceCallFunc(func() {
item.handler.shutFunc(m.request)
item.handler.itemFunc(m.request)
})
// 中间件默认不会进一步执行,
// 需要内部调用Next方法决定是否进一步执行以便于请求流程控制。
loop = false
}
case gHANDLER_TYPE_HANDLER:
m.served = true
niceCallFunc(func() {
item.handler.itemFunc(m.request)
})
case gHANDLER_TYPE_MIDDLEWARE:
niceCallFunc(func() {
item.handler.itemFunc(m.request)
})
}
}, func(exception interface{}) {
m.request.error = gerror.Newf("%v", exception)
m.request.Response.WriteStatus(http.StatusInternalServerError, exception)
})
}
}

View File

@ -47,6 +47,9 @@ func (r *Response) Write(content ...interface{}) {
if len(content) == 0 {
return
}
if r.Status == 0 {
r.Status = http.StatusOK
}
for _, v := range content {
switch value := v.(type) {
case []byte:

View File

@ -34,7 +34,7 @@ const (
)
// 自定义日志处理方法类型
type LogHandler func(r *Request, error ...interface{})
type LogHandler func(r *Request, err ...error)
// HTTP Server 设置结构体,静态配置
type ServerConfig struct {
@ -70,6 +70,7 @@ type ServerConfig struct {
LogPath string // Logging: 存放日志的目录路径(默认为空,表示不写文件)
LogHandler LogHandler // Logging: 日志配置: 自定义日志处理回调方法(默认为空)
LogStdout bool // Logging: 是否打印日志到终端(默认开启)
ErrorStack bool // Logging: 当产生错误时打印调用链详细堆栈
ErrorLogEnabled bool // Logging: 是否开启error log(默认开启)
AccessLogEnabled bool // Logging: 是否开启access log(默认关闭)
NameToUriType int // Mess: 服务注册时对象和方法名称转换为URI时的规则
@ -100,6 +101,7 @@ var defaultServerConfig = ServerConfig{
SessionMaxAge: time.Hour * 24,
SessionIdName: "gfsessionid",
LogStdout: true,
ErrorStack: true,
ErrorLogEnabled: true,
AccessLogEnabled: false,
DumpRouteMap: true,

View File

@ -54,6 +54,15 @@ func (s *Server) SetErrorLogEnabled(enabled bool) {
s.config.ErrorLogEnabled = enabled
}
// 设置是否开启error stack打印功能
func (s *Server) SetErrorStack(enabled bool) {
if s.Status() == SERVER_STATUS_RUNNING {
glog.Error(gCHANGE_CONFIG_WHILE_RUNNING_ERROR)
return
}
s.config.ErrorStack = enabled
}
// 设置日志写入的回调函数
func (s *Server) SetLogHandler(handler LogHandler) {
if s.Status() == SERVER_STATUS_RUNNING {

View File

@ -12,6 +12,8 @@ import (
"sort"
"strings"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/os/gres"
"github.com/gogf/gf/encoding/ghtml"
@ -70,11 +72,17 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
request.Response.WriteStatus(http.StatusNotFound)
}
}
// error log
if e := recover(); e != nil {
request.Response.WriteStatus(http.StatusInternalServerError)
s.handleErrorLog(e, request)
if request.error != nil {
s.handleErrorLog(request.error, request)
} else {
if exception := recover(); exception != nil {
request.Response.WriteStatus(http.StatusInternalServerError)
s.handleErrorLog(gerror.Newf("%v", exception), request)
}
}
// access log
s.handleAccessLog(request)
}()

View File

@ -9,6 +9,8 @@ package ghttp
import (
"fmt"
"github.com/gogf/gf/errors/gerror"
"github.com/gogf/gf/os/gtime"
)
@ -40,7 +42,7 @@ func (s *Server) handleAccessLog(r *Request) {
}
// 处理服务错误信息主要是panichttp请求的status由access log进行管理
func (s *Server) handleErrorLog(error interface{}, r *Request) {
func (s *Server) handleErrorLog(err error, r *Request) {
// 错误输出默认是开启的
if !s.IsErrorLogEnabled() {
return
@ -48,7 +50,7 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) {
// 自定义错误处理
if v := s.GetLogHandler(); v != nil {
v(r, error)
v(r, err)
return
}
@ -57,12 +59,17 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) {
if r.TLS != nil {
scheme = "https"
}
content := fmt.Sprintf(`%v, "%s %s %s %s %s"`, error, r.Method, scheme, r.Host, r.URL.String(), r.Proto)
content := fmt.Sprintf(`%v, "%s %s %s %s %s"`, err, r.Method, scheme, r.Host, r.URL.String(), r.Proto)
if r.LeaveTime > r.EnterTime {
content += fmt.Sprintf(` %.3f`, float64(r.LeaveTime-r.EnterTime)/1000)
} else {
content += fmt.Sprintf(` %.3f`, float64(gtime.Microsecond()-r.EnterTime)/1000)
}
content += fmt.Sprintf(`, %s, "%s", "%s"`, r.GetClientIp(), r.Referer(), r.UserAgent())
s.logger.Cat("error").StackWithFilter(gPATH_FILTER_KEY).Stdout(s.config.LogStdout).Error(content)
if s.config.ErrorStack {
if stack := gerror.Stack(err); stack != "" {
content += "\n" + stack
}
}
s.logger.Cat("error").Stack(false).Stdout(s.config.LogStdout).Error(content)
}

View File

@ -56,6 +56,10 @@ func (s *Server) handlePreBindItems() {
// 获取分组路由对象
func (s *Server) Group(prefix string, groups ...func(g *RouterGroup)) *RouterGroup {
// 自动识别并加上/前缀
if prefix[0] != '/' {
prefix = "/" + prefix
}
if prefix == "/" {
prefix = ""
}

View File

@ -8,6 +8,7 @@ package ghttp_test
import (
"fmt"
"net/http"
"testing"
"time"
@ -409,3 +410,42 @@ func Test_Hook_Middleware_Basic1(t *testing.T) {
gtest.Assert(client.GetContent("/test/test"), "ac13test42bd")
})
}
func MiddlewareAuth(r *ghttp.Request) {
token := r.Get("token")
if token == "123456" {
r.Middleware.Next()
} else {
r.Response.WriteStatus(http.StatusForbidden)
}
}
func MiddlewareCORS(r *ghttp.Request) {
r.Response.CORSDefault()
r.Middleware.Next()
}
func Test_Middleware_CORSAndAuth(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
s.Group("/api.v2", func(g *ghttp.RouterGroup) {
g.Middleware(MiddlewareAuth, MiddlewareCORS)
g.ALL("/user/list", func(r *ghttp.Request) {
r.Response.Write("list")
})
})
s.SetPort(p)
s.SetDumpRouteMap(false)
s.Start()
defer s.Shutdown()
time.Sleep(200 * time.Millisecond)
gtest.Case(t, func() {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
gtest.Assert(client.GetContent("/"), "Not Found")
gtest.Assert(client.GetContent("/api.v2"), "Not Found")
gtest.Assert(client.GetContent("/api.v2/user/list"), "Forbidden")
gtest.Assert(client.GetContent("/api.v2/user/list", "token=123456"), "list")
})
}