From a06ca315308012c1913aed557ad0aee7f2dcb31f Mon Sep 17 00:00:00 2001 From: John Date: Wed, 4 Dec 2019 10:03:03 +0800 Subject: [PATCH] improve middleware feature for ghttp.Server --- debug/gdebug/gdebug.go | 19 +++ net/ghttp/ghttp_request.go | 3 +- net/ghttp/ghttp_request_middleware.go | 44 +++++-- net/ghttp/ghttp_server.go | 42 ++++--- net/ghttp/ghttp_server_domain.go | 42 +++++++ net/ghttp/ghttp_server_router_group.go | 111 ++++++++---------- net/ghttp/ghttp_server_router_hook.go | 5 +- net/ghttp/ghttp_server_router_middleware.go | 7 +- net/ghttp/ghttp_server_router_serve.go | 2 +- net/ghttp/ghttp_server_service_controller.go | 58 ++++++--- net/ghttp/ghttp_server_service_handler.go | 31 ++--- net/ghttp/ghttp_server_service_object.go | 93 +++++++++------ net/ghttp/ghttp_unit_middleware_test.go | 14 +-- .../ghttp_unit_router_group_group_test.go | 14 +-- 14 files changed, 300 insertions(+), 185 deletions(-) diff --git a/debug/gdebug/gdebug.go b/debug/gdebug/gdebug.go index 5961f3aa6..6809ccfa2 100644 --- a/debug/gdebug/gdebug.go +++ b/debug/gdebug/gdebug.go @@ -11,6 +11,7 @@ import ( "bytes" "fmt" "path/filepath" + "reflect" "runtime" "strconv" "strings" @@ -244,3 +245,21 @@ func CallerFileLineShort() string { _, path, line := Caller() return fmt.Sprintf(`%s:%d`, filepath.Base(path), line) } + +// FuncPath returns the complete function path of given . +func FuncPath(f interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() +} + +// FuncName returns the function name of given . +func FuncName(f interface{}) string { + path := FuncPath(f) + if path == "" { + return "" + } + index := strings.LastIndexByte(path, '/') + if index < 0 { + index = strings.LastIndexByte(path, '\\') + } + return path[index+1:] +} diff --git a/net/ghttp/ghttp_request.go b/net/ghttp/ghttp_request.go index b14d5c54b..c64db0301 100644 --- a/net/ghttp/ghttp_request.go +++ b/net/ghttp/ghttp_request.go @@ -30,7 +30,6 @@ type Request struct { LeaveTime int64 // Request ending time in microseconds. Middleware *Middleware // The middleware manager. handlers []*handlerParsedItem // All matched handlers containing handler, hook and middleware for this request . - handlerIndex int // Index number for executing sequence purpose of handlers. hasHookHandler bool // A bool marking whether there's hook handler in the handlers for performance purpose. hasServeHandler bool // A bool marking whether there's serving handler in the handlers for performance purpose. parsedQuery bool // A bool marking whether the GET parameters parsed. @@ -125,7 +124,7 @@ func (r *Request) IsAjaxRequest() bool { return strings.EqualFold(r.Header.Get("X-Requested-With"), "XMLHttpRequest") } -// GetClientIp returns the client ip of this request. +// GetClientIp returns the client ip of this request without port. func (r *Request) GetClientIp() string { if len(r.clientIp) == 0 { if r.clientIp = r.Header.Get("X-Real-IP"); r.clientIp == "" { diff --git a/net/ghttp/ghttp_request_middleware.go b/net/ghttp/ghttp_request_middleware.go index b32e7603c..6bb704ce7 100644 --- a/net/ghttp/ghttp_request_middleware.go +++ b/net/ghttp/ghttp_request_middleware.go @@ -17,23 +17,25 @@ import ( // Middleware is the plugin for request workflow management. type Middleware struct { - served bool // Is the request served, which is used for checking response status 404. - request *Request // The request object pointer. + served bool // Is the request served, which is used for checking response status 404. + request *Request // The request object pointer. + handlerIndex int // Index number for executing sequence purpose for handler items. + handlerMDIndex int // Index number for executing sequence purpose for bound middleware of handler item. } // Next calls the next workflow handler. func (m *Middleware) Next() { - item := (*handlerParsedItem)(nil) - loop := true + var item *handlerParsedItem + var loop = true for loop { // Check whether the request is exited. - if m.request.IsExited() || m.request.handlerIndex >= len(m.request.handlers) { + if m.request.IsExited() || m.handlerIndex >= len(m.request.handlers) { break } - item = m.request.handlers[m.request.handlerIndex] - m.request.handlerIndex++ + item = m.request.handlers[m.handlerIndex] // Filter the HOOK handlers, which are designed to be called in another standalone procedure. if item.handler.itemType == gHANDLER_TYPE_HOOK { + m.handlerIndex++ continue } // Router values switching. @@ -42,7 +44,20 @@ func (m *Middleware) Next() { m.request.Router = item.handler.router gutil.TryCatch(func() { + // Execute bound middleware array of the item if it's not empty. + if m.handlerMDIndex < len(item.handler.middleware) { + md := item.handler.middleware[m.handlerMDIndex] + m.handlerMDIndex++ + niceCallFunc(func() { + md(m.request) + }) + loop = false + return + } + m.handlerIndex++ + switch item.handler.itemType { + // Service controller. case gHANDLER_TYPE_CONTROLLER: m.served = true if m.request.IsExited() { @@ -63,6 +78,7 @@ func (m *Middleware) Next() { }) } + // Service object. case gHANDLER_TYPE_OBJECT: m.served = true if m.request.IsExited() { @@ -84,6 +100,7 @@ func (m *Middleware) Next() { }) } + // Service handler. case gHANDLER_TYPE_HANDLER: m.served = true if m.request.IsExited() { @@ -93,6 +110,7 @@ func (m *Middleware) Next() { item.handler.itemFunc(m.request) }) + // Global middleware array. case gHANDLER_TYPE_MIDDLEWARE: niceCallFunc(func() { item.handler.itemFunc(m.request) @@ -107,11 +125,13 @@ func (m *Middleware) Next() { }) } // Check the http status code after all handler and middleware done. - if m.request.Response.Status == 0 { - if m.request.Middleware.served || m.request.Response.buffer.Len() > 0 { - m.request.Response.WriteHeader(http.StatusOK) - } else { - m.request.Response.WriteHeader(http.StatusNotFound) + if m.request.IsExited() || m.handlerIndex >= len(m.request.handlers) { + if m.request.Response.Status == 0 { + if m.request.Middleware.served { + m.request.Response.WriteHeader(http.StatusOK) + } else { + m.request.Response.WriteHeader(http.StatusNotFound) + } } } } diff --git a/net/ghttp/ghttp_server.go b/net/ghttp/ghttp_server.go index 4c0ce8326..28f59cecb 100644 --- a/net/ghttp/ghttp_server.go +++ b/net/ghttp/ghttp_server.go @@ -10,6 +10,7 @@ import ( "bytes" "errors" "fmt" + "github.com/gogf/gf/debug/gdebug" "net/http" "os" "reflect" @@ -62,15 +63,16 @@ type ( // 路由函数注册信息 handlerItem struct { - itemId int // 用于标识该注册函数的唯一性ID - itemName string // 注册的函数名称信息(用于路由信息打印) - itemType int // 注册函数类型(对象/函数/控制器/中间件/钩子函数) - itemFunc HandlerFunc // 函数内存地址(与以上两个参数二选一) - initFunc HandlerFunc // 初始化请求回调函数(对象注册方式下有效) - shutFunc HandlerFunc // 完成请求回调函数(对象注册方式下有效) - ctrlInfo *handlerController // 控制器服务函数反射信息 - hookName string // 钩子类型名称(注册函数类型为钩子函数下有效) - router *Router // 注册时绑定的路由对象 + itemId int // 用于标识该注册函数的唯一性ID + itemName string // 注册的函数名称信息(用于路由信息打印) + itemType int // 注册函数类型(对象/函数/控制器/中间件/钩子函数) + itemFunc HandlerFunc // 函数内存地址(与以上两个参数二选一) + initFunc HandlerFunc // 初始化请求回调函数(对象注册方式下有效) + shutFunc HandlerFunc // 完成请求回调函数(对象注册方式下有效) + middleware []HandlerFunc // 绑定的中间件列表 + ctrlInfo *handlerController // 控制器服务函数反射信息 + hookName string // 钩子类型名称(注册函数类型为钩子函数下有效) + router *Router // 注册时绑定的路由对象 } // 根据特定URL.Path解析后的路由检索结果项 @@ -317,12 +319,13 @@ func (s *Server) Start() error { // 打印展示路由表 func (s *Server) DumpRoutesMap() { if s.config.DumpRouteMap && len(s.routesMap) > 0 { - glog.Header(false).Println(fmt.Sprintf("\n%s", s.GetRouteMap())) + glog.Header(false).Println(fmt.Sprintf("\n%s", s.getRouteMapString())) } } // 获得路由表(格式化字符串) -func (s *Server) GetRouteMap() string { +func (s *Server) getRouteMapString() string { + // Route table for dumping. type tableItem struct { middleware string domain string @@ -342,12 +345,11 @@ func (s *Server) GetRouteMap() string { tablewriter.ALIGN_CENTER, tablewriter.ALIGN_CENTER, tablewriter.ALIGN_CENTER, - tablewriter.ALIGN_LEFT, + tablewriter.ALIGN_CENTER, tablewriter.ALIGN_CENTER, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, - tablewriter.ALIGN_CENTER, }) m := make(map[string]*garray.SortedArray) @@ -363,10 +365,20 @@ func (s *Server) GetRouteMap() string { priority: len(registeredItems) - index - 1, } if item.handler.itemType == gHANDLER_TYPE_MIDDLEWARE { - item.middleware = "MIDDLEWARE" + item.middleware = "GLOBAL MIDDLEWARE" } + if len(item.handler.middleware) > 0 { + for _, v := range item.handler.middleware { + if item.middleware != "" { + item.middleware += "," + } + item.middleware += gdebug.FuncName(v) + } + } + // If the domain does not exist in the dump map, it create the map. + // The value of the map is a custom sorted array. if _, ok := m[item.domain]; !ok { - // 注意排序函数的逻辑,从小到达排序 + // Sort in ASC order. m[item.domain] = garray.NewSortedArraySize(100, func(v1, v2 interface{}) int { item1 := v1.(*tableItem) item2 := v2.(*tableItem) diff --git a/net/ghttp/ghttp_server_domain.go b/net/ghttp/ghttp_server_domain.go index 757f1519f..1c831d823 100644 --- a/net/ghttp/ghttp_server_domain.go +++ b/net/ghttp/ghttp_server_domain.go @@ -35,6 +35,12 @@ func (d *Domain) BindHandler(pattern string, handler HandlerFunc) { } } +func (d *Domain) doBindHandler(pattern string, handler HandlerFunc, middleware []HandlerFunc) { + for domain, _ := range d.m { + d.s.doBindHandler(pattern+"@"+domain, handler, middleware) + } +} + // 执行对象方法 func (d *Domain) BindObject(pattern string, obj interface{}, methods ...string) { for domain, _ := range d.m { @@ -42,6 +48,12 @@ func (d *Domain) BindObject(pattern string, obj interface{}, methods ...string) } } +func (d *Domain) doBindObject(pattern string, obj interface{}, methods string, middleware []HandlerFunc) { + for domain, _ := range d.m { + d.s.doBindObject(pattern+"@"+domain, obj, methods, middleware) + } +} + // 执行对象方法注册,methods参数不区分大小写 func (d *Domain) BindObjectMethod(pattern string, obj interface{}, method string) { for domain, _ := range d.m { @@ -49,6 +61,12 @@ func (d *Domain) BindObjectMethod(pattern string, obj interface{}, method string } } +func (d *Domain) doBindObjectMethod(pattern string, obj interface{}, method string, middleware []HandlerFunc) { + for domain, _ := range d.m { + d.s.doBindObjectMethod(pattern+"@"+domain, obj, method, middleware) + } +} + // RESTful执行对象注册 func (d *Domain) BindObjectRest(pattern string, obj interface{}) { for domain, _ := range d.m { @@ -56,6 +74,12 @@ func (d *Domain) BindObjectRest(pattern string, obj interface{}) { } } +func (d *Domain) doBindObjectRest(pattern string, obj interface{}, middleware []HandlerFunc) { + for domain, _ := range d.m { + d.s.doBindObjectRest(pattern+"@"+domain, obj, middleware) + } +} + // 控制器注册 func (d *Domain) BindController(pattern string, c Controller, methods ...string) { for domain, _ := range d.m { @@ -63,6 +87,12 @@ func (d *Domain) BindController(pattern string, c Controller, methods ...string) } } +func (d *Domain) doBindController(pattern string, c Controller, methods string, middleware []HandlerFunc) { + for domain, _ := range d.m { + d.s.doBindController(pattern+"@"+domain, c, methods, middleware) + } +} + // 控制器方法注册,methods参数区分大小写 func (d *Domain) BindControllerMethod(pattern string, c Controller, method string) { for domain, _ := range d.m { @@ -70,6 +100,12 @@ func (d *Domain) BindControllerMethod(pattern string, c Controller, method strin } } +func (d *Domain) doBindControllerMethod(pattern string, c Controller, method string, middleware []HandlerFunc) { + for domain, _ := range d.m { + d.s.doBindControllerMethod(pattern+"@"+domain, c, method, middleware) + } +} + // RESTful控制器注册 func (d *Domain) BindControllerRest(pattern string, c Controller) { for domain, _ := range d.m { @@ -77,6 +113,12 @@ func (d *Domain) BindControllerRest(pattern string, c Controller) { } } +func (d *Domain) doBindControllerRest(pattern string, c Controller, middleware []HandlerFunc) { + for domain, _ := range d.m { + d.s.doBindControllerRest(pattern+"@"+domain, c, middleware) + } +} + // 绑定指定的hook回调函数, hook参数的值由ghttp server设定,参数不区分大小写 // 目前hook支持:Init/Shut func (d *Domain) BindHookHandler(pattern string, hook string, handler HandlerFunc) { diff --git a/net/ghttp/ghttp_server_router_group.go b/net/ghttp/ghttp_server_router_group.go index a7b53613c..66202fcd9 100644 --- a/net/ghttp/ghttp_server_router_group.go +++ b/net/ghttp/ghttp_server_router_group.go @@ -18,10 +18,11 @@ import ( // 分组路由对象 type RouterGroup struct { - parent *RouterGroup // 父级分组路由 - server *Server // Server - domain *Domain // Domain - prefix string // URI前缀 + parent *RouterGroup // 父级分组路由 + server *Server // Server + domain *Domain // Domain + prefix string // URI前缀 + middleware []HandlerFunc // 分组路由绑定的中间件 } // 分组路由批量绑定项 @@ -44,6 +45,7 @@ var ( // 处理预绑定路由项 func (s *Server) handlePreBindItems() { for _, item := range preBindItems { + // Handle the items of current server. if item.group.server != nil && item.group.server != s { continue } @@ -62,16 +64,16 @@ func (s *Server) Group(prefix string, groups ...func(group *RouterGroup)) *Route if prefix == "/" { prefix = "" } - rg := &RouterGroup{ + group := &RouterGroup{ server: s, prefix: prefix, } if len(groups) > 0 { for _, v := range groups { - v(rg) + v(group) } } - return rg + return group } // 获取分组路由对象(绑定域名) @@ -82,16 +84,16 @@ func (d *Domain) Group(prefix string, groups ...func(group *RouterGroup)) *Route if prefix == "/" { prefix = "" } - rg := &RouterGroup{ + group := &RouterGroup{ domain: d, prefix: prefix, } if len(groups) > 0 { for _, v := range groups { - v(rg) + v(group) } } - return rg + return group } // 层级递归创建分组路由注册项 @@ -99,27 +101,34 @@ func (g *RouterGroup) Group(prefix string, groups ...func(group *RouterGroup)) * if prefix == "/" { prefix = "" } - rg := &RouterGroup{ + group := &RouterGroup{ parent: g, server: g.server, domain: g.domain, prefix: prefix, } + if len(g.middleware) > 0 { + group.middleware = make([]HandlerFunc, len(g.middleware)) + copy(group.middleware, g.middleware) + } if len(groups) > 0 { for _, v := range groups { - v(rg) + v(group) } } - return rg + return group } func (g *RouterGroup) Clone() *RouterGroup { - return &RouterGroup{ - parent: g.parent, - server: g.server, - domain: g.domain, - prefix: g.prefix, + newGroup := &RouterGroup{ + parent: g.parent, + server: g.server, + domain: g.domain, + prefix: g.prefix, + middleware: make([]HandlerFunc, len(g.middleware)), } + copy(newGroup.middleware, g.middleware) + return newGroup } // 执行分组路由批量绑定 @@ -211,23 +220,8 @@ func (g *RouterGroup) Hook(pattern string, hook string, handler HandlerFunc) *Ro } func (g *RouterGroup) Middleware(handlers ...HandlerFunc) *RouterGroup { - group := g.Clone() - for _, handler := range handlers { - if gstr.Contains(g.prefix, "*") { - group.preBind("MIDDLEWARE", "/", handler) - } else { - group.preBind("MIDDLEWARE", "/*", handler) - } - } - return group -} - -func (g *RouterGroup) MiddlewarePattern(pattern string, handlers ...HandlerFunc) *RouterGroup { - group := g.Clone() - for _, handler := range handlers { - group.preBind("MIDDLEWARE", pattern, handler) - } - return group + g.middleware = append(g.middleware, handlers...) + return g } func (g *RouterGroup) preBind(bindType string, pattern string, object interface{}, params ...interface{}) *RouterGroup { @@ -279,64 +273,54 @@ func (g *RouterGroup) doBind(bindType string, pattern string, object interface{} bindType = "HOOK" } switch bindType { - case "MIDDLEWARE": - if h, ok := object.(HandlerFunc); ok { - if g.server != nil { - g.server.BindMiddleware(pattern, h) - } else { - g.domain.BindMiddleware(pattern, h) - } - } else { - glog.Fatalf("invalid middleware handler for pattern:%s", pattern) - } case "HANDLER": if h, ok := object.(HandlerFunc); ok { if g.server != nil { - g.server.BindHandler(pattern, h) + g.server.doBindHandler(pattern, h, g.middleware) } else { - g.domain.BindHandler(pattern, h) + g.domain.doBindHandler(pattern, h, g.middleware) } } else if g.isController(object) { if len(extras) > 0 { if g.server != nil { - g.server.BindControllerMethod(pattern, object.(Controller), extras[0]) + g.server.doBindControllerMethod(pattern, object.(Controller), extras[0], g.middleware) } else { - g.domain.BindControllerMethod(pattern, object.(Controller), extras[0]) + g.domain.doBindControllerMethod(pattern, object.(Controller), extras[0], g.middleware) } } else { if g.server != nil { - g.server.BindController(pattern, object.(Controller)) + g.server.doBindController(pattern, object.(Controller), "", g.middleware) } else { - g.domain.BindController(pattern, object.(Controller)) + g.domain.doBindController(pattern, object.(Controller), "", g.middleware) } } } else { if len(extras) > 0 { if g.server != nil { - g.server.BindObjectMethod(pattern, object, extras[0]) + g.server.doBindObjectMethod(pattern, object, extras[0], g.middleware) } else { - g.domain.BindObjectMethod(pattern, object, extras[0]) + g.domain.doBindObjectMethod(pattern, object, extras[0], g.middleware) } } else { if g.server != nil { - g.server.BindObject(pattern, object) + g.server.doBindObject(pattern, object, "", g.middleware) } else { - g.domain.BindObject(pattern, object) + g.domain.doBindObject(pattern, object, "", g.middleware) } } } case "REST": if g.isController(object) { if g.server != nil { - g.server.BindControllerRest(pattern, object.(Controller)) + g.server.doBindControllerRest(pattern, object.(Controller), g.middleware) } else { - g.domain.BindControllerRest(pattern, object.(Controller)) + g.domain.doBindControllerRest(pattern, object.(Controller), g.middleware) } } else { if g.server != nil { - g.server.BindObjectRest(pattern, object) + g.server.doBindObjectRest(pattern, object, g.middleware) } else { - g.domain.BindObjectRest(pattern, object) + g.domain.doBindObjectRest(pattern, object, g.middleware) } } case "HOOK": @@ -365,9 +349,12 @@ func (g *RouterGroup) isController(value interface{}) bool { if v.Kind() == reflect.Ptr { v = v.Elem() } - if v.FieldByName("Request").IsValid() && v.FieldByName("Response").IsValid() && - v.FieldByName("Server").IsValid() && v.FieldByName("Cookie").IsValid() && - v.FieldByName("Session").IsValid() && v.FieldByName("View").IsValid() { + if v.FieldByName("Request").IsValid() && + v.FieldByName("Response").IsValid() && + v.FieldByName("Server").IsValid() && + v.FieldByName("Cookie").IsValid() && + v.FieldByName("Session").IsValid() && + v.FieldByName("View").IsValid() { return true } return false diff --git a/net/ghttp/ghttp_server_router_hook.go b/net/ghttp/ghttp_server_router_hook.go index a22cc72c5..fcce0aa66 100644 --- a/net/ghttp/ghttp_server_router_hook.go +++ b/net/ghttp/ghttp_server_router_hook.go @@ -7,16 +7,15 @@ package ghttp import ( + "github.com/gogf/gf/debug/gdebug" "net/http" - "reflect" - "runtime" ) // 绑定指定的hook回调函数, pattern参数同BindHandler,支持命名路由;hook参数的值由ghttp server设定,参数不区分大小写 func (s *Server) BindHookHandler(pattern string, hook string, handler HandlerFunc) { s.setHandler(pattern, &handlerItem{ itemType: gHANDLER_TYPE_HOOK, - itemName: runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name(), + itemName: gdebug.FuncPath(handler), itemFunc: handler, hookName: hook, }) diff --git a/net/ghttp/ghttp_server_router_middleware.go b/net/ghttp/ghttp_server_router_middleware.go index 2f169ef57..31dc39e61 100644 --- a/net/ghttp/ghttp_server_router_middleware.go +++ b/net/ghttp/ghttp_server_router_middleware.go @@ -7,8 +7,7 @@ package ghttp import ( - "reflect" - "runtime" + "github.com/gogf/gf/debug/gdebug" ) const ( @@ -20,7 +19,7 @@ func (s *Server) BindMiddleware(pattern string, handlers ...HandlerFunc) { for _, handler := range handlers { s.setHandler(pattern, &handlerItem{ itemType: gHANDLER_TYPE_MIDDLEWARE, - itemName: runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name(), + itemName: gdebug.FuncPath(handler), itemFunc: handler, }) } @@ -31,7 +30,7 @@ func (s *Server) BindMiddlewareDefault(handlers ...HandlerFunc) { for _, handler := range handlers { s.setHandler(gDEFAULT_MIDDLEWARE_PATTERN, &handlerItem{ itemType: gHANDLER_TYPE_MIDDLEWARE, - itemName: runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name(), + itemName: gdebug.FuncPath(handler), itemFunc: handler, }) } diff --git a/net/ghttp/ghttp_server_router_serve.go b/net/ghttp/ghttp_server_router_serve.go index d6c4afec8..0f0f592f3 100644 --- a/net/ghttp/ghttp_server_router_serve.go +++ b/net/ghttp/ghttp_server_router_serve.go @@ -117,7 +117,7 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han // 注意当不带任何动态路由规则时,len(match) == 1 if match, err := gregex.MatchString(item.router.RegRule, path); err == nil && len(match) > 0 { parsedItem := &handlerParsedItem{item, nil} - // 如果需要query匹配,那么需要重新正则解析URL + // 如果需要路由规则中带有URI名称匹配,那么需要重新正则解析URL if len(item.router.RegNames) > 0 { if len(match) > len(item.router.RegNames) { parsedItem.values = make(map[string]string) diff --git a/net/ghttp/ghttp_server_service_controller.go b/net/ghttp/ghttp_server_service_controller.go index aca597b75..cf640d89c 100644 --- a/net/ghttp/ghttp_server_service_controller.go +++ b/net/ghttp/ghttp_server_service_controller.go @@ -20,7 +20,36 @@ import ( // 绑定控制器,控制器需要实现 gmvc.Controller 接口, // 这种方式绑定的控制器每一次请求都会初始化一个新的控制器对象进行处理,对应不同的请求会话, // 第三个参数methods用以指定需要注册的方法,支持多个方法名称,多个方法以英文“,”号分隔,区分大小写. -func (s *Server) BindController(pattern string, c Controller, methods ...string) { +func (s *Server) BindController(pattern string, controller Controller, method ...string) { + bindMethod := "" + if len(method) > 0 { + bindMethod = method[0] + } + s.doBindController(pattern, controller, bindMethod, nil) +} + +// 绑定路由到指定的方法执行, 第三个参数method仅支持一个方法注册,不支持多个,并且区分大小写。 +func (s *Server) BindControllerMethod(pattern string, controller Controller, method string) { + s.doBindControllerMethod(pattern, controller, method, nil) +} + +// 绑定控制器(RESTFul),控制器需要实现gmvc.Controller接口 +// 方法会识别HTTP方法,并做REST绑定处理,例如:Post方法会绑定到HTTP POST的方法请求处理,Delete方法会绑定到HTTP DELETE的方法请求处理 +// 因此只会绑定HTTP Method对应的方法,其他方法不会自动注册绑定 +// 这种方式绑定的控制器每一次请求都会初始化一个新的控制器对象进行处理,对应不同的请求会话 +func (s *Server) BindControllerRest(pattern string, controller Controller) { + s.doBindControllerRest(pattern, controller, nil) +} + +func (s *Server) doBindController(pattern string, controller Controller, method string, middleware []HandlerFunc) { + // Convert input method to map for convenience and high performance searching. + var methodMap map[string]bool + if len(method) > 0 { + methodMap = make(map[string]bool) + for _, v := range strings.Split(method, ",") { + methodMap[strings.TrimSpace(v)] = true + } + } // 当pattern中的method为all时,去掉该method,以便于后续方法判断 domain, method, path, err := s.parsePattern(pattern) if err != nil { @@ -30,17 +59,9 @@ func (s *Server) BindController(pattern string, c Controller, methods ...string) if strings.EqualFold(method, gDEFAULT_METHOD) { pattern = s.serveHandlerKey("", path, domain) } - - methodMap := (map[string]bool)(nil) - if len(methods) > 0 { - methodMap = make(map[string]bool) - for _, v := range strings.Split(methods[0], ",") { - methodMap[strings.TrimSpace(v)] = true - } - } // 遍历控制器,获取方法列表,并构造成uri m := make(handlerMap) - v := reflect.ValueOf(c) + v := reflect.ValueOf(controller) t := v.Type() sname := t.Elem().Name() pkgPath := t.Elem().PkgPath() @@ -77,6 +98,7 @@ func (s *Server) BindController(pattern string, c Controller, methods ...string) name: mname, reflect: v.Elem().Type(), }, + middleware: middleware, } // 如果方法中带有Index方法,那么额外自动增加一个路由规则匹配主URI, // 例如: pattern为/user, 那么会同时注册/user及/user/index, @@ -95,16 +117,16 @@ func (s *Server) BindController(pattern string, c Controller, methods ...string) name: mname, reflect: v.Elem().Type(), }, + middleware: middleware, } } } s.bindHandlerByMap(m) } -// 绑定路由到指定的方法执行, 第三个参数method仅支持一个方法注册,不支持多个,并且区分大小写。 -func (s *Server) BindControllerMethod(pattern string, c Controller, method string) { +func (s *Server) doBindControllerMethod(pattern string, controller Controller, method string, middleware []HandlerFunc) { m := make(handlerMap) - v := reflect.ValueOf(c) + v := reflect.ValueOf(controller) t := v.Type() sname := t.Elem().Name() mname := strings.TrimSpace(method) @@ -132,18 +154,15 @@ func (s *Server) BindControllerMethod(pattern string, c Controller, method strin name: mname, reflect: v.Elem().Type(), }, + middleware: middleware, } s.bindHandlerByMap(m) } -// 绑定控制器(RESTFul),控制器需要实现gmvc.Controller接口 -// 方法会识别HTTP方法,并做REST绑定处理,例如:Post方法会绑定到HTTP POST的方法请求处理,Delete方法会绑定到HTTP DELETE的方法请求处理 -// 因此只会绑定HTTP Method对应的方法,其他方法不会自动注册绑定 -// 这种方式绑定的控制器每一次请求都会初始化一个新的控制器对象进行处理,对应不同的请求会话 -func (s *Server) BindControllerRest(pattern string, c Controller) { +func (s *Server) doBindControllerRest(pattern string, controller Controller, middleware []HandlerFunc) { // 遍历控制器,获取方法列表,并构造成uri m := make(handlerMap) - v := reflect.ValueOf(c) + v := reflect.ValueOf(controller) t := v.Type() sname := t.Elem().Name() pkgPath := t.Elem().PkgPath() @@ -172,6 +191,7 @@ func (s *Server) BindControllerRest(pattern string, c Controller) { name: mname, reflect: v.Elem().Type(), }, + middleware: middleware, } } s.bindHandlerByMap(m) diff --git a/net/ghttp/ghttp_server_service_handler.go b/net/ghttp/ghttp_server_service_handler.go index ce201726e..92172ff47 100644 --- a/net/ghttp/ghttp_server_service_handler.go +++ b/net/ghttp/ghttp_server_service_handler.go @@ -8,38 +8,33 @@ package ghttp import ( "bytes" - "reflect" - "runtime" + "github.com/gogf/gf/debug/gdebug" "strings" - "github.com/gogf/gf/os/glog" "github.com/gogf/gf/text/gstr" ) // 注意该方法是直接绑定函数的内存地址,执行的时候直接执行该方法,不会存在初始化新的控制器逻辑 func (s *Server) BindHandler(pattern string, handler HandlerFunc) { - s.bindHandlerItem(pattern, &handlerItem{ - itemName: runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name(), - itemType: gHANDLER_TYPE_HANDLER, - itemFunc: handler, - }) + s.doBindHandler(pattern, handler, nil) } // 绑定URI到操作函数/方法 // pattern的格式形如:/user/list, put:/user, delete:/user, post:/user@johng.cn // 支持RESTful的请求格式,具体业务逻辑由绑定的处理方法来执行 -func (s *Server) bindHandlerItem(pattern string, item *handlerItem) { - if s.Status() == SERVER_STATUS_RUNNING { - glog.Error("server handlers cannot be changed while running") - return - } - s.setHandler(pattern, item) +func (s *Server) doBindHandler(pattern string, handler HandlerFunc, middleware []HandlerFunc) { + s.setHandler(pattern, &handlerItem{ + itemName: gdebug.FuncPath(handler), + itemType: gHANDLER_TYPE_HANDLER, + itemFunc: handler, + middleware: middleware, + }) } // 通过映射数组绑定URI到操作函数/方法 func (s *Server) bindHandlerByMap(m handlerMap) { for p, h := range m { - s.bindHandlerItem(p, h) + s.setHandler(p, h) } } @@ -48,8 +43,8 @@ func (s *Server) bindHandlerByMap(m handlerMap) { // 规则2:pattern中的URI包含{.method}关键字,则替换该关键字为方法名称; // 规则2:如果不满足规则1,那么直接将防发明附加到pattern中的URI后面; func (s *Server) mergeBuildInNameToPattern(pattern string, structName, methodName string, allowAppend bool) string { - structName = s.nameToUrlPart(structName) - methodName = s.nameToUrlPart(methodName) + structName = s.nameToUri(structName) + methodName = s.nameToUri(methodName) pattern = strings.Replace(pattern, "{.struct}", structName, -1) if strings.Index(pattern, "{.method}") != -1 { return strings.Replace(pattern, "{.method}", methodName, -1) @@ -75,7 +70,7 @@ func (s *Server) mergeBuildInNameToPattern(pattern string, structName, methodNam // 规则1: 不处理名称,以原有名称构建成URI // 规则2: 仅转为小写,单词间不使用连接符号 // 规则3: 采用驼峰命名方式 -func (s *Server) nameToUrlPart(name string) string { +func (s *Server) nameToUri(name string) string { switch s.config.NameToUriType { case URI_TYPE_FULLNAME: return name diff --git a/net/ghttp/ghttp_server_service_object.go b/net/ghttp/ghttp_server_service_object.go index 3ed5568cc..e3147ad86 100644 --- a/net/ghttp/ghttp_server_service_object.go +++ b/net/ghttp/ghttp_server_service_object.go @@ -19,7 +19,35 @@ import ( // 绑定对象到URI请求处理中,会自动识别方法名称,并附加到对应的URI地址后面 // 第三个参数methods用以指定需要注册的方法,支持多个方法名称,多个方法以英文“,”号分隔,区分大小写 -func (s *Server) BindObject(pattern string, obj interface{}, methods ...string) { +func (s *Server) BindObject(pattern string, object interface{}, method ...string) { + bindMethod := "" + if len(method) > 0 { + bindMethod = method[0] + } + s.doBindObject(pattern, object, bindMethod, nil) +} + +// 绑定对象到URI请求处理中,会自动识别方法名称,并附加到对应的URI地址后面, +// 第三个参数method仅支持一个方法注册,不支持多个,并且区分大小写。 +func (s *Server) BindObjectMethod(pattern string, object interface{}, method string) { + s.doBindObjectMethod(pattern, object, method, nil) +} + +// 绑定对象到URI请求处理中,会自动识别方法名称,并附加到对应的URI地址后面, +// 需要注意对象方法的定义必须按照 ghttp.HandlerFunc 来定义 +func (s *Server) BindObjectRest(pattern string, object interface{}) { + s.doBindObjectRest(pattern, object, nil) +} + +func (s *Server) doBindObject(pattern string, object interface{}, method string, middleware []HandlerFunc) { + // Convert input method to map for convenience and high performance searching. + var methodMap map[string]bool + if len(method) > 0 { + methodMap = make(map[string]bool) + for _, v := range strings.Split(method, ",") { + methodMap[strings.TrimSpace(v)] = true + } + } // 当pattern中的method为all时,去掉该method,以便于后续方法判断 domain, method, path, err := s.parsePattern(pattern) if err != nil { @@ -30,15 +58,8 @@ func (s *Server) BindObject(pattern string, obj interface{}, methods ...string) pattern = s.serveHandlerKey("", path, domain) } - methodMap := (map[string]bool)(nil) - if len(methods) > 0 { - methodMap = make(map[string]bool) - for _, v := range strings.Split(methods[0], ",") { - methodMap[strings.TrimSpace(v)] = true - } - } m := make(handlerMap) - v := reflect.ValueOf(obj) + v := reflect.ValueOf(object) t := v.Type() sname := t.Elem().Name() initFunc := (func(*Request))(nil) @@ -78,11 +99,12 @@ func (s *Server) BindObject(pattern string, obj interface{}, methods ...string) } key := s.mergeBuildInNameToPattern(pattern, sname, mname, true) m[key] = &handlerItem{ - itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname), - itemType: gHANDLER_TYPE_OBJECT, - itemFunc: itemFunc, - initFunc: initFunc, - shutFunc: shutFunc, + itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname), + itemType: gHANDLER_TYPE_OBJECT, + itemFunc: itemFunc, + initFunc: initFunc, + shutFunc: shutFunc, + middleware: middleware, } // 如果方法中带有Index方法,那么额外自动增加一个路由规则匹配主URI。 // 注意,当pattern带有内置变量时,不会自动加该路由。 @@ -93,11 +115,12 @@ func (s *Server) BindObject(pattern string, obj interface{}, methods ...string) k = "/" + k } m[k] = &handlerItem{ - itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname), - itemType: gHANDLER_TYPE_OBJECT, - itemFunc: itemFunc, - initFunc: initFunc, - shutFunc: shutFunc, + itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname), + itemType: gHANDLER_TYPE_OBJECT, + itemFunc: itemFunc, + initFunc: initFunc, + shutFunc: shutFunc, + middleware: middleware, } } } @@ -106,9 +129,9 @@ func (s *Server) BindObject(pattern string, obj interface{}, methods ...string) // 绑定对象到URI请求处理中,会自动识别方法名称,并附加到对应的URI地址后面, // 第三个参数method仅支持一个方法注册,不支持多个,并且区分大小写。 -func (s *Server) BindObjectMethod(pattern string, obj interface{}, method string) { +func (s *Server) doBindObjectMethod(pattern string, object interface{}, method string, middleware []HandlerFunc) { m := make(handlerMap) - v := reflect.ValueOf(obj) + v := reflect.ValueOf(object) t := v.Type() sname := t.Elem().Name() mname := strings.TrimSpace(method) @@ -139,21 +162,20 @@ func (s *Server) BindObjectMethod(pattern string, obj interface{}, method string } key := s.mergeBuildInNameToPattern(pattern, sname, mname, false) m[key] = &handlerItem{ - itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname), - itemType: gHANDLER_TYPE_OBJECT, - itemFunc: itemFunc, - initFunc: initFunc, - shutFunc: shutFunc, + itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname), + itemType: gHANDLER_TYPE_OBJECT, + itemFunc: itemFunc, + initFunc: initFunc, + shutFunc: shutFunc, + middleware: middleware, } s.bindHandlerByMap(m) } -// 绑定对象到URI请求处理中,会自动识别方法名称,并附加到对应的URI地址后面, -// 需要注意对象方法的定义必须按照 ghttp.HandlerFunc 来定义 -func (s *Server) BindObjectRest(pattern string, obj interface{}) { +func (s *Server) doBindObjectRest(pattern string, object interface{}, middleware []HandlerFunc) { m := make(handlerMap) - v := reflect.ValueOf(obj) + v := reflect.ValueOf(object) t := v.Type() sname := t.Elem().Name() initFunc := (func(*Request))(nil) @@ -184,11 +206,12 @@ func (s *Server) BindObjectRest(pattern string, obj interface{}) { } key := s.mergeBuildInNameToPattern(mname+":"+pattern, sname, mname, false) m[key] = &handlerItem{ - itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname), - itemType: gHANDLER_TYPE_OBJECT, - itemFunc: itemFunc, - initFunc: initFunc, - shutFunc: shutFunc, + itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname), + itemType: gHANDLER_TYPE_OBJECT, + itemFunc: itemFunc, + initFunc: initFunc, + shutFunc: shutFunc, + middleware: middleware, } } s.bindHandlerByMap(m) diff --git a/net/ghttp/ghttp_unit_middleware_test.go b/net/ghttp/ghttp_unit_middleware_test.go index 110f3128d..08685d913 100644 --- a/net/ghttp/ghttp_unit_middleware_test.go +++ b/net/ghttp/ghttp_unit_middleware_test.go @@ -131,7 +131,7 @@ func Test_BindMiddleware_Basic3(t *testing.T) { }) } -func Test_BindMiddleware_Must_Be_Called(t *testing.T) { +func Test_BindMiddleware_Basic4(t *testing.T) { p := ports.PopRand() s := g.Server(p) s.Group("/", func(group *ghttp.RouterGroup) { @@ -157,9 +157,9 @@ func Test_BindMiddleware_Must_Be_Called(t *testing.T) { client := ghttp.NewClient() client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) - gtest.Assert(client.GetContent("/"), "12") + gtest.Assert(client.GetContent("/"), "Not Found") gtest.Assert(client.GetContent("/test"), "1test2") - gtest.Assert(client.PutContent("/test/none"), "12") + gtest.Assert(client.PutContent("/test/none"), "Not Found") }) } @@ -188,7 +188,7 @@ func Test_Middleware_With_Static(t *testing.T) { gtest.Assert(client.GetContent("/"), "index") gtest.Assert(client.GetContent("/test.html"), "test") - gtest.Assert(client.GetContent("/none"), "12") + gtest.Assert(client.GetContent("/none"), "Not Found") gtest.Assert(client.GetContent("/user/list"), "1list2") }) } @@ -214,7 +214,7 @@ func Test_Middleware_Status(t *testing.T) { client := ghttp.NewClient() client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) - gtest.Assert(client.GetContent("/"), "404") + gtest.Assert(client.GetContent("/"), "Not Found") gtest.Assert(client.GetContent("/user/list"), "200") resp, err := client.Get("/") @@ -268,7 +268,7 @@ func Test_Middleware_Hook_With_Static(t *testing.T) { time.Sleep(100 * time.Millisecond) gtest.Assert(a.Len(), 4) - gtest.Assert(client.GetContent("/none"), "a12b") + gtest.Assert(client.GetContent("/none"), "ab") time.Sleep(100 * time.Millisecond) gtest.Assert(a.Len(), 6) @@ -610,7 +610,7 @@ func Test_Middleware_CORSAndAuth(t *testing.T) { client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) gtest.Assert(client.GetContent("/"), "Not Found") - gtest.Assert(client.GetContent("/api.v2"), "Forbidden") + gtest.Assert(client.GetContent("/api.v2"), "Not Found") gtest.Assert(client.GetContent("/api.v2/user/list"), "Forbidden") gtest.Assert(client.GetContent("/api.v2/user/list", "token=123456"), "list") }) diff --git a/net/ghttp/ghttp_unit_router_group_group_test.go b/net/ghttp/ghttp_unit_router_group_group_test.go index c6cbd5d54..74a473560 100644 --- a/net/ghttp/ghttp_unit_router_group_group_test.go +++ b/net/ghttp/ghttp_unit_router_group_group_test.go @@ -67,17 +67,17 @@ func Test_Router_Group_Group(t *testing.T) { client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) gtest.Assert(client.GetContent("/"), "Not Found") - gtest.Assert(client.GetContent("/api.v2"), "12") + gtest.Assert(client.GetContent("/api.v2"), "Not Found") gtest.Assert(client.GetContent("/api.v2/test"), "1test2") - gtest.Assert(client.GetContent("/api.v2/hook"), "hook any12") - gtest.Assert(client.GetContent("/api.v2/hook/name"), "hook namehook any12") - gtest.Assert(client.GetContent("/api.v2/hook/name/any"), "hook any12") + gtest.Assert(client.GetContent("/api.v2/hook"), "hook any") + gtest.Assert(client.GetContent("/api.v2/hook/name"), "hook namehook any") + gtest.Assert(client.GetContent("/api.v2/hook/name/any"), "hook any") gtest.Assert(client.GetContent("/api.v2/order/list"), "1list2") - gtest.Assert(client.GetContent("/api.v2/order/update"), "12") + gtest.Assert(client.GetContent("/api.v2/order/update"), "Not Found") gtest.Assert(client.PutContent("/api.v2/order/update"), "1update2") - gtest.Assert(client.GetContent("/api.v2/user/drop"), "12") + gtest.Assert(client.GetContent("/api.v2/user/drop"), "Not Found") gtest.Assert(client.DeleteContent("/api.v2/user/drop"), "1drop2") - gtest.Assert(client.GetContent("/api.v2/user/edit"), "12") + gtest.Assert(client.GetContent("/api.v2/user/edit"), "Not Found") gtest.Assert(client.PostContent("/api.v2/user/edit"), "1edit2") gtest.Assert(client.GetContent("/api.v2/user/info"), "1info2") })