From 368312c8164609b486b949b2a339b7d4b37b6996 Mon Sep 17 00:00:00 2001 From: John Guo Date: Wed, 24 May 2023 17:21:28 +0800 Subject: [PATCH] add multiple methods support for object route (#2663) --- net/ghttp/ghttp_func.go | 6 + net/ghttp/ghttp_server_openapi.go | 30 ++--- net/ghttp/ghttp_server_router.go | 49 +++++++- net/ghttp/ghttp_server_router_serve.go | 2 +- net/ghttp/ghttp_server_service_handler.go | 7 +- net/ghttp/ghttp_server_service_object.go | 2 +- ...ttp_z_unit_feature_openapi_swagger_test.go | 119 ++++++++++++++++++ net/goai/goai_path.go | 18 ++- net/goai/goai_shema.go | 10 ++ net/goai/goai_shemas.go | 6 + net/goai/goai_z_unit_test.go | 6 +- 11 files changed, 225 insertions(+), 30 deletions(-) diff --git a/net/ghttp/ghttp_func.go b/net/ghttp/ghttp_func.go index 07c8fa8f1..cb550ba10 100644 --- a/net/ghttp/ghttp_func.go +++ b/net/ghttp/ghttp_func.go @@ -10,8 +10,14 @@ import ( "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/internal/httputil" + "github.com/gogf/gf/v2/text/gstr" ) +// SupportedMethods returns all supported HTTP methods. +func SupportedMethods() []string { + return gstr.SplitAndTrim(supportedHttpMethods, ",") +} + // BuildParams builds the request string for the http client. The `params` can be type of: // string/[]byte/map/struct/*struct. // diff --git a/net/ghttp/ghttp_server_openapi.go b/net/ghttp/ghttp_server_openapi.go index 0ca79a390..0ae5c67ba 100644 --- a/net/ghttp/ghttp_server_openapi.go +++ b/net/ghttp/ghttp_server_openapi.go @@ -19,27 +19,29 @@ func (s *Server) initOpenApi() { return } var ( - ctx = context.TODO() - err error - method string + ctx = context.TODO() + err error + methods []string ) for _, item := range s.GetRoutes() { switch item.Type { case HandlerTypeMiddleware, HandlerTypeHook: continue } - method = item.Method - if gstr.Equal(method, defaultMethod) { - method = "" - } if item.Handler.Info.Func == nil { - err = s.openapi.Add(goai.AddInput{ - Path: item.Route, - Method: method, - Object: item.Handler.Info.Value.Interface(), - }) - if err != nil { - s.Logger().Fatalf(ctx, `%+v`, err) + methods = []string{item.Method} + if gstr.Equal(item.Method, defaultMethod) { + methods = SupportedMethods() + } + for _, method := range methods { + err = s.openapi.Add(goai.AddInput{ + Path: item.Route, + Method: method, + Object: item.Handler.Info.Value.Interface(), + }) + if err != nil { + s.Logger().Fatalf(ctx, `%+v`, err) + } } } } diff --git a/net/ghttp/ghttp_server_router.go b/net/ghttp/ghttp_server_router.go index 3c4e9dded..383001189 100644 --- a/net/ghttp/ghttp_server_router.go +++ b/net/ghttp/ghttp_server_router.go @@ -82,7 +82,6 @@ func (s *Server) setHandler(ctx context.Context, in setHandlerInput) { if handler.Name == "" { handler.Name = runtime.FuncForPC(handler.Info.Value.Pointer()).Name() } - handler.Id = handlerIdGenerator.Add(1) if handler.Source == "" { _, file, line := gdebug.CallerWithFilter([]string{consts.StackFilterKeyForGoFrame}) handler.Source = fmt.Sprintf(`%s:%d`, file, line) @@ -92,21 +91,50 @@ func (s *Server) setHandler(ctx context.Context, in setHandlerInput) { s.Logger().Fatalf(ctx, `invalid pattern "%s", %+v`, pattern, err) return } - + // ==================================================================================== // Change the registered route according to meta info from its request structure. + // It supports multiple methods that are joined using char `,`. + // ==================================================================================== if handler.Info.Type != nil && handler.Info.Type.NumIn() == 2 { var objectReq = reflect.New(handler.Info.Type.In(1)) if v := gmeta.Get(objectReq, gtag.Path); !v.IsEmpty() { uri = v.String() } - if v := gmeta.Get(objectReq, gtag.Method); !v.IsEmpty() { - method = v.String() - } if v := gmeta.Get(objectReq, gtag.Domain); !v.IsEmpty() { domain = v.String() } + if v := gmeta.Get(objectReq, gtag.Method); !v.IsEmpty() { + method = v.String() + } + // Multiple methods registering, which are joined using char `,`. + if gstr.Contains(method, ",") { + methods := gstr.SplitAndTrim(method, ",") + for _, v := range methods { + // Each method has it own handler. + clonedHandler := *handler + s.doSetHandler(ctx, &clonedHandler, prefix, uri, pattern, v, domain) + } + return + } + // Converts `all` to `ALL`. + if gstr.Equal(method, defaultMethod) { + method = defaultMethod + } } + s.doSetHandler(ctx, handler, prefix, uri, pattern, method, domain) +} +func (s *Server) doSetHandler( + ctx context.Context, handler *HandlerItem, + prefix, uri, pattern, method, domain string, +) { + if !s.isValidMethod(method) { + s.Logger().Fatalf( + ctx, + `invalid method value "%s", should be in "%s" or "%s"`, + method, supportedHttpMethods, defaultMethod, + ) + } // Prefix for URI feature. if prefix != "" { uri = prefix + "/" + strings.TrimLeft(uri, "/") @@ -118,7 +146,6 @@ func (s *Server) setHandler(ctx context.Context, in setHandlerInput) { if len(uri) == 0 || uri[0] != '/' { s.Logger().Fatalf(ctx, `invalid pattern "%s", URI should lead with '/'`, pattern) - return } // Repeated router checks, this feature can be disabled by server configuration. @@ -145,6 +172,8 @@ func (s *Server) setHandler(ctx context.Context, in setHandlerInput) { } } } + // Unique id for each handler. + handler.Id = handlerIdGenerator.Add(1) // Create a new router by given parameter. handler.Router = &Router{ Uri: uri, @@ -248,6 +277,14 @@ func (s *Server) setHandler(ctx context.Context, in setHandlerInput) { s.routesMap[routerKey] = append(s.routesMap[routerKey], handler) } +func (s *Server) isValidMethod(method string) bool { + if gstr.Equal(method, defaultMethod) { + return true + } + _, ok := methodsMap[strings.ToUpper(method)] + return ok +} + // compareRouterPriority compares the priority between `newItem` and `oldItem`. It returns true // if `newItem`'s priority is higher than `oldItem`, else it returns false. The higher priority // item will be inserted into the router list before the other one. diff --git a/net/ghttp/ghttp_server_router_serve.go b/net/ghttp/ghttp_server_router_serve.go index 76b9f02b3..4163d36aa 100644 --- a/net/ghttp/ghttp_server_router_serve.go +++ b/net/ghttp/ghttp_server_router_serve.go @@ -157,7 +157,7 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*Han item := e.Value.(*HandlerItem) // Filter repeated handler items, especially the middleware and hook handlers. // It is necessary, do not remove this checks logic unless you really know how it is necessary. - if _, ok := repeatHandlerCheckMap[item.Id]; ok { + if _, isRepeatedHandler := repeatHandlerCheckMap[item.Id]; isRepeatedHandler { continue } else { repeatHandlerCheckMap[item.Id] = struct{}{} diff --git a/net/ghttp/ghttp_server_service_handler.go b/net/ghttp/ghttp_server_service_handler.go index 06ced8a00..8422654a3 100644 --- a/net/ghttp/ghttp_server_service_handler.go +++ b/net/ghttp/ghttp_server_service_handler.go @@ -91,9 +91,10 @@ func (s *Server) mergeBuildInNameToPattern(pattern string, structName, methodNam return pattern } // Check domain parameter. - array := strings.Split(pattern, "@") - uri := array[0] - uri = strings.TrimRight(uri, "/") + "/" + methodName + var ( + array = strings.Split(pattern, "@") + uri = strings.TrimRight(array[0], "/") + "/" + methodName + ) // Append the domain parameter to URI. if len(array) > 1 { return uri + "@" + array[1] diff --git a/net/ghttp/ghttp_server_service_object.go b/net/ghttp/ghttp_server_service_object.go index 8842841ab..f497ab309 100644 --- a/net/ghttp/ghttp_server_service_object.go +++ b/net/ghttp/ghttp_server_service_object.go @@ -88,7 +88,7 @@ func (s *Server) doBindObject(ctx context.Context, in doBindObjectInput) { s.Logger().Fatalf(ctx, `%+v`, err) return } - if strings.EqualFold(method, defaultMethod) { + if gstr.Equal(method, defaultMethod) { in.Pattern = s.serveHandlerKey("", path, domain) } var ( diff --git a/net/ghttp/ghttp_z_unit_feature_openapi_swagger_test.go b/net/ghttp/ghttp_z_unit_feature_openapi_swagger_test.go index d271cd9bd..e46db1ddf 100644 --- a/net/ghttp/ghttp_z_unit_feature_openapi_swagger_test.go +++ b/net/ghttp/ghttp_z_unit_feature_openapi_swagger_test.go @@ -66,3 +66,122 @@ func Test_OpenApi_Swagger(t *testing.T) { t.Assert(gstr.Contains(c.GetContent(ctx, "/api.json"), `/test/error`), true) }) } + +func Test_OpenApi_Multiple_Methods_Swagger(t *testing.T) { + type TestReq struct { + gmeta.Meta `method:"get,post" summary:"Test summary" tags:"Test"` + Age int + Name string + } + type TestRes struct { + Id int + Age int + Name string + } + s := g.Server(guid.S()) + s.SetSwaggerPath("/swagger") + s.SetOpenApiPath("/api.json") + s.Use(ghttp.MiddlewareHandlerResponse) + s.BindHandler("/test", func(ctx context.Context, req *TestReq) (res *TestRes, err error) { + return &TestRes{ + Id: 1, + Age: req.Age, + Name: req.Name, + }, nil + }) + s.BindHandler("/test/error", func(ctx context.Context, req *TestReq) (res *TestRes, err error) { + return &TestRes{ + Id: 1, + Age: req.Age, + Name: req.Name, + }, gerror.New("error") + }) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + gtest.C(t, func(t *gtest.T) { + openapi := s.GetOpenApi() + t.AssertNE(openapi.Paths["/test"].Get, nil) + t.AssertNE(openapi.Paths["/test"].Post, nil) + t.AssertNE(openapi.Paths["/test/error"].Get, nil) + t.AssertNE(openapi.Paths["/test/error"].Post, nil) + + t.Assert(len(openapi.Paths["/test"].Get.Parameters), 2) + t.Assert(len(openapi.Paths["/test/error"].Get.Parameters), 2) + t.Assert(len(openapi.Components.Schemas.Get(`github.com.gogf.gf.v2.net.ghttp_test.TestReq`).Value.Properties.Map()), 2) + + c := g.Client() + c.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort())) + + // Only works on GET & POST methods. + t.Assert(c.GetContent(ctx, "/test?age=18&name=john"), `{"code":0,"message":"","data":{"Id":1,"Age":18,"Name":"john"}}`) + t.Assert(c.GetContent(ctx, "/test/error"), `{"code":50,"message":"error","data":{"Id":1,"Age":0,"Name":""}}`) + t.Assert(c.PostContent(ctx, "/test?age=18&name=john"), `{"code":0,"message":"","data":{"Id":1,"Age":18,"Name":"john"}}`) + t.Assert(c.PostContent(ctx, "/test/error"), `{"code":50,"message":"error","data":{"Id":1,"Age":0,"Name":""}}`) + + // Not works on other methods. + t.Assert(c.PutContent(ctx, "/test?age=18&name=john"), `{"code":65,"message":"Not Found","data":null}`) + t.Assert(c.PutContent(ctx, "/test/error"), `{"code":65,"message":"Not Found","data":null}`) + + t.Assert(gstr.Contains(c.GetContent(ctx, "/swagger/"), `API Reference`), true) + t.Assert(gstr.Contains(c.GetContent(ctx, "/api.json"), `/test/error`), true) + }) +} + +func Test_OpenApi_Method_All_Swagger(t *testing.T) { + type TestReq struct { + gmeta.Meta `method:"all" summary:"Test summary" tags:"Test"` + Age int + Name string + } + type TestRes struct { + Id int + Age int + Name string + } + s := g.Server(guid.S()) + s.SetSwaggerPath("/swagger") + s.SetOpenApiPath("/api.json") + s.Use(ghttp.MiddlewareHandlerResponse) + s.BindHandler("/test", func(ctx context.Context, req *TestReq) (res *TestRes, err error) { + return &TestRes{ + Id: 1, + Age: req.Age, + Name: req.Name, + }, nil + }) + s.BindHandler("/test/error", func(ctx context.Context, req *TestReq) (res *TestRes, err error) { + return &TestRes{ + Id: 1, + Age: req.Age, + Name: req.Name, + }, gerror.New("error") + }) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + gtest.C(t, func(t *gtest.T) { + openapi := s.GetOpenApi() + t.AssertNE(openapi.Paths["/test"].Get, nil) + t.AssertNE(openapi.Paths["/test"].Post, nil) + t.AssertNE(openapi.Paths["/test"].Delete, nil) + t.AssertNE(openapi.Paths["/test/error"].Get, nil) + t.AssertNE(openapi.Paths["/test/error"].Post, nil) + t.AssertNE(openapi.Paths["/test/error"].Delete, nil) + + c := g.Client() + c.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort())) + + t.Assert(c.GetContent(ctx, "/test?age=18&name=john"), `{"code":0,"message":"","data":{"Id":1,"Age":18,"Name":"john"}}`) + t.Assert(c.GetContent(ctx, "/test/error"), `{"code":50,"message":"error","data":{"Id":1,"Age":0,"Name":""}}`) + t.Assert(c.PostContent(ctx, "/test?age=18&name=john"), `{"code":0,"message":"","data":{"Id":1,"Age":18,"Name":"john"}}`) + t.Assert(c.PostContent(ctx, "/test/error"), `{"code":50,"message":"error","data":{"Id":1,"Age":0,"Name":""}}`) + + t.Assert(gstr.Contains(c.GetContent(ctx, "/swagger/"), `API Reference`), true) + t.Assert(gstr.Contains(c.GetContent(ctx, "/api.json"), `/test/error`), true) + }) +} diff --git a/net/goai/goai_path.go b/net/goai/goai_path.go index 76a0bb493..4ffcc8415 100644 --- a/net/goai/goai_path.go +++ b/net/goai/goai_path.go @@ -310,6 +310,11 @@ func (oai *OpenApiV3) addPath(in addPathInput) error { } func (oai *OpenApiV3) removeOperationDuplicatedProperties(operation Operation) { + if len(operation.Parameters) == 0 { + // Nothing to do. + return + } + var ( duplicatedParameterNames []interface{} dataField string @@ -332,10 +337,15 @@ func (oai *OpenApiV3) removeOperationDuplicatedProperties(operation Operation) { } // Check request body schema ref. - if schema := oai.Components.Schemas.Get(requestBodyContent.Schema.Ref); schema != nil { - schema.Value.Required = oai.removeItemsFromArray(schema.Value.Required, duplicatedParameterNames) - schema.Value.Properties.Removes(duplicatedParameterNames) - continue + if requestBodyContent.Schema.Ref != "" { + if schema := oai.Components.Schemas.Get(requestBodyContent.Schema.Ref); schema != nil { + newSchema := schema.Value.Clone() + requestBodyContent.Schema.Ref = "" + requestBodyContent.Schema.Value = newSchema + newSchema.Required = oai.removeItemsFromArray(newSchema.Required, duplicatedParameterNames) + newSchema.Properties.Removes(duplicatedParameterNames) + continue + } } // Check the Value public field for the request body. diff --git a/net/goai/goai_shema.go b/net/goai/goai_shema.go index 62532e287..3dcf223aa 100644 --- a/net/goai/goai_shema.go +++ b/net/goai/goai_shema.go @@ -63,6 +63,16 @@ type Schema struct { ValidationRules string `json:"-"` } +// Clone only clones necessary attributes. +// TODO clone all attributes, or improve package deepcopy. +func (s *Schema) Clone() *Schema { + newSchema := *s + newSchema.Required = make([]string, len(s.Required)) + copy(newSchema.Required, s.Required) + newSchema.Properties = s.Properties.Clone() + return &newSchema +} + func (s Schema) MarshalJSON() ([]byte, error) { var ( b []byte diff --git a/net/goai/goai_shemas.go b/net/goai/goai_shemas.go index c8f6c6458..87fb4995c 100644 --- a/net/goai/goai_shemas.go +++ b/net/goai/goai_shemas.go @@ -26,6 +26,12 @@ func (s *Schemas) init() { } } +func (s *Schemas) Clone() Schemas { + newSchemas := createSchemas() + newSchemas.refs = s.refs.Clone() + return newSchemas +} + func (s *Schemas) Get(name string) *SchemaRef { s.init() value := s.refs.Get(name) diff --git a/net/goai/goai_z_unit_test.go b/net/goai/goai_z_unit_test.go index 69fbbfc80..86393119c 100644 --- a/net/goai/goai_z_unit_test.go +++ b/net/goai/goai_z_unit_test.go @@ -114,7 +114,11 @@ func TestOpenApiV3_Add(t *testing.T) { // Schema asserts. t.Assert(len(oai.Components.Schemas.Map()), 3) t.Assert(oai.Components.Schemas.Get(`github.com.gogf.gf.v2.net.goai_test.CreateResourceReq`).Value.Type, goai.TypeObject) - t.Assert(len(oai.Components.Schemas.Get(`github.com.gogf.gf.v2.net.goai_test.CreateResourceReq`).Value.Properties.Map()), 5) + + t.Assert(len(oai.Components.Schemas.Get(`github.com.gogf.gf.v2.net.goai_test.CreateResourceReq`).Value.Properties.Map()), 7) + t.Assert(len(oai.Paths["/test1/{appId}"].Put.RequestBody.Value.Content["application/json"].Schema.Value.Properties.Map()), 5) + t.Assert(len(oai.Paths["/test1/{appId}"].Post.RequestBody.Value.Content["application/json"].Schema.Value.Properties.Map()), 5) + t.Assert(oai.Paths["/test1/{appId}"].Post.Parameters[0].Value.Schema.Value.Type, goai.TypeInteger) t.Assert(oai.Paths["/test1/{appId}"].Post.Parameters[1].Value.Schema.Value.Type, goai.TypeString)