From 5fa33411fc2853b81c9212f18ad2dacaf769ba86 Mon Sep 17 00:00:00 2001 From: wanna Date: Thu, 5 Dec 2024 15:49:50 +0800 Subject: [PATCH] chore: add example for openapi/swagger authentication (#4004) --- example/httpserver/swagger/main.go | 9 +++++++ net/ghttp/ghttp_server_config_api.go | 5 ++++ ...ttp_z_unit_feature_openapi_swagger_test.go | 25 +++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/example/httpserver/swagger/main.go b/example/httpserver/swagger/main.go index d453588d3..3d934352c 100644 --- a/example/httpserver/swagger/main.go +++ b/example/httpserver/swagger/main.go @@ -45,5 +45,14 @@ func main() { new(Hello), ) }) + // if api.json requires authentication, add openApiBasicAuth handler + s.BindHookHandler(s.GetOpenApiPath(), ghttp.HookBeforeServe, openApiBasicAuth) s.Run() } + +func openApiBasicAuth(r *ghttp.Request) { + if !r.BasicAuth("OpenApiAuthUserName", "OpenApiAuthPass", "Restricted") { + r.ExitAll() + return + } +} diff --git a/net/ghttp/ghttp_server_config_api.go b/net/ghttp/ghttp_server_config_api.go index d9a66d3cb..a89888b1f 100644 --- a/net/ghttp/ghttp_server_config_api.go +++ b/net/ghttp/ghttp_server_config_api.go @@ -21,3 +21,8 @@ func (s *Server) SetSwaggerUITemplate(swaggerUITemplate string) { func (s *Server) SetOpenApiPath(path string) { s.config.OpenApiPath = path } + +// GetOpenApiPath returns the configuration of `OpenApiPath` of server. +func (s *Server) GetOpenApiPath() string { + return s.config.OpenApiPath +} diff --git a/net/ghttp/ghttp_z_unit_feature_openapi_swagger_test.go b/net/ghttp/ghttp_z_unit_feature_openapi_swagger_test.go index e46db1ddf..2ce097ac9 100644 --- a/net/ghttp/ghttp_z_unit_feature_openapi_swagger_test.go +++ b/net/ghttp/ghttp_z_unit_feature_openapi_swagger_test.go @@ -185,3 +185,28 @@ func Test_OpenApi_Method_All_Swagger(t *testing.T) { t.Assert(gstr.Contains(c.GetContent(ctx, "/api.json"), `/test/error`), true) }) } + +func Test_OpenApi_Auth(t *testing.T) { + s := g.Server(guid.S()) + apiPath := "/api.json" + s.SetOpenApiPath(apiPath) + s.BindHookHandler(s.GetOpenApiPath(), ghttp.HookBeforeServe, openApiBasicAuth) + s.Start() + defer s.Shutdown() + gtest.C(t, func(t *gtest.T) { + t.Assert(s.GetOpenApiPath(), apiPath) + c := g.Client() + c.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", s.GetListenedPort())) + t.Assert(c.GetContent(ctx, apiPath), "Unauthorized") + c.SetBasicAuth("OpenApiAuthUserName", "OpenApiAuthPass") + cc := c.GetContent(ctx, apiPath) + t.AssertNE(cc, "Unauthorized") + }) +} + +func openApiBasicAuth(r *ghttp.Request) { + if !r.BasicAuth("OpenApiAuthUserName", "OpenApiAuthPass", "Restricted") { + r.ExitAll() + return + } +}