diff --git a/.example/net/ghttp/server/middleware/auth.go b/.example/net/ghttp/server/middleware/auth.go new file mode 100644 index 000000000..79b325aa9 --- /dev/null +++ b/.example/net/ghttp/server/middleware/auth.go @@ -0,0 +1,34 @@ +package main + +import ( + "net/http" + + "github.com/gogf/gf/frame/g" + "github.com/gogf/gf/net/ghttp" +) + +func MiddlewareAuth(r *ghttp.Request) { + token := r.Get("token") + if token == "123456" { + r.Middleware.Next() + } else { + r.Response.WriteStatus(http.StatusForbidden) + } +} + +func MiddlewareCORS(r *ghttp.Request) { + r.Response.CORSDefault() + r.Middleware.Next() +} + +func main() { + s := g.Server() + s.Group("/api.v2", func(g *ghttp.RouterGroup) { + g.Middleware(MiddlewareAuth, MiddlewareCORS) + g.ALL("/user/list", func(r *ghttp.Request) { + r.Response.Write("list") + }) + }) + s.SetPort(8199) + s.Run() +} diff --git a/.example/net/ghttp/server/middleware/cors.go b/.example/net/ghttp/server/middleware/cors.go new file mode 100644 index 000000000..6ddcd3a0d --- /dev/null +++ b/.example/net/ghttp/server/middleware/cors.go @@ -0,0 +1,23 @@ +package main + +import ( + "github.com/gogf/gf/frame/g" + "github.com/gogf/gf/net/ghttp" +) + +func MiddlewareCORS(r *ghttp.Request) { + r.Response.CORSDefault() + r.Middleware.Next() +} + +func main() { + s := g.Server() + s.Group("/api.v2", func(g *ghttp.RouterGroup) { + g.Middleware(MiddlewareCORS) + g.ALL("/user/list", func(r *ghttp.Request) { + r.Response.Write("list") + }) + }) + s.SetPort(8199) + s.Run() +} diff --git a/.example/net/ghttp/server/middleware/log.go b/.example/net/ghttp/server/middleware/log.go new file mode 100644 index 000000000..607962a4f --- /dev/null +++ b/.example/net/ghttp/server/middleware/log.go @@ -0,0 +1,44 @@ +package main + +import ( + "net/http" + + "github.com/gogf/gf/os/glog" + + "github.com/gogf/gf/frame/g" + "github.com/gogf/gf/net/ghttp" +) + +func MiddlewareAuth(r *ghttp.Request) { + token := r.Get("token") + if token == "123456" { + r.Middleware.Next() + } else { + r.Response.WriteStatus(http.StatusForbidden) + } +} + +func MiddlewareCORS(r *ghttp.Request) { + r.Response.CORSDefault() + r.Middleware.Next() +} + +func MiddlewareLog(r *ghttp.Request) { + r.Middleware.Next() + glog.Println(r.Response.Status, r.URL.Path) +} + +func main() { + s := g.Server() + s.Group("/", func(g *ghttp.RouterGroup) { + g.Middleware(MiddlewareLog) + }) + s.Group("/api.v2", func(g *ghttp.RouterGroup) { + g.Middleware(MiddlewareAuth, MiddlewareCORS) + g.ALL("/user/list", func(r *ghttp.Request) { + panic("custom error") + }) + }) + s.SetPort(8199) + s.Run() +} diff --git a/RELEASE.2.MD b/RELEASE.2.MD index faecd8f8e..3eb7d8c76 100644 --- a/RELEASE.2.MD +++ b/RELEASE.2.MD @@ -40,6 +40,7 @@ ## 功能改进 1. `ghttp` + - 当`WebServer`产生`panic`异常错误时,默认打印调用链堆栈到错误日志中; - `Cookie`及`Session`的`TTL`配置数据类型修改为`time.Duration`; - 新增允许同时通过`Header/Cookie`传递`SessionId`; - 新增`ConfigFromMap/SetConfigWithMap`方法,支持通过`map`参数设置WebServer; diff --git a/net/ghttp/ghttp_request.go b/net/ghttp/ghttp_request.go index 627411bbb..55b041d3a 100644 --- a/net/ghttp/ghttp_request.go +++ b/net/ghttp/ghttp_request.go @@ -40,6 +40,7 @@ type Request struct { parsedPost bool // POST参数是否已经解析 queryVars map[string][]string // GET参数 routerVars map[string][]string // 路由解析参数 + error error // 当前请求执行错误 exit bool // 是否退出当前请求流程执行 params map[string]interface{} // 开发者自定义参数(请求流程中有效) parsedHost string // 解析过后不带端口号的服务器域名名称 @@ -92,12 +93,8 @@ func (r *Request) GetVar(key string, def ...interface{}) *gvar.Var { // 获取原始请求输入二进制。 func (r *Request) GetRaw() []byte { - err := error(nil) if r.rawContent == nil { - r.rawContent, err = ioutil.ReadAll(r.Body) - if err != nil { - r.Error("error reading request body: ", err) - } + r.rawContent, _ = ioutil.ReadAll(r.Body) } return r.rawContent } diff --git a/net/ghttp/ghttp_request_log.go b/net/ghttp/ghttp_request_log.go deleted file mode 100644 index 7b1e7ea2f..000000000 --- a/net/ghttp/ghttp_request_log.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved. -// -// This Source Code Form is subject to the terms of the MIT License. -// If a copy of the MIT was not distributed with this file, -// You can obtain one at https://github.com/gogf/gf. - -package ghttp - -import "fmt" - -// 打印error日志 -func (r *Request) Error(value ...interface{}) { - r.Server.handleErrorLog(fmt.Sprint(value...), r) -} diff --git a/net/ghttp/ghttp_request_middleware.go b/net/ghttp/ghttp_request_middleware.go index f3778a546..746ae690b 100644 --- a/net/ghttp/ghttp_request_middleware.go +++ b/net/ghttp/ghttp_request_middleware.go @@ -6,7 +6,14 @@ package ghttp -import "reflect" +import ( + "net/http" + "reflect" + + "github.com/gogf/gf/errors/gerror" + + "github.com/gogf/gf/util/gutil" +) // 中间件对象 type Middleware struct { @@ -17,7 +24,8 @@ type Middleware struct { // 执行下一个请求流程处理函数 func (m *Middleware) Next() { item := (*handlerParsedItem)(nil) - for { + loop := true + for loop { // 是否停止请求执行 if m.request.IsExited() || m.request.handlerIndex >= len(m.request.handlers) { return @@ -34,49 +42,69 @@ func (m *Middleware) Next() { } m.request.Router = item.handler.router // 执行函数处理 - switch item.handler.itemType { - case gHANDLER_TYPE_CONTROLLER: - m.served = true - c := reflect.New(item.handler.ctrlInfo.reflect) - niceCallFunc(func() { - c.MethodByName("Init").Call([]reflect.Value{reflect.ValueOf(m.request)}) - }) - if !m.request.IsExited() { + gutil.TryCatch(func() { + switch item.handler.itemType { + case gHANDLER_TYPE_CONTROLLER: + m.served = true + if m.request.IsExited() { + break + } + c := reflect.New(item.handler.ctrlInfo.reflect) niceCallFunc(func() { - c.MethodByName(item.handler.ctrlInfo.name).Call(nil) + c.MethodByName("Init").Call([]reflect.Value{reflect.ValueOf(m.request)}) }) - } - if !m.request.IsExited() { - niceCallFunc(func() { - c.MethodByName("Shut").Call(nil) - }) - } - case gHANDLER_TYPE_OBJECT: - m.served = true - if item.handler.initFunc != nil { - niceCallFunc(func() { - item.handler.initFunc(m.request) - }) - } - if !m.request.IsExited() { + if !m.request.IsExited() { + niceCallFunc(func() { + c.MethodByName(item.handler.ctrlInfo.name).Call(nil) + }) + } + if !m.request.IsExited() { + niceCallFunc(func() { + c.MethodByName("Shut").Call(nil) + }) + } + + case gHANDLER_TYPE_OBJECT: + m.served = true + if m.request.IsExited() { + break + } + if item.handler.initFunc != nil { + niceCallFunc(func() { + item.handler.initFunc(m.request) + }) + } + if !m.request.IsExited() { + niceCallFunc(func() { + item.handler.itemFunc(m.request) + }) + } + if !m.request.IsExited() && item.handler.shutFunc != nil { + niceCallFunc(func() { + item.handler.shutFunc(m.request) + }) + } + + case gHANDLER_TYPE_HANDLER: + m.served = true + if m.request.IsExited() { + break + } niceCallFunc(func() { item.handler.itemFunc(m.request) }) - } - if !m.request.IsExited() && item.handler.shutFunc != nil { + + case gHANDLER_TYPE_MIDDLEWARE: niceCallFunc(func() { - item.handler.shutFunc(m.request) + item.handler.itemFunc(m.request) }) + // 中间件默认不会进一步执行, + // 需要内部调用Next方法决定是否进一步执行,以便于请求流程控制。 + loop = false } - case gHANDLER_TYPE_HANDLER: - m.served = true - niceCallFunc(func() { - item.handler.itemFunc(m.request) - }) - case gHANDLER_TYPE_MIDDLEWARE: - niceCallFunc(func() { - item.handler.itemFunc(m.request) - }) - } + }, func(exception interface{}) { + m.request.error = gerror.Newf("%v", exception) + m.request.Response.WriteStatus(http.StatusInternalServerError, exception) + }) } } diff --git a/net/ghttp/ghttp_response.go b/net/ghttp/ghttp_response.go index 1c6cd3f29..c37f823bf 100644 --- a/net/ghttp/ghttp_response.go +++ b/net/ghttp/ghttp_response.go @@ -47,6 +47,9 @@ func (r *Response) Write(content ...interface{}) { if len(content) == 0 { return } + if r.Status == 0 { + r.Status = http.StatusOK + } for _, v := range content { switch value := v.(type) { case []byte: diff --git a/net/ghttp/ghttp_server_config.go b/net/ghttp/ghttp_server_config.go index 4c72a9d65..5d6fe7116 100644 --- a/net/ghttp/ghttp_server_config.go +++ b/net/ghttp/ghttp_server_config.go @@ -34,7 +34,7 @@ const ( ) // 自定义日志处理方法类型 -type LogHandler func(r *Request, error ...interface{}) +type LogHandler func(r *Request, err ...error) // HTTP Server 设置结构体,静态配置 type ServerConfig struct { @@ -70,6 +70,7 @@ type ServerConfig struct { LogPath string // Logging: 存放日志的目录路径(默认为空,表示不写文件) LogHandler LogHandler // Logging: 日志配置: 自定义日志处理回调方法(默认为空) LogStdout bool // Logging: 是否打印日志到终端(默认开启) + ErrorStack bool // Logging: 当产生错误时打印调用链详细堆栈 ErrorLogEnabled bool // Logging: 是否开启error log(默认开启) AccessLogEnabled bool // Logging: 是否开启access log(默认关闭) NameToUriType int // Mess: 服务注册时对象和方法名称转换为URI时的规则 @@ -100,6 +101,7 @@ var defaultServerConfig = ServerConfig{ SessionMaxAge: time.Hour * 24, SessionIdName: "gfsessionid", LogStdout: true, + ErrorStack: true, ErrorLogEnabled: true, AccessLogEnabled: false, DumpRouteMap: true, diff --git a/net/ghttp/ghttp_server_config_logger.go b/net/ghttp/ghttp_server_config_logging.go similarity index 91% rename from net/ghttp/ghttp_server_config_logger.go rename to net/ghttp/ghttp_server_config_logging.go index a558f869e..ab24e97ca 100644 --- a/net/ghttp/ghttp_server_config_logger.go +++ b/net/ghttp/ghttp_server_config_logging.go @@ -54,6 +54,15 @@ func (s *Server) SetErrorLogEnabled(enabled bool) { s.config.ErrorLogEnabled = enabled } +// 设置是否开启error stack打印功能 +func (s *Server) SetErrorStack(enabled bool) { + if s.Status() == SERVER_STATUS_RUNNING { + glog.Error(gCHANGE_CONFIG_WHILE_RUNNING_ERROR) + return + } + s.config.ErrorStack = enabled +} + // 设置日志写入的回调函数 func (s *Server) SetLogHandler(handler LogHandler) { if s.Status() == SERVER_STATUS_RUNNING { diff --git a/net/ghttp/ghttp_server_handler.go b/net/ghttp/ghttp_server_handler.go index 8287ad7d5..184ef1395 100644 --- a/net/ghttp/ghttp_server_handler.go +++ b/net/ghttp/ghttp_server_handler.go @@ -12,6 +12,8 @@ import ( "sort" "strings" + "github.com/gogf/gf/errors/gerror" + "github.com/gogf/gf/os/gres" "github.com/gogf/gf/encoding/ghtml" @@ -70,11 +72,17 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { request.Response.WriteStatus(http.StatusNotFound) } } + // error log - if e := recover(); e != nil { - request.Response.WriteStatus(http.StatusInternalServerError) - s.handleErrorLog(e, request) + if request.error != nil { + s.handleErrorLog(request.error, request) + } else { + if exception := recover(); exception != nil { + request.Response.WriteStatus(http.StatusInternalServerError) + s.handleErrorLog(gerror.Newf("%v", exception), request) + } } + // access log s.handleAccessLog(request) }() diff --git a/net/ghttp/ghttp_server_log.go b/net/ghttp/ghttp_server_log.go index cebb1a835..d1cd414d7 100644 --- a/net/ghttp/ghttp_server_log.go +++ b/net/ghttp/ghttp_server_log.go @@ -9,6 +9,8 @@ package ghttp import ( "fmt" + "github.com/gogf/gf/errors/gerror" + "github.com/gogf/gf/os/gtime" ) @@ -40,7 +42,7 @@ func (s *Server) handleAccessLog(r *Request) { } // 处理服务错误信息,主要是panic,http请求的status由access log进行管理 -func (s *Server) handleErrorLog(error interface{}, r *Request) { +func (s *Server) handleErrorLog(err error, r *Request) { // 错误输出默认是开启的 if !s.IsErrorLogEnabled() { return @@ -48,7 +50,7 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) { // 自定义错误处理 if v := s.GetLogHandler(); v != nil { - v(r, error) + v(r, err) return } @@ -57,12 +59,17 @@ func (s *Server) handleErrorLog(error interface{}, r *Request) { if r.TLS != nil { scheme = "https" } - content := fmt.Sprintf(`%v, "%s %s %s %s %s"`, error, r.Method, scheme, r.Host, r.URL.String(), r.Proto) + content := fmt.Sprintf(`%v, "%s %s %s %s %s"`, err, r.Method, scheme, r.Host, r.URL.String(), r.Proto) if r.LeaveTime > r.EnterTime { content += fmt.Sprintf(` %.3f`, float64(r.LeaveTime-r.EnterTime)/1000) } else { content += fmt.Sprintf(` %.3f`, float64(gtime.Microsecond()-r.EnterTime)/1000) } content += fmt.Sprintf(`, %s, "%s", "%s"`, r.GetClientIp(), r.Referer(), r.UserAgent()) - s.logger.Cat("error").StackWithFilter(gPATH_FILTER_KEY).Stdout(s.config.LogStdout).Error(content) + if s.config.ErrorStack { + if stack := gerror.Stack(err); stack != "" { + content += "\n" + stack + } + } + s.logger.Cat("error").Stack(false).Stdout(s.config.LogStdout).Error(content) } diff --git a/net/ghttp/ghttp_server_router_group.go b/net/ghttp/ghttp_server_router_group.go index aa4f93759..976a0b91d 100644 --- a/net/ghttp/ghttp_server_router_group.go +++ b/net/ghttp/ghttp_server_router_group.go @@ -56,6 +56,10 @@ func (s *Server) handlePreBindItems() { // 获取分组路由对象 func (s *Server) Group(prefix string, groups ...func(g *RouterGroup)) *RouterGroup { + // 自动识别并加上/前缀 + if prefix[0] != '/' { + prefix = "/" + prefix + } if prefix == "/" { prefix = "" } diff --git a/net/ghttp/ghttp_unit_middleware_test.go b/net/ghttp/ghttp_unit_middleware_test.go index 3f28be8e8..34618f49b 100644 --- a/net/ghttp/ghttp_unit_middleware_test.go +++ b/net/ghttp/ghttp_unit_middleware_test.go @@ -8,6 +8,7 @@ package ghttp_test import ( "fmt" + "net/http" "testing" "time" @@ -409,3 +410,42 @@ func Test_Hook_Middleware_Basic1(t *testing.T) { gtest.Assert(client.GetContent("/test/test"), "ac13test42bd") }) } + +func MiddlewareAuth(r *ghttp.Request) { + token := r.Get("token") + if token == "123456" { + r.Middleware.Next() + } else { + r.Response.WriteStatus(http.StatusForbidden) + } +} + +func MiddlewareCORS(r *ghttp.Request) { + r.Response.CORSDefault() + r.Middleware.Next() +} + +func Test_Middleware_CORSAndAuth(t *testing.T) { + p := ports.PopRand() + s := g.Server(p) + s.Group("/api.v2", func(g *ghttp.RouterGroup) { + g.Middleware(MiddlewareAuth, MiddlewareCORS) + g.ALL("/user/list", func(r *ghttp.Request) { + r.Response.Write("list") + }) + }) + s.SetPort(p) + s.SetDumpRouteMap(false) + s.Start() + defer s.Shutdown() + time.Sleep(200 * time.Millisecond) + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + gtest.Assert(client.GetContent("/"), "Not Found") + gtest.Assert(client.GetContent("/api.v2"), "Not Found") + gtest.Assert(client.GetContent("/api.v2/user/list"), "Forbidden") + gtest.Assert(client.GetContent("/api.v2/user/list", "token=123456"), "list") + }) +}