diff --git a/net/ghttp/ghttp_server_config.go b/net/ghttp/ghttp_server_config.go index 437b662e1..d70aa299d 100644 --- a/net/ghttp/ghttp_server_config.go +++ b/net/ghttp/ghttp_server_config.go @@ -15,6 +15,7 @@ import ( "net" "net/http" "strconv" + "strings" "time" "github.com/gogf/gf/v2/internal/intlog" @@ -55,8 +56,7 @@ type ServerConfig struct { HTTPSAddr string `json:"httpsAddr"` // Listeners specifies the custom listeners. - // Listeners is a map, the key of map must specify the port of Address or HTTPSAddr. - Listeners map[int]net.Listener `json:"listeners"` + Listeners []net.Listener `json:"listeners"` // HTTPSCertPath specifies certification file path for HTTPS service. HTTPSCertPath string `json:"httpsCertPath"` @@ -418,32 +418,21 @@ func (s *Server) SetHTTPSPort(port ...int) { } // SetListener set the custom listener for the server. -// It will overwrite the address you specified before. -func (s *Server) SetListener(l net.Listener) error { - if l == nil { - return gerror.NewCodef(gcode.CodeInvalidParameter, "listener is nil") +func (s *Server) SetListener(listeners ...net.Listener) error { + if listeners == nil { + return gerror.NewCodef(gcode.CodeInvalidParameter, "listener can not be nil") } - port := (l.Addr().(*net.TCPAddr)).Port - s.config.Address = fmt.Sprintf(":%d", port) - s.config.Listeners = map[int]net.Listener{port: l} - return nil -} - -// SetListeners set the custom listeners for the server. -// The key of map should specify the port like: SetListeners(map[int]net.Listener{80: ln}). -// If the listener's port not match the port provided in map, the method will return error. -func (s *Server) SetListeners(listeners map[int]net.Listener) error { - for k, v := range listeners { - portIndeed := (v.Addr().(*net.TCPAddr)).Port - if portIndeed != k { - return gerror.NewCodef( - gcode.CodeInvalidParameter, - "listener specified by port %d listen at port %d indeed", - k, portIndeed, - ) + if len(listeners) > 0 { + ports := make([]string, len(listeners)) + for k, v := range listeners { + if v == nil { + return gerror.NewCodef(gcode.CodeInvalidParameter, "listener can not be nil") + } + ports[k] = fmt.Sprintf(":%d", (v.Addr().(*net.TCPAddr)).Port) } + s.config.Address = strings.Join(ports, ",") + s.config.Listeners = listeners } - s.config.Listeners = listeners return nil } diff --git a/net/ghttp/ghttp_server_graceful.go b/net/ghttp/ghttp_server_graceful.go index a62acd7e1..492ef6324 100644 --- a/net/ghttp/ghttp_server_graceful.go +++ b/net/ghttp/ghttp_server_graceful.go @@ -52,9 +52,14 @@ func (s *Server) newGracefulServer(address string, fd ...int) *gracefulServer { } if s.config.Listeners != nil { addrArray := gstr.SplitAndTrim(address, ":") - port, err := strconv.Atoi(addrArray[len(addrArray)-1]) + addrPort, err := strconv.Atoi(addrArray[len(addrArray)-1]) if err == nil { - gs.rawListener = s.config.Listeners[port] + for _, v := range s.config.Listeners { + listenerPort := (v.Addr().(*net.TCPAddr)).Port + if listenerPort == addrPort { + gs.rawListener = v + } + } } } return gs diff --git a/net/ghttp/ghttp_z_unit_feature_config_test.go b/net/ghttp/ghttp_z_unit_feature_config_test.go index f24db1506..a67cd0a16 100644 --- a/net/ghttp/ghttp_z_unit_feature_config_test.go +++ b/net/ghttp/ghttp_z_unit_feature_config_test.go @@ -8,6 +8,7 @@ package ghttp_test import ( "fmt" + "github.com/gogf/gf/v2/net/gtcp" "net" "testing" "time" @@ -24,12 +25,14 @@ import ( func Test_ConfigFromMap(t *testing.T) { gtest.C(t, func(t *gtest.T) { - ln, err := net.Listen("tcp", ":8199") + p, _ := gtcp.GetFreePort() + addr := fmt.Sprintf(":%d", p) + ln, err := net.Listen("tcp", addr) t.AssertNil(err) - listeners := map[int]net.Listener{8199: ln} + listeners := []net.Listener{ln} m := g.Map{ - "address": ":8199", + "address": addr, "listeners": listeners, "readTimeout": "60s", "indexFiles": g.Slice{"index.php", "main.php"}, diff --git a/net/ghttp/ghttp_z_unit_feature_custom_listeners_test.go b/net/ghttp/ghttp_z_unit_feature_custom_listeners_test.go index 8846db62d..c21c16469 100644 --- a/net/ghttp/ghttp_z_unit_feature_custom_listeners_test.go +++ b/net/ghttp/ghttp_z_unit_feature_custom_listeners_test.go @@ -7,6 +7,8 @@ package ghttp_test import ( + "fmt" + "github.com/gogf/gf/v2/net/gtcp" "github.com/gogf/gf/v2/test/gtest" "net" "testing" @@ -16,15 +18,19 @@ import ( "github.com/gogf/gf/v2/net/ghttp" ) -func Test_SetCustomListener(t *testing.T) { +func Test_SetSingleCustomListener(t *testing.T) { gtest.C(t, func(t *gtest.T) { - s := g.Server() + p, _ := gtcp.GetFreePort() + addr := fmt.Sprintf(":%d", p) + s := g.Server(g.Map{ + "address": addr, + }) s.Group("/", func(group *ghttp.RouterGroup) { group.GET("/test", func(r *ghttp.Request) { r.Response.Write("test") }) }) - ln, err := net.Listen("tcp", ":8199") + ln, err := net.Listen("tcp", addr) t.AssertNil(err) err = s.SetListener(ln) t.AssertNil(err) @@ -33,13 +39,10 @@ func Test_SetCustomListener(t *testing.T) { defer s.Shutdown() time.Sleep(100 * time.Millisecond) - s.GetListenedPort() - - t.AssertEQ(s.GetListenedPort(), 8199) }) } -func Test_SetRightCustomListeners(t *testing.T) { +func Test_SetMultipleCustomListeners(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := g.Server() s.Group("/", func(group *ghttp.RouterGroup) { @@ -47,21 +50,22 @@ func Test_SetRightCustomListeners(t *testing.T) { r.Response.Write("test") }) }) - s.SetAddr(":8199") - ln, err := net.Listen("tcp", ":8199") - t.AssertNil(err) - err = s.SetListeners(map[int]net.Listener{8299: ln}) - t.AssertNE(err, nil) - err = s.SetListeners(map[int]net.Listener{8199: ln}) - t.AssertNil(err) + p1, _ := gtcp.GetFreePort() + p2, _ := gtcp.GetFreePort() + + ln1, err := net.Listen("tcp", fmt.Sprintf(":%d", p1)) + ln2, err := net.Listen("tcp", fmt.Sprintf(":%d", p2)) + err = s.SetListener(ln1, ln2) + t.AssertEQ(err, nil) s.Start() defer s.Shutdown() time.Sleep(100 * time.Millisecond) - s.GetListenedPort() - - t.AssertEQ(s.GetListenedPort(), 8199) + ports := []int{p1, p2} + for _, p := range s.GetListenedPorts() { + t.AssertIN(p, ports) + } }) } @@ -73,17 +77,7 @@ func Test_SetWrongCustomListeners(t *testing.T) { r.Response.Write("test") }) }) - s.SetAddr(":8199") - ln, err := net.Listen("tcp", ":8299") - t.AssertNil(err) - err = s.SetListeners(map[int]net.Listener{8199: ln}) + err := s.SetListener(nil) t.AssertNQ(err, nil) - s.Start() - defer s.Shutdown() - - time.Sleep(100 * time.Millisecond) - s.GetListenedPort() - - t.AssertEQ(s.GetListenedPort(), 8199) }) }