mirror of
https://gitee.com/johng/gf
synced 2026-06-06 02:25:47 +08:00
improve middleware and error logging for ghttp.Server
This commit is contained in:
34
.example/net/ghttp/server/middleware/auth.go
Normal file
34
.example/net/ghttp/server/middleware/auth.go
Normal file
@ -0,0 +1,34 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
)
|
||||
|
||||
func MiddlewareAuth(r *ghttp.Request) {
|
||||
token := r.Get("token")
|
||||
if token == "123456" {
|
||||
r.Middleware.Next()
|
||||
} else {
|
||||
r.Response.WriteStatus(http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func MiddlewareCORS(r *ghttp.Request) {
|
||||
r.Response.CORSDefault()
|
||||
r.Middleware.Next()
|
||||
}
|
||||
|
||||
func main() {
|
||||
s := g.Server()
|
||||
s.Group("/api.v2", func(g *ghttp.RouterGroup) {
|
||||
g.Middleware(MiddlewareAuth, MiddlewareCORS)
|
||||
g.ALL("/user/list", func(r *ghttp.Request) {
|
||||
r.Response.Write("list")
|
||||
})
|
||||
})
|
||||
s.SetPort(8199)
|
||||
s.Run()
|
||||
}
|
||||
23
.example/net/ghttp/server/middleware/cors.go
Normal file
23
.example/net/ghttp/server/middleware/cors.go
Normal file
@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
)
|
||||
|
||||
func MiddlewareCORS(r *ghttp.Request) {
|
||||
r.Response.CORSDefault()
|
||||
r.Middleware.Next()
|
||||
}
|
||||
|
||||
func main() {
|
||||
s := g.Server()
|
||||
s.Group("/api.v2", func(g *ghttp.RouterGroup) {
|
||||
g.Middleware(MiddlewareCORS)
|
||||
g.ALL("/user/list", func(r *ghttp.Request) {
|
||||
r.Response.Write("list")
|
||||
})
|
||||
})
|
||||
s.SetPort(8199)
|
||||
s.Run()
|
||||
}
|
||||
44
.example/net/ghttp/server/middleware/log.go
Normal file
44
.example/net/ghttp/server/middleware/log.go
Normal file
@ -0,0 +1,44 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gogf/gf/os/glog"
|
||||
|
||||
"github.com/gogf/gf/frame/g"
|
||||
"github.com/gogf/gf/net/ghttp"
|
||||
)
|
||||
|
||||
func MiddlewareAuth(r *ghttp.Request) {
|
||||
token := r.Get("token")
|
||||
if token == "123456" {
|
||||
r.Middleware.Next()
|
||||
} else {
|
||||
r.Response.WriteStatus(http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func MiddlewareCORS(r *ghttp.Request) {
|
||||
r.Response.CORSDefault()
|
||||
r.Middleware.Next()
|
||||
}
|
||||
|
||||
func MiddlewareLog(r *ghttp.Request) {
|
||||
r.Middleware.Next()
|
||||
glog.Println(r.Response.Status, r.URL.Path)
|
||||
}
|
||||
|
||||
func main() {
|
||||
s := g.Server()
|
||||
s.Group("/", func(g *ghttp.RouterGroup) {
|
||||
g.Middleware(MiddlewareLog)
|
||||
})
|
||||
s.Group("/api.v2", func(g *ghttp.RouterGroup) {
|
||||
g.Middleware(MiddlewareAuth, MiddlewareCORS)
|
||||
g.ALL("/user/list", func(r *ghttp.Request) {
|
||||
panic("custom error")
|
||||
})
|
||||
})
|
||||
s.SetPort(8199)
|
||||
s.Run()
|
||||
}
|
||||
@ -40,6 +40,7 @@
|
||||
## 功能改进
|
||||
|
||||
1. `ghttp`
|
||||
- 当`WebServer`产生`panic`异常错误时,默认打印调用链堆栈到错误日志中;
|
||||
- `Cookie`及`Session`的`TTL`配置数据类型修改为`time.Duration`;
|
||||
- 新增允许同时通过`Header/Cookie`传递`SessionId`;
|
||||
- 新增`ConfigFromMap/SetConfigWithMap`方法,支持通过`map`参数设置WebServer;
|
||||
|
||||
@ -40,6 +40,7 @@ type Request struct {
|
||||
parsedPost bool // POST参数是否已经解析
|
||||
queryVars map[string][]string // GET参数
|
||||
routerVars map[string][]string // 路由解析参数
|
||||
error error // 当前请求执行错误
|
||||
exit bool // 是否退出当前请求流程执行
|
||||
params map[string]interface{} // 开发者自定义参数(请求流程中有效)
|
||||
parsedHost string // 解析过后不带端口号的服务器域名名称
|
||||
@ -92,12 +93,8 @@ func (r *Request) GetVar(key string, def ...interface{}) *gvar.Var {
|
||||
|
||||
// 获取原始请求输入二进制。
|
||||
func (r *Request) GetRaw() []byte {
|
||||
err := error(nil)
|
||||
if r.rawContent == nil {
|
||||
r.rawContent, err = ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
r.Error("error reading request body: ", err)
|
||||
}
|
||||
r.rawContent, _ = ioutil.ReadAll(r.Body)
|
||||
}
|
||||
return r.rawContent
|
||||
}
|
||||
|
||||
@ -1,14 +0,0 @@
|
||||
// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the MIT License.
|
||||
// If a copy of the MIT was not distributed with this file,
|
||||
// You can obtain one at https://github.com/gogf/gf.
|
||||
|
||||
package ghttp
|
||||
|
||||
import "fmt"
|
||||
|
||||
// 打印error日志
|
||||
func (r *Request) Error(value ...interface{}) {
|
||||
r.Server.handleErrorLog(fmt.Sprint(value...), r)
|
||||
}
|
||||
@ -6,7 +6,14 @@
|
||||
|
||||
package ghttp
|
||||
|
||||
import "reflect"
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
|
||||
"github.com/gogf/gf/errors/gerror"
|
||||
|
||||
"github.com/gogf/gf/util/gutil"
|
||||
)
|
||||
|
||||
// 中间件对象
|
||||
type Middleware struct {
|
||||
@ -17,7 +24,8 @@ type Middleware struct {
|
||||
// 执行下一个请求流程处理函数
|
||||
func (m *Middleware) Next() {
|
||||
item := (*handlerParsedItem)(nil)
|
||||
for {
|
||||
loop := true
|
||||
for loop {
|
||||
// 是否停止请求执行
|
||||
if m.request.IsExited() || m.request.handlerIndex >= len(m.request.handlers) {
|
||||
return
|
||||
@ -34,49 +42,69 @@ func (m *Middleware) Next() {
|
||||
}
|
||||
m.request.Router = item.handler.router
|
||||
// 执行函数处理
|
||||
switch item.handler.itemType {
|
||||
case gHANDLER_TYPE_CONTROLLER:
|
||||
m.served = true
|
||||
c := reflect.New(item.handler.ctrlInfo.reflect)
|
||||
niceCallFunc(func() {
|
||||
c.MethodByName("Init").Call([]reflect.Value{reflect.ValueOf(m.request)})
|
||||
})
|
||||
if !m.request.IsExited() {
|
||||
gutil.TryCatch(func() {
|
||||
switch item.handler.itemType {
|
||||
case gHANDLER_TYPE_CONTROLLER:
|
||||
m.served = true
|
||||
if m.request.IsExited() {
|
||||
break
|
||||
}
|
||||
c := reflect.New(item.handler.ctrlInfo.reflect)
|
||||
niceCallFunc(func() {
|
||||
c.MethodByName(item.handler.ctrlInfo.name).Call(nil)
|
||||
c.MethodByName("Init").Call([]reflect.Value{reflect.ValueOf(m.request)})
|
||||
})
|
||||
}
|
||||
if !m.request.IsExited() {
|
||||
niceCallFunc(func() {
|
||||
c.MethodByName("Shut").Call(nil)
|
||||
})
|
||||
}
|
||||
case gHANDLER_TYPE_OBJECT:
|
||||
m.served = true
|
||||
if item.handler.initFunc != nil {
|
||||
niceCallFunc(func() {
|
||||
item.handler.initFunc(m.request)
|
||||
})
|
||||
}
|
||||
if !m.request.IsExited() {
|
||||
if !m.request.IsExited() {
|
||||
niceCallFunc(func() {
|
||||
c.MethodByName(item.handler.ctrlInfo.name).Call(nil)
|
||||
})
|
||||
}
|
||||
if !m.request.IsExited() {
|
||||
niceCallFunc(func() {
|
||||
c.MethodByName("Shut").Call(nil)
|
||||
})
|
||||
}
|
||||
|
||||
case gHANDLER_TYPE_OBJECT:
|
||||
m.served = true
|
||||
if m.request.IsExited() {
|
||||
break
|
||||
}
|
||||
if item.handler.initFunc != nil {
|
||||
niceCallFunc(func() {
|
||||
item.handler.initFunc(m.request)
|
||||
})
|
||||
}
|
||||
if !m.request.IsExited() {
|
||||
niceCallFunc(func() {
|
||||
item.handler.itemFunc(m.request)
|
||||
})
|
||||
}
|
||||
if !m.request.IsExited() && item.handler.shutFunc != nil {
|
||||
niceCallFunc(func() {
|
||||
item.handler.shutFunc(m.request)
|
||||
})
|
||||
}
|
||||
|
||||
case gHANDLER_TYPE_HANDLER:
|
||||
m.served = true
|
||||
if m.request.IsExited() {
|
||||
break
|
||||
}
|
||||
niceCallFunc(func() {
|
||||
item.handler.itemFunc(m.request)
|
||||
})
|
||||
}
|
||||
if !m.request.IsExited() && item.handler.shutFunc != nil {
|
||||
|
||||
case gHANDLER_TYPE_MIDDLEWARE:
|
||||
niceCallFunc(func() {
|
||||
item.handler.shutFunc(m.request)
|
||||
item.handler.itemFunc(m.request)
|
||||
})
|
||||
// 中间件默认不会进一步执行,
|
||||
// 需要内部调用Next方法决定是否进一步执行,以便于请求流程控制。
|
||||
loop = false
|
||||
}
|
||||
case gHANDLER_TYPE_HANDLER:
|
||||
m.served = true
|
||||
niceCallFunc(func() {
|
||||
item.handler.itemFunc(m.request)
|
||||
})
|
||||
case gHANDLER_TYPE_MIDDLEWARE:
|
||||
niceCallFunc(func() {
|
||||
item.handler.itemFunc(m.request)
|
||||
})
|
||||
}
|
||||
}, func(exception interface{}) {
|
||||
m.request.error = gerror.Newf("%v", exception)
|
||||
m.request.Response.WriteStatus(http.StatusInternalServerError, exception)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,6 +47,9 @@ func (r *Response) Write(content ...interface{}) {
|
||||
if len(content) == 0 {
|
||||
return
|
||||
}
|
||||
if r.Status == 0 {
|
||||
r.Status = http.StatusOK
|
||||
}
|
||||
for _, v := range content {
|
||||
switch value := v.(type) {
|
||||
case []byte:
|
||||
|
||||
@ -34,7 +34,7 @@ const (
|
||||
)
|
||||
|
||||
// 自定义日志处理方法类型
|
||||
type LogHandler func(r *Request, error ...interface{})
|
||||
type LogHandler func(r *Request, err ...error)
|
||||
|
||||
// HTTP Server 设置结构体,静态配置
|
||||
type ServerConfig struct {
|
||||
@ -70,6 +70,7 @@ type ServerConfig struct {
|
||||
LogPath string // Logging: 存放日志的目录路径(默认为空,表示不写文件)
|
||||
LogHandler LogHandler // Logging: 日志配置: 自定义日志处理回调方法(默认为空)
|
||||
LogStdout bool // Logging: 是否打印日志到终端(默认开启)
|
||||
ErrorStack bool // Logging: 当产生错误时打印调用链详细堆栈
|
||||
ErrorLogEnabled bool // Logging: 是否开启error log(默认开启)
|
||||
AccessLogEnabled bool // Logging: 是否开启access log(默认关闭)
|
||||
NameToUriType int // Mess: 服务注册时对象和方法名称转换为URI时的规则
|
||||
@ -100,6 +101,7 @@ var defaultServerConfig = ServerConfig{
|
||||
SessionMaxAge: time.Hour * 24,
|
||||
SessionIdName: "gfsessionid",
|
||||
LogStdout: true,
|
||||
ErrorStack: true,
|
||||
ErrorLogEnabled: true,
|
||||
AccessLogEnabled: false,
|
||||
DumpRouteMap: true,
|
||||
|
||||
@ -54,6 +54,15 @@ func (s *Server) SetErrorLogEnabled(enabled bool) {
|
||||
s.config.ErrorLogEnabled = enabled
|
||||
}
|
||||
|
||||
// 设置是否开启error stack打印功能
|
||||
func (s *Server) SetErrorStack(enabled bool) {
|
||||
if s.Status() == SERVER_STATUS_RUNNING {
|
||||
glog.Error(gCHANGE_CONFIG_WHILE_RUNNING_ERROR)
|
||||
return
|
||||
}
|
||||
s.config.ErrorStack = enabled
|
||||
}
|
||||
|
||||
// 设置日志写入的回调函数
|
||||
func (s *Server) SetLogHandler(handler LogHandler) {
|
||||
if s.Status() == SERVER_STATUS_RUNNING {
|
||||
@ -12,6 +12,8 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/errors/gerror"
|
||||
|
||||
"github.com/gogf/gf/os/gres"
|
||||
|
||||
"github.com/gogf/gf/encoding/ghtml"
|
||||
@ -70,11 +72,17 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
request.Response.WriteStatus(http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// error log
|
||||
if e := recover(); e != nil {
|
||||
request.Response.WriteStatus(http.StatusInternalServerError)
|
||||
s.handleErrorLog(e, request)
|
||||
if request.error != nil {
|
||||
s.handleErrorLog(request.error, request)
|
||||
} else {
|
||||
if exception := recover(); exception != nil {
|
||||
request.Response.WriteStatus(http.StatusInternalServerError)
|
||||
s.handleErrorLog(gerror.Newf("%v", exception), request)
|
||||
}
|
||||
}
|
||||
|
||||
// access log
|
||||
s.handleAccessLog(request)
|
||||
}()
|
||||
|
||||
@ -9,6 +9,8 @@ package ghttp
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/errors/gerror"
|
||||
|
||||
"github.com/gogf/gf/os/gtime"
|
||||
)
|
||||
|
||||
@ -40,7 +42,7 @@ func (s *Server) handleAccessLog(r *Request) {
|
||||
}
|
||||
|
||||
// 处理服务错误信息,主要是panic,http请求的status由access log进行管理
|
||||
func (s *Server) handleErrorLog(error interface{}, r *Request) {
|
||||
func (s *Server) handleErrorLog(err error, r *Request) {
|
||||
// 错误输出默认是开启的
|
||||
if !s.IsErrorLogEnabled() {
|
||||
return
|
||||
@ -48,7 +50,7 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) {
|
||||
|
||||
// 自定义错误处理
|
||||
if v := s.GetLogHandler(); v != nil {
|
||||
v(r, error)
|
||||
v(r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -57,12 +59,17 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) {
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
content := fmt.Sprintf(`%v, "%s %s %s %s %s"`, error, r.Method, scheme, r.Host, r.URL.String(), r.Proto)
|
||||
content := fmt.Sprintf(`%v, "%s %s %s %s %s"`, err, r.Method, scheme, r.Host, r.URL.String(), r.Proto)
|
||||
if r.LeaveTime > r.EnterTime {
|
||||
content += fmt.Sprintf(` %.3f`, float64(r.LeaveTime-r.EnterTime)/1000)
|
||||
} else {
|
||||
content += fmt.Sprintf(` %.3f`, float64(gtime.Microsecond()-r.EnterTime)/1000)
|
||||
}
|
||||
content += fmt.Sprintf(`, %s, "%s", "%s"`, r.GetClientIp(), r.Referer(), r.UserAgent())
|
||||
s.logger.Cat("error").StackWithFilter(gPATH_FILTER_KEY).Stdout(s.config.LogStdout).Error(content)
|
||||
if s.config.ErrorStack {
|
||||
if stack := gerror.Stack(err); stack != "" {
|
||||
content += "\n" + stack
|
||||
}
|
||||
}
|
||||
s.logger.Cat("error").Stack(false).Stdout(s.config.LogStdout).Error(content)
|
||||
}
|
||||
|
||||
@ -56,6 +56,10 @@ func (s *Server) handlePreBindItems() {
|
||||
|
||||
// 获取分组路由对象
|
||||
func (s *Server) Group(prefix string, groups ...func(g *RouterGroup)) *RouterGroup {
|
||||
// 自动识别并加上/前缀
|
||||
if prefix[0] != '/' {
|
||||
prefix = "/" + prefix
|
||||
}
|
||||
if prefix == "/" {
|
||||
prefix = ""
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ package ghttp_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -409,3 +410,42 @@ func Test_Hook_Middleware_Basic1(t *testing.T) {
|
||||
gtest.Assert(client.GetContent("/test/test"), "ac13test42bd")
|
||||
})
|
||||
}
|
||||
|
||||
func MiddlewareAuth(r *ghttp.Request) {
|
||||
token := r.Get("token")
|
||||
if token == "123456" {
|
||||
r.Middleware.Next()
|
||||
} else {
|
||||
r.Response.WriteStatus(http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func MiddlewareCORS(r *ghttp.Request) {
|
||||
r.Response.CORSDefault()
|
||||
r.Middleware.Next()
|
||||
}
|
||||
|
||||
func Test_Middleware_CORSAndAuth(t *testing.T) {
|
||||
p := ports.PopRand()
|
||||
s := g.Server(p)
|
||||
s.Group("/api.v2", func(g *ghttp.RouterGroup) {
|
||||
g.Middleware(MiddlewareAuth, MiddlewareCORS)
|
||||
g.ALL("/user/list", func(r *ghttp.Request) {
|
||||
r.Response.Write("list")
|
||||
})
|
||||
})
|
||||
s.SetPort(p)
|
||||
s.SetDumpRouteMap(false)
|
||||
s.Start()
|
||||
defer s.Shutdown()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
gtest.Case(t, func() {
|
||||
client := ghttp.NewClient()
|
||||
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
|
||||
|
||||
gtest.Assert(client.GetContent("/"), "Not Found")
|
||||
gtest.Assert(client.GetContent("/api.v2"), "Not Found")
|
||||
gtest.Assert(client.GetContent("/api.v2/user/list"), "Forbidden")
|
||||
gtest.Assert(client.GetContent("/api.v2/user/list", "token=123456"), "list")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user