improve status handler by supporting multiple status handler for package ghttp

This commit is contained in:
Jack
2020-11-25 16:37:41 +08:00
parent fc215ef0b2
commit 33292f54e0
6 changed files with 43 additions and 11 deletions

View File

@ -30,7 +30,7 @@ type (
serveTree map[string]interface{} // The route map tree.
serveCache *gcache.Cache // Server cache for internal usage.
routesMap map[string][]registeredRouteItem // Route map mainly for route dumps and repeated route checks.
statusHandlerMap map[string]HandlerFunc // Custom status handler map.
statusHandlerMap map[string][]HandlerFunc // Custom status handler map.
sessionManager *gsession.Manager // Session manager.
}

View File

@ -100,13 +100,13 @@ func GetServer(name ...interface{}) *Server {
servers: make([]*gracefulServer, 0),
closeChan: make(chan struct{}, 10000),
serverCount: gtype.NewInt(),
statusHandlerMap: make(map[string]HandlerFunc),
statusHandlerMap: make(map[string][]HandlerFunc),
serveTree: make(map[string]interface{}),
serveCache: gcache.New(),
routesMap: make(map[string][]registeredRouteItem),
}
// Initialize the server using default configurations.
if err := s.SetConfig(Config()); err != nil {
if err := s.SetConfig(NewConfig()); err != nil {
panic(err)
}
// Record the server to internal server mapping by name.

View File

@ -153,7 +153,7 @@ func (d *Domain) BindHookHandlerByMap(pattern string, hookmap map[string]Handler
func (d *Domain) BindStatusHandler(status int, handler HandlerFunc) {
for domain, _ := range d.domains {
d.server.setStatusHandler(d.server.statusHandlerKey(status, domain), handler)
d.server.addStatusHandler(d.server.statusHandlerKey(status, domain), handler)
}
}

View File

@ -151,11 +151,15 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// HTTP status handler.
if request.Response.Status != http.StatusOK {
if f := s.getStatusHandler(request.Response.Status, request); f != nil {
statusFuncArray := s.getStatusHandler(request.Response.Status, request)
for _, f := range statusFuncArray {
// Call custom status handler.
niceCallFunc(func() {
f(request)
})
if request.IsExited() {
break
}
}
}

View File

@ -11,7 +11,7 @@ import (
)
// getStatusHandler retrieves and returns the handler for given status code.
func (s *Server) getStatusHandler(status int, r *Request) HandlerFunc {
func (s *Server) getStatusHandler(status int, r *Request) []HandlerFunc {
domains := []string{r.GetHost(), gDEFAULT_DOMAIN}
for _, domain := range domains {
if f, ok := s.statusHandlerMap[s.statusHandlerKey(status, domain)]; ok {
@ -21,10 +21,13 @@ func (s *Server) getStatusHandler(status int, r *Request) HandlerFunc {
return nil
}
// setStatusHandler sets the handler for given status code.
// addStatusHandler sets the handler for given status code.
// The parameter <pattern> is like: domain#status
func (s *Server) setStatusHandler(pattern string, handler HandlerFunc) {
s.statusHandlerMap[pattern] = handler
func (s *Server) addStatusHandler(pattern string, handler HandlerFunc) {
if s.statusHandlerMap[pattern] == nil {
s.statusHandlerMap[pattern] = make([]HandlerFunc, 0)
}
s.statusHandlerMap[pattern] = append(s.statusHandlerMap[pattern], handler)
}
// statusHandlerKey creates and returns key for given status and domain.
@ -34,7 +37,7 @@ func (s *Server) statusHandlerKey(status int, domain string) string {
// BindStatusHandler registers handler for given status code.
func (s *Server) BindStatusHandler(status int, handler HandlerFunc) {
s.setStatusHandler(s.statusHandlerKey(status, gDEFAULT_DOMAIN), handler)
s.addStatusHandler(s.statusHandlerKey(status, gDEFAULT_DOMAIN), handler)
}
// BindStatusHandlerByMap registers handler for given status code using map.

View File

@ -34,10 +34,35 @@ func Test_StatusHandler(t *testing.T) {
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
client := ghttp.NewClient()
client := g.Client()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
t.Assert(client.GetContent("/404"), "404")
t.Assert(client.GetContent("/502"), "502")
})
}
func Test_StatusHandler_Multi(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
p, _ := ports.PopRand()
s := g.Server(p)
s.BindStatusHandler(502, func(r *ghttp.Request) {
r.Response.WriteOver("1")
})
s.BindStatusHandler(502, func(r *ghttp.Request) {
r.Response.Write("2")
})
s.BindHandler("/502", func(r *ghttp.Request) {
r.Response.WriteStatusExit(502)
})
s.SetDumpRouterMap(false)
s.SetPort(p)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
client := g.Client()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
t.Assert(client.GetContent("/502"), "12")
})
}