From 0140808460c573c9c255dc2b7f70068899a62015 Mon Sep 17 00:00:00 2001 From: John Guo Date: Tue, 13 Jul 2021 23:01:31 +0800 Subject: [PATCH] add handler extension feature for package ghttp --- frame/g/g_func.go | 6 ++ net/ghttp/ghttp.go | 27 ++++-- .../ghttp_middleware_handler_response.go | 29 +++++++ net/ghttp/ghttp_request.go | 20 +++++ net/ghttp/ghttp_request_middleware.go | 54 ++++++++++-- net/ghttp/ghttp_response.go | 4 +- net/ghttp/ghttp_server_config.go | 12 ++- net/ghttp/ghttp_server_domain.go | 31 ++----- net/ghttp/ghttp_server_router.go | 8 +- net/ghttp/ghttp_server_router_group.go | 17 +++- net/ghttp/ghttp_server_router_hook.go | 8 +- net/ghttp/ghttp_server_router_middleware.go | 15 +++- net/ghttp/ghttp_server_service_controller.go | 42 ++++++---- net/ghttp/ghttp_server_service_handler.go | 76 ++++++++++++++--- net/ghttp/ghttp_server_service_object.go | 70 ++++++---------- ...ghttp_unit_router_handler_extended_test.go | 82 +++++++++++++++++++ net/ghttp/ghttp_unit_router_names_test.go | 2 +- 17 files changed, 373 insertions(+), 130 deletions(-) create mode 100644 net/ghttp/ghttp_middleware_handler_response.go create mode 100644 net/ghttp/ghttp_unit_router_handler_extended_test.go diff --git a/frame/g/g_func.go b/frame/g/g_func.go index 64b284482..5034a8f08 100644 --- a/frame/g/g_func.go +++ b/frame/g/g_func.go @@ -7,6 +7,7 @@ package g import ( + "context" "github.com/gogf/gf/container/gvar" "github.com/gogf/gf/internal/empty" "github.com/gogf/gf/net/ghttp" @@ -75,3 +76,8 @@ func IsNil(value interface{}, traceSource ...bool) bool { func IsEmpty(value interface{}) bool { return empty.IsEmpty(value) } + +// RequestFromCtx retrieves and returns the Request object from context. +func RequestFromCtx(ctx context.Context) *ghttp.Request { + return ghttp.RequestFromCtx(ctx) +} diff --git a/net/ghttp/ghttp.go b/net/ghttp/ghttp.go index 54e2a4bbd..0f9ed208b 100644 --- a/net/ghttp/ghttp.go +++ b/net/ghttp/ghttp.go @@ -58,20 +58,30 @@ type ( handler *handlerItem // The handler. } + // HandlerFunc is request handler function. + HandlerFunc = func(r *Request) + + // handlerFuncInfo contains the HandlerFunc address and its reflect type. + handlerFuncInfo struct { + Func HandlerFunc // Handler function address. + Type reflect.Type // Reflect type information for current handler, which is used for extension of handler feature. + Value reflect.Value // Reflect value information for current handler, which is used for extension of handler feature. + } + // handlerItem is the registered handler for route handling, // including middleware and hook functions. handlerItem struct { itemId int // Unique handler item id mark. itemName string // Handler name, which is automatically retrieved from runtime stack when registered. itemType int // Handler type: object/handler/controller/middleware/hook. - itemFunc HandlerFunc // Handler address. - initFunc HandlerFunc // Initialization function when request enters the object(only available for object register type). - shutFunc HandlerFunc // Shutdown function when request leaves out the object(only available for object register type). + itemInfo handlerFuncInfo // Handler function information. + initFunc HandlerFunc // Initialization function when request enters the object (only available for object register type). + shutFunc HandlerFunc // Shutdown function when request leaves out the object (only available for object register type). middleware []HandlerFunc // Bound middleware array. ctrlInfo *handlerController // Controller information for reflect usage. - hookName string // Hook type name. + hookName string // Hook type name, only available for hook type. router *Router // Router object. - source string // Source file path:line when registering. + source string // Registering source file `path:line`. } // handlerParsedItem is the item parsed from URL.Path. @@ -98,9 +108,6 @@ type ( Stack() string } - // HandlerFunc is request handler function. - HandlerFunc = func(r *Request) - // Listening file descriptor mapping. // The key is either "http" or "https" and the value is its FD. listenerFdMap = map[string]string @@ -126,6 +133,10 @@ const ( exceptionExitAll = "exit_all" exceptionExitHook = "exit_hook" routeCacheDuration = time.Hour + methodNameInit = "Init" + methodNameShut = "Shut" + methodNameExit = "Exit" + ctxKeyForRequest = "gHttpRequestObject" ) var ( diff --git a/net/ghttp/ghttp_middleware_handler_response.go b/net/ghttp/ghttp_middleware_handler_response.go new file mode 100644 index 000000000..d84815275 --- /dev/null +++ b/net/ghttp/ghttp_middleware_handler_response.go @@ -0,0 +1,29 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package ghttp + +import ( + "github.com/gogf/gf/errors/gerror" + "github.com/gogf/gf/internal/intlog" +) + +// MiddlewareHandlerResponse is the default middleware handling handler response object and its error. +func MiddlewareHandlerResponse(r *Request) { + r.Middleware.Next() + res, err := r.GetHandlerResponse() + if err != nil { + r.Response.Writef( + `{"code":%d,"message":"%s"}`, + gerror.Code(err), + err.Error(), + ) + return + } + if exception := r.Response.WriteJson(res); exception != nil { + intlog.Error(r.Context(), exception) + } +} diff --git a/net/ghttp/ghttp_request.go b/net/ghttp/ghttp_request.go index fa86d02c5..e01dc1ac8 100644 --- a/net/ghttp/ghttp_request.go +++ b/net/ghttp/ghttp_request.go @@ -37,6 +37,7 @@ type Request struct { StaticFile *staticFile // Static file object for static file serving. context context.Context // Custom context for internal usage purpose. handlers []*handlerParsedItem // All matched handlers containing handler, hook and middleware for this request. + handlerResponse handlerResponse // Handler response object and its error value. hasHookHandler bool // A bool marking whether there's hook handler in the handlers for performance purpose. hasServeHandler bool // A bool marking whether there's serving handler in the handlers for performance purpose. parsedQuery bool // A bool marking whether the GET parameters parsed. @@ -57,6 +58,11 @@ type Request struct { viewParams gview.Params // Custom template view variables for this response. } +type handlerResponse struct { + Object interface{} + Error error +} + // staticFile is the file struct for static file service. type staticFile struct { File *gres.File // Resource file object. @@ -96,6 +102,15 @@ func newRequest(s *Server, r *http.Request, w http.ResponseWriter) *Request { return request } +// RequestFromCtx retrieves and returns the Request object from context. +func RequestFromCtx(ctx context.Context) *Request { + result := ctx.Value(ctxKeyForRequest) + if result != nil { + return result.(*Request) + } + return nil +} + // WebSocket upgrades current request as a websocket request. // It returns a new WebSocket object if success, or the error if failure. // Note that the request should be a websocket request, or it will surely fail upgrading. @@ -236,3 +251,8 @@ func (r *Request) ReloadParam() { r.parsedQuery = false r.bodyContent = nil } + +// GetHandlerResponse retrieves and returns the handler response object and its error. +func (r *Request) GetHandlerResponse() (res interface{}, err error) { + return r.handlerResponse.Object, r.handlerResponse.Error +} diff --git a/net/ghttp/ghttp_request_middleware.go b/net/ghttp/ghttp_request_middleware.go index caa0e0bb0..75321da0a 100644 --- a/net/ghttp/ghttp_request_middleware.go +++ b/net/ghttp/ghttp_request_middleware.go @@ -7,6 +7,7 @@ package ghttp import ( + "context" "github.com/gogf/gf/errors/gerror" "net/http" "reflect" @@ -91,9 +92,7 @@ func (m *middleware) Next() { }) } if !m.request.IsExited() { - niceCallFunc(func() { - item.handler.itemFunc(m.request) - }) + m.callHandlerFunc(item.handler.itemInfo) } if !m.request.IsExited() && item.handler.shutFunc != nil { niceCallFunc(func() { @@ -108,13 +107,13 @@ func (m *middleware) Next() { break } niceCallFunc(func() { - item.handler.itemFunc(m.request) + m.callHandlerFunc(item.handler.itemInfo) }) // Global middleware array. case handlerTypeMiddleware: niceCallFunc(func() { - item.handler.itemFunc(m.request) + item.handler.itemInfo.Func(m.request) }) // It does not continue calling next middleware after another middleware done. // There should be a "Next" function to be called in the middleware in order to manage the workflow. @@ -145,3 +144,48 @@ func (m *middleware) Next() { } } } + +func (m *middleware) callHandlerFunc(funcInfo handlerFuncInfo) { + niceCallFunc(func() { + if funcInfo.Func != nil { + funcInfo.Func(m.request) + } else { + var inputValues = []reflect.Value{ + reflect.ValueOf(context.WithValue( + m.request.Context(), ctxKeyForRequest, m.request, + )), + } + if funcInfo.Type.NumIn() == 2 { + var ( + request reflect.Value + ) + if funcInfo.Type.In(1).Kind() == reflect.Ptr { + request = reflect.New(funcInfo.Type.In(1).Elem()) + m.request.handlerResponse.Error = m.request.Parse(request.Interface()) + } else { + request = reflect.New(funcInfo.Type.In(1).Elem()).Elem() + m.request.handlerResponse.Error = m.request.Parse(request.Addr().Interface()) + } + if m.request.handlerResponse.Error != nil { + return + } + inputValues = append(inputValues, request) + } + + // Call handler with dynamic created parameter values. + results := funcInfo.Value.Call(inputValues) + switch len(results) { + case 1: + m.request.handlerResponse.Error = results[0].Interface().(error) + + case 2: + m.request.handlerResponse.Object = results[0].Interface() + if !results[1].IsNil() { + if v := results[1].Interface(); v != nil { + m.request.handlerResponse.Error = v.(error) + } + } + } + } + }) +} diff --git a/net/ghttp/ghttp_response.go b/net/ghttp/ghttp_response.go index a33613957..04524046a 100644 --- a/net/ghttp/ghttp_response.go +++ b/net/ghttp/ghttp_response.go @@ -110,7 +110,7 @@ func (r *Response) RedirectBack(code ...int) { r.RedirectTo(r.Request.GetReferer(), code...) } -// BufferString returns the buffered content as []byte. +// Buffer returns the buffered content as []byte. func (r *Response) Buffer() []byte { return r.buffer.Bytes() } @@ -136,7 +136,7 @@ func (r *Response) ClearBuffer() { r.buffer.Reset() } -// Output outputs the buffer content to the client and clears the buffer. +// Flush outputs the buffer content to the client and clears the buffer. func (r *Response) Flush() { if r.Server.config.ServerAgent != "" { r.Header().Set("Server", r.Server.config.ServerAgent) diff --git a/net/ghttp/ghttp_server_config.go b/net/ghttp/ghttp_server_config.go index 5021c26d8..9f2e8be33 100644 --- a/net/ghttp/ghttp_server_config.go +++ b/net/ghttp/ghttp_server_config.go @@ -30,10 +30,14 @@ import ( const ( defaultHttpAddr = ":80" // Default listening port for HTTP. defaultHttpsAddr = ":443" // Default listening port for HTTPS. - URI_TYPE_DEFAULT = 0 // Method name to URI converting type, which converts name to its lower case and joins the words using char '-'. - URI_TYPE_FULLNAME = 1 // Method name to URI converting type, which does no converting to the method name. - URI_TYPE_ALLLOWER = 2 // Method name to URI converting type, which converts name to its lower case. - URI_TYPE_CAMEL = 3 // Method name to URI converting type, which converts name to its camel case. + URI_TYPE_DEFAULT = 0 // Deprecated, please use UriTypeDefault instead. + URI_TYPE_FULLNAME = 1 // Deprecated, please use UriTypeFullName instead. + URI_TYPE_ALLLOWER = 2 // Deprecated, please use UriTypeAllLower instead. + URI_TYPE_CAMEL = 3 // Deprecated, please use UriTypeCamel instead. + UriTypeDefault = 0 // Method name to URI converting type, which converts name to its lower case and joins the words using char '-'. + UriTypeFullName = 1 // Method name to URI converting type, which does no converting to the method name. + UriTypeAllLower = 2 // Method name to URI converting type, which converts name to its lower case. + UriTypeCamel = 3 // Method name to URI converting type, which converts name to its camel case. ) // ServerConfig is the HTTP Server configuration manager. diff --git a/net/ghttp/ghttp_server_domain.go b/net/ghttp/ghttp_server_domain.go index 3fcb066e0..17a1e6e68 100644 --- a/net/ghttp/ghttp_server_domain.go +++ b/net/ghttp/ghttp_server_domain.go @@ -28,18 +28,15 @@ func (s *Server) Domain(domains string) *Domain { return d } -func (d *Domain) BindHandler(pattern string, handler HandlerFunc) { +func (d *Domain) BindHandler(pattern string, handler interface{}) { for domain, _ := range d.domains { d.server.BindHandler(pattern+"@"+domain, handler) } } -func (d *Domain) doBindHandler( - pattern string, handler HandlerFunc, - middleware []HandlerFunc, source string, -) { +func (d *Domain) doBindHandler(pattern string, funcInfo handlerFuncInfo, middleware []HandlerFunc, source string) { for domain, _ := range d.domains { - d.server.doBindHandler(pattern+"@"+domain, handler, middleware, source) + d.server.doBindHandler(pattern+"@"+domain, funcInfo, middleware, source) } } @@ -49,10 +46,7 @@ func (d *Domain) BindObject(pattern string, obj interface{}, methods ...string) } } -func (d *Domain) doBindObject( - pattern string, obj interface{}, methods string, - middleware []HandlerFunc, source string, -) { +func (d *Domain) doBindObject(pattern string, obj interface{}, methods string, middleware []HandlerFunc, source string) { for domain, _ := range d.domains { d.server.doBindObject(pattern+"@"+domain, obj, methods, middleware, source) } @@ -79,10 +73,7 @@ func (d *Domain) BindObjectRest(pattern string, obj interface{}) { } } -func (d *Domain) doBindObjectRest( - pattern string, obj interface{}, - middleware []HandlerFunc, source string, -) { +func (d *Domain) doBindObjectRest(pattern string, obj interface{}, middleware []HandlerFunc, source string) { for domain, _ := range d.domains { d.server.doBindObjectRest(pattern+"@"+domain, obj, middleware, source) } @@ -94,10 +85,7 @@ func (d *Domain) BindController(pattern string, c Controller, methods ...string) } } -func (d *Domain) doBindController( - pattern string, c Controller, methods string, - middleware []HandlerFunc, source string, -) { +func (d *Domain) doBindController(pattern string, c Controller, methods string, middleware []HandlerFunc, source string) { for domain, _ := range d.domains { d.server.doBindController(pattern+"@"+domain, c, methods, middleware, source) } @@ -109,10 +97,7 @@ func (d *Domain) BindControllerMethod(pattern string, c Controller, method strin } } -func (d *Domain) doBindControllerMethod( - pattern string, c Controller, method string, - middleware []HandlerFunc, source string, -) { +func (d *Domain) doBindControllerMethod(pattern string, c Controller, method string, middleware []HandlerFunc, source string) { for domain, _ := range d.domains { d.server.doBindControllerMethod(pattern+"@"+domain, c, method, middleware, source) } @@ -171,7 +156,7 @@ func (d *Domain) BindMiddleware(pattern string, handlers ...HandlerFunc) { func (d *Domain) BindMiddlewareDefault(handlers ...HandlerFunc) { for domain, _ := range d.domains { - d.server.BindMiddleware(gDEFAULT_MIDDLEWARE_PATTERN+"@"+domain, handlers...) + d.server.BindMiddleware(defaultMiddlewarePattern+"@"+domain, handlers...) } } diff --git a/net/ghttp/ghttp_server_router.go b/net/ghttp/ghttp_server_router.go index 6d31d5e78..85b6480a5 100644 --- a/net/ghttp/ghttp_server_router.go +++ b/net/ghttp/ghttp_server_router.go @@ -234,7 +234,7 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte // Compare the length of their URI, // but the fuzzy and named parts of the URI are not calculated to the result. - // Eg: + // Example: // /admin-goods-{page} > /admin-{page} // /{hash}.{type} > /{hash} var uriNew, uriOld string @@ -252,7 +252,7 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte } // Route type checks: {xxx} > :xxx > *xxx. - // Eg: + // Example: // /name/act > /{name}/:act var ( fuzzyCountFieldNew int @@ -321,9 +321,7 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte // If they have different router type, // the new router item has more priority than the other one. - if newItem.itemType == handlerTypeHandler || - newItem.itemType == handlerTypeObject || - newItem.itemType == handlerTypeController { + if newItem.itemType == handlerTypeHandler || newItem.itemType == handlerTypeObject || newItem.itemType == handlerTypeController { return true } diff --git a/net/ghttp/ghttp_server_router_group.go b/net/ghttp/ghttp_server_router_group.go index 3b4d01d0b..09b1ea84c 100644 --- a/net/ghttp/ghttp_server_router_group.go +++ b/net/ghttp/ghttp_server_router_group.go @@ -155,8 +155,10 @@ func (g *RouterGroup) Bind(items []GroupItem) *RouterGroup { switch bindType { case "REST": group.preBindToLocalArray("REST", gconv.String(item[0])+":"+gconv.String(item[1]), item[2]) + case "MIDDLEWARE": group.preBindToLocalArray("MIDDLEWARE", gconv.String(item[0])+":"+gconv.String(item[1]), item[2]) + default: if strings.EqualFold(bindType, "ALL") { bindType = "" @@ -309,11 +311,16 @@ func (g *RouterGroup) doBindRoutersToServer(item *preBindItem) *RouterGroup { } switch bindType { case "HANDLER": - if h, ok := object.(HandlerFunc); ok { + if reflect.ValueOf(object).Kind() == reflect.Func { + funcInfo, err := g.server.checkAndCreateFuncInfo(object, "", "", "") + if err != nil { + g.server.Logger().Error(err.Error()) + return g + } if g.server != nil { - g.server.doBindHandler(pattern, h, g.middleware, source) + g.server.doBindHandler(pattern, funcInfo, g.middleware, source) } else { - g.domain.doBindHandler(pattern, h, g.middleware, source) + g.domain.doBindHandler(pattern, funcInfo, g.middleware, source) } } else if g.isController(object) { if len(extras) > 0 { @@ -373,6 +380,7 @@ func (g *RouterGroup) doBindRoutersToServer(item *preBindItem) *RouterGroup { } } } else { + // At last, it treats the `object` as Object registering type. if g.server != nil { g.server.doBindObject(pattern, object, "", g.middleware, source) } else { @@ -380,6 +388,7 @@ func (g *RouterGroup) doBindRoutersToServer(item *preBindItem) *RouterGroup { } } } + case "REST": if g.isController(object) { if g.server != nil { @@ -398,6 +407,7 @@ func (g *RouterGroup) doBindRoutersToServer(item *preBindItem) *RouterGroup { g.domain.doBindObjectRest(pattern, object, g.middleware, source) } } + case "HOOK": if h, ok := object.(HandlerFunc); ok { if g.server != nil { @@ -414,6 +424,7 @@ func (g *RouterGroup) doBindRoutersToServer(item *preBindItem) *RouterGroup { // isController checks and returns whether given is a controller. // A controller should contains attributes: Request/Response/Server/Cookie/Session/View. +// Deprecated. func (g *RouterGroup) isController(value interface{}) bool { // Whether implements interface Controller. if _, ok := value.(Controller); !ok { diff --git a/net/ghttp/ghttp_server_router_hook.go b/net/ghttp/ghttp_server_router_hook.go index 4da68ad5e..7963a7531 100644 --- a/net/ghttp/ghttp_server_router_hook.go +++ b/net/ghttp/ghttp_server_router_hook.go @@ -9,6 +9,7 @@ package ghttp import ( "github.com/gogf/gf/debug/gdebug" "net/http" + "reflect" ) // BindHookHandler registers handler for specified hook. @@ -20,7 +21,10 @@ func (s *Server) doBindHookHandler(pattern string, hook string, handler HandlerF s.setHandler(pattern, &handlerItem{ itemType: handlerTypeHook, itemName: gdebug.FuncPath(handler), - itemFunc: handler, + itemInfo: handlerFuncInfo{ + Func: handler, + Type: reflect.TypeOf(handler), + }, hookName: hook, source: source, }) @@ -43,7 +47,7 @@ func (s *Server) callHookHandler(hook string, r *Request) { // DO NOT USE the router of the hook handler, // which can overwrite the router of serving handler. // r.Router = item.handler.router - if err := s.niceCallHookHandler(item.handler.itemFunc, r); err != nil { + if err := s.niceCallHookHandler(item.handler.itemInfo.Func, r); err != nil { switch err { case exceptionExit: break diff --git a/net/ghttp/ghttp_server_router_middleware.go b/net/ghttp/ghttp_server_router_middleware.go index 699ccd9d5..09faf7d1e 100644 --- a/net/ghttp/ghttp_server_router_middleware.go +++ b/net/ghttp/ghttp_server_router_middleware.go @@ -8,11 +8,12 @@ package ghttp import ( "github.com/gogf/gf/debug/gdebug" + "reflect" ) const ( // The default route pattern for global middleware. - gDEFAULT_MIDDLEWARE_PATTERN = "/*" + defaultMiddlewarePattern = "/*" ) // BindMiddleware registers one or more global middleware to the server. @@ -24,7 +25,10 @@ func (s *Server) BindMiddleware(pattern string, handlers ...HandlerFunc) { s.setHandler(pattern, &handlerItem{ itemType: handlerTypeMiddleware, itemName: gdebug.FuncPath(handler), - itemFunc: handler, + itemInfo: handlerFuncInfo{ + Func: handler, + Type: reflect.TypeOf(handler), + }, }) } } @@ -34,10 +38,13 @@ func (s *Server) BindMiddleware(pattern string, handlers ...HandlerFunc) { // before or after service handler. func (s *Server) BindMiddlewareDefault(handlers ...HandlerFunc) { for _, handler := range handlers { - s.setHandler(gDEFAULT_MIDDLEWARE_PATTERN, &handlerItem{ + s.setHandler(defaultMiddlewarePattern, &handlerItem{ itemType: handlerTypeMiddleware, itemName: gdebug.FuncPath(handler), - itemFunc: handler, + itemInfo: handlerFuncInfo{ + Func: handler, + Type: reflect.TypeOf(handler), + }, }) } } diff --git a/net/ghttp/ghttp_server_service_controller.go b/net/ghttp/ghttp_server_service_controller.go index 94f0940c1..b2df34ec4 100644 --- a/net/ghttp/ghttp_server_service_controller.go +++ b/net/ghttp/ghttp_server_service_controller.go @@ -73,18 +73,20 @@ func (s *Server) doBindController( pattern = s.serveHandlerKey("", path, domain) } // Retrieve a list of methods, create construct corresponding URI. - m := make(map[string]*handlerItem) - v := reflect.ValueOf(controller) - t := v.Type() - pkgPath := t.Elem().PkgPath() - pkgName := gfile.Basename(pkgPath) - structName := t.Elem().Name() + var ( + m = make(map[string]*handlerItem) + v = reflect.ValueOf(controller) + t = v.Type() + pkgPath = t.Elem().PkgPath() + pkgName = gfile.Basename(pkgPath) + structName = t.Elem().Name() + ) for i := 0; i < v.NumMethod(); i++ { methodName := t.Method(i).Name if methodMap != nil && !methodMap[methodName] { continue } - if methodName == "Init" || methodName == "Shut" || methodName == "Exit" { + if methodName == methodNameInit || methodName == methodNameShut || methodName == methodNameExit { continue } ctlName := gstr.Replace(t.String(), fmt.Sprintf(`%s.`, pkgName), "") @@ -153,12 +155,14 @@ func (s *Server) doBindControllerMethod( middleware []HandlerFunc, source string, ) { - m := make(map[string]*handlerItem) - v := reflect.ValueOf(controller) - t := v.Type() - structName := t.Elem().Name() - methodName := strings.TrimSpace(method) - methodValue := v.MethodByName(methodName) + var ( + m = make(map[string]*handlerItem) + v = reflect.ValueOf(controller) + t = v.Type() + structName = t.Elem().Name() + methodName = strings.TrimSpace(method) + methodValue = v.MethodByName(methodName) + ) if !methodValue.IsValid() { s.Logger().Fatal("invalid method name: " + methodName) return @@ -194,11 +198,13 @@ func (s *Server) doBindControllerRest( pattern string, controller Controller, middleware []HandlerFunc, source string, ) { - m := make(map[string]*handlerItem) - v := reflect.ValueOf(controller) - t := v.Type() - pkgPath := t.Elem().PkgPath() - structName := t.Elem().Name() + var ( + m = make(map[string]*handlerItem) + v = reflect.ValueOf(controller) + t = v.Type() + pkgPath = t.Elem().PkgPath() + structName = t.Elem().Name() + ) for i := 0; i < v.NumMethod(); i++ { methodName := t.Method(i).Name if _, ok := methodsMap[strings.ToUpper(methodName)]; !ok { diff --git a/net/ghttp/ghttp_server_service_handler.go b/net/ghttp/ghttp_server_service_handler.go index c81ee3900..36a1a1c66 100644 --- a/net/ghttp/ghttp_server_service_handler.go +++ b/net/ghttp/ghttp_server_service_handler.go @@ -9,27 +9,37 @@ package ghttp import ( "bytes" "github.com/gogf/gf/debug/gdebug" + "github.com/gogf/gf/errors/gerror" + "reflect" "strings" "github.com/gogf/gf/text/gstr" ) // BindHandler registers a handler function to server with given pattern. -func (s *Server) BindHandler(pattern string, handler HandlerFunc) { - s.doBindHandler(pattern, handler, nil, "") +// The parameter `handler` can be type of: +// func(*ghttp.Request) +// func(context.Context) +// func(context.Context,TypeRequest) +// func(context.Context,TypeRequest) error +// func(context.Context,TypeRequest)(TypeResponse,error) +func (s *Server) BindHandler(pattern string, handler interface{}) { + funcInfo, err := s.checkAndCreateFuncInfo(handler, "", "", "") + if err != nil { + s.Logger().Error(err.Error()) + return + } + s.doBindHandler(pattern, funcInfo, nil, "") } // doBindHandler registers a handler function to server with given pattern. // The parameter is like: // /user/list, put:/user, delete:/user, post:/user@goframe.org -func (s *Server) doBindHandler( - pattern string, handler HandlerFunc, - middleware []HandlerFunc, source string, -) { +func (s *Server) doBindHandler(pattern string, funcInfo handlerFuncInfo, middleware []HandlerFunc, source string) { s.setHandler(pattern, &handlerItem{ - itemName: gdebug.FuncPath(handler), + itemName: gdebug.FuncPath(funcInfo.Func), itemType: handlerTypeHandler, - itemFunc: handler, + itemInfo: funcInfo, middleware: middleware, source: source, }) @@ -77,13 +87,13 @@ func (s *Server) mergeBuildInNameToPattern(pattern string, structName, methodNam // Rule 3: Use camel case naming. func (s *Server) nameToUri(name string) string { switch s.config.NameToUriType { - case URI_TYPE_FULLNAME: + case UriTypeFullName: return name - case URI_TYPE_ALLLOWER: + case UriTypeAllLower: return strings.ToLower(name) - case URI_TYPE_CAMEL: + case UriTypeCamel: part := bytes.NewBuffer(nil) if gstr.IsLetterUpper(name[0]) { part.WriteByte(name[0] + 32) @@ -93,8 +103,9 @@ func (s *Server) nameToUri(name string) string { part.WriteString(name[1:]) return part.String() - case URI_TYPE_DEFAULT: + case UriTypeDefault: fallthrough + default: part := bytes.NewBuffer(nil) for i := 0; i < len(name); i++ { @@ -110,3 +121,44 @@ func (s *Server) nameToUri(name string) string { return part.String() } } + +func (s *Server) checkAndCreateFuncInfo(f interface{}, pkgPath, objName, methodName string) (info handlerFuncInfo, err error) { + handlerFunc, ok := f.(HandlerFunc) + if !ok { + reflectType := reflect.TypeOf(f) + if reflectType.NumIn() == 0 || reflectType.NumIn() > 2 || reflectType.NumOut() > 2 { + if pkgPath != "" { + err = gerror.Newf( + `invalid handler: %s.%s.%s defined as "%s", but "func(*ghttp.Request)" or "func(context.Context)/func(context.Context,Request)/func(context.Context,Request) error/func(context.Context,Request)(Response,error)" is required`, + pkgPath, objName, methodName, reflect.TypeOf(f).String(), + ) + } else { + err = gerror.Newf( + `invalid handler: defined as "%s", but "func(*ghttp.Request)" or "func(context.Context)/func(context.Context,Request)/func(context.Context,Request) error/func(context.Context,Request)(Response,error)" is required`, + reflect.TypeOf(f).String(), + ) + } + return + } + + if reflectType.In(0).String() != "context.Context" { + err = gerror.Newf( + `invalid handler: defined as "%s", but the first input parameter should be type of "context.Context"`, + reflect.TypeOf(f).String(), + ) + return + } + + if reflectType.NumOut() > 0 && reflectType.Out(reflectType.NumOut()-1).String() != "error" { + err = gerror.Newf( + `invalid handler: defined as "%s", but the last output parameter should be type of "error"`, + reflect.TypeOf(f).String(), + ) + return + } + } + info.Func = handlerFunc + info.Type = reflect.TypeOf(f) + info.Value = reflect.ValueOf(f) + return +} diff --git a/net/ghttp/ghttp_server_service_object.go b/net/ghttp/ghttp_server_service_object.go index d86bf24e2..872cf179c 100644 --- a/net/ghttp/ghttp_server_service_object.go +++ b/net/ghttp/ghttp_server_service_object.go @@ -46,10 +46,7 @@ func (s *Server) BindObjectRest(pattern string, object interface{}) { s.doBindObjectRest(pattern, object, nil, "") } -func (s *Server) doBindObject( - pattern string, object interface{}, method string, - middleware []HandlerFunc, source string, -) { +func (s *Server) doBindObject(pattern string, object interface{}, method string, middleware []HandlerFunc, source string) { // Convert input method to map for convenience and high performance searching purpose. var methodMap map[string]bool if len(method) > 0 { @@ -104,26 +101,18 @@ func (s *Server) doBindObject( if objName[0] == '*' { objName = fmt.Sprintf(`(%s)`, objName) } - itemFunc, ok := v.Method(i).Interface().(func(*Request)) - if !ok { - if len(methodMap) > 0 { - s.Logger().Errorf( - `invalid route method: %s.%s.%s defined as "%s", but "func(*ghttp.Request)" is required for object registry`, - pkgPath, objName, methodName, v.Method(i).Type().String(), - ) - } else { - s.Logger().Debugf( - `ignore route method: %s.%s.%s defined as "%s", no match "func(*ghttp.Request)" for object registry`, - pkgPath, objName, methodName, v.Method(i).Type().String(), - ) - } - continue + + funcInfo, err := s.checkAndCreateFuncInfo(v.Method(i).Interface(), pkgPath, objName, methodName) + if err != nil { + s.Logger().Error(err.Error()) + return } + key := s.mergeBuildInNameToPattern(pattern, structName, methodName, true) m[key] = &handlerItem{ itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, methodName), itemType: handlerTypeObject, - itemFunc: itemFunc, + itemInfo: funcInfo, initFunc: initFunc, shutFunc: shutFunc, middleware: middleware, @@ -145,7 +134,7 @@ func (s *Server) doBindObject( m[k] = &handlerItem{ itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, methodName), itemType: handlerTypeObject, - itemFunc: itemFunc, + itemInfo: funcInfo, initFunc: initFunc, shutFunc: shutFunc, middleware: middleware, @@ -194,19 +183,18 @@ func (s *Server) doBindObjectMethod( if objName[0] == '*' { objName = fmt.Sprintf(`(%s)`, objName) } - itemFunc, ok := methodValue.Interface().(func(*Request)) - if !ok { - s.Logger().Errorf( - `invalid route method: %s.%s.%s defined as "%s", but "func(*ghttp.Request)" is required for object registry`, - pkgPath, objName, methodName, methodValue.Type().String(), - ) + + funcInfo, err := s.checkAndCreateFuncInfo(methodValue.Interface(), pkgPath, objName, methodName) + if err != nil { + s.Logger().Error(err.Error()) return } + key := s.mergeBuildInNameToPattern(pattern, structName, methodName, false) m[key] = &handlerItem{ itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, methodName), itemType: handlerTypeObject, - itemFunc: itemFunc, + itemInfo: funcInfo, initFunc: initFunc, shutFunc: shutFunc, middleware: middleware, @@ -216,10 +204,7 @@ func (s *Server) doBindObjectMethod( s.bindHandlerByMap(m) } -func (s *Server) doBindObjectRest( - pattern string, object interface{}, - middleware []HandlerFunc, source string, -) { +func (s *Server) doBindObjectRest(pattern string, object interface{}, middleware []HandlerFunc, source string) { var ( m = make(map[string]*handlerItem) v = reflect.ValueOf(object) @@ -236,11 +221,11 @@ func (s *Server) doBindObjectRest( t = v.Type() } structName := t.Elem().Name() - if v.MethodByName("Init").IsValid() { - initFunc = v.MethodByName("Init").Interface().(func(*Request)) + if v.MethodByName(methodNameInit).IsValid() { + initFunc = v.MethodByName(methodNameInit).Interface().(func(*Request)) } - if v.MethodByName("Shut").IsValid() { - shutFunc = v.MethodByName("Shut").Interface().(func(*Request)) + if v.MethodByName(methodNameShut).IsValid() { + shutFunc = v.MethodByName(methodNameShut).Interface().(func(*Request)) } pkgPath := t.Elem().PkgPath() for i := 0; i < v.NumMethod(); i++ { @@ -253,19 +238,18 @@ func (s *Server) doBindObjectRest( if objName[0] == '*' { objName = fmt.Sprintf(`(%s)`, objName) } - itemFunc, ok := v.Method(i).Interface().(func(*Request)) - if !ok { - s.Logger().Errorf( - `invalid route method: %s.%s.%s defined as "%s", but "func(*ghttp.Request)" is required for object registry`, - pkgPath, objName, methodName, v.Method(i).Type().String(), - ) - continue + + funcInfo, err := s.checkAndCreateFuncInfo(v.Method(i).Interface(), pkgPath, objName, methodName) + if err != nil { + s.Logger().Error(err.Error()) + return } + key := s.mergeBuildInNameToPattern(methodName+":"+pattern, structName, methodName, false) m[key] = &handlerItem{ itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, methodName), itemType: handlerTypeObject, - itemFunc: itemFunc, + itemInfo: funcInfo, initFunc: initFunc, shutFunc: shutFunc, middleware: middleware, diff --git a/net/ghttp/ghttp_unit_router_handler_extended_test.go b/net/ghttp/ghttp_unit_router_handler_extended_test.go new file mode 100644 index 000000000..ddcf4da3d --- /dev/null +++ b/net/ghttp/ghttp_unit_router_handler_extended_test.go @@ -0,0 +1,82 @@ +// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package ghttp_test + +import ( + "context" + "fmt" + "github.com/gogf/gf/errors/gerror" + "testing" + "time" + + "github.com/gogf/gf/frame/g" + "github.com/gogf/gf/net/ghttp" + "github.com/gogf/gf/test/gtest" +) + +func Test_Router_Handler_Extended_Handler_Basic(t *testing.T) { + p, _ := ports.PopRand() + s := g.Server(p) + s.BindHandler("/test", func(ctx context.Context) { + r := g.RequestFromCtx(ctx) + r.Response.Write("test") + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + gtest.C(t, func(t *gtest.T) { + client := g.Client() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + t.Assert(client.GetContent("/test"), "test") + }) +} + +func Test_Router_Handler_Extended_Handler_WithObject(t *testing.T) { + type TestReq struct { + Age int + Name string + } + type TestRes struct { + Id int + Age int + Name string + } + p, _ := ports.PopRand() + s := g.Server(p) + s.Use(ghttp.MiddlewareHandlerResponse) + s.BindHandler("/test", func(ctx context.Context, req *TestReq) (res *TestRes, err error) { + return &TestRes{ + Id: 1, + Age: req.Age, + Name: req.Name, + }, nil + }) + s.BindHandler("/test/error", func(ctx context.Context, req *TestReq) (res *TestRes, err error) { + return &TestRes{ + Id: 1, + Age: req.Age, + Name: req.Name, + }, gerror.New("error") + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + gtest.C(t, func(t *gtest.T) { + client := g.Client() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + t.Assert(client.GetContent("/test?age=18&name=john"), `{"Id":1,"Age":18,"Name":"john"}`) + t.Assert(client.GetContent("/test/error"), `{"code":-1,"message":"error"}`) + }) +} diff --git a/net/ghttp/ghttp_unit_router_names_test.go b/net/ghttp/ghttp_unit_router_names_test.go index bf6f40e8d..5c79cd9f8 100644 --- a/net/ghttp/ghttp_unit_router_names_test.go +++ b/net/ghttp/ghttp_unit_router_names_test.go @@ -88,7 +88,7 @@ func Test_NameToUri_Camel(t *testing.T) { func Test_NameToUri_Default(t *testing.T) { p, _ := ports.PopRand() s := g.Server(p) - s.SetNameToUriType(ghttp.URI_TYPE_DEFAULT) + s.SetNameToUriType(ghttp.UriTypeDefault) s.BindObject("/{.struct}/{.method}", new(NamesObject)) s.SetPort(p) s.SetDumpRouterMap(false)