diff --git a/net/ghttp/ghttp_server_config.go b/net/ghttp/ghttp_server_config.go index d03dbe999..85ed565a2 100644 --- a/net/ghttp/ghttp_server_config.go +++ b/net/ghttp/ghttp_server_config.go @@ -9,6 +9,9 @@ package ghttp import ( "context" "crypto/tls" + "fmt" + "github.com/gogf/gf/v2/errors/gcode" + "github.com/gogf/gf/v2/errors/gerror" "net" "net/http" "strconv" @@ -414,25 +417,39 @@ func (s *Server) SetHTTPSPort(port ...int) { } } -// SetListeners set the custom listener for the server. +// SetListener set the custom listener for the server. +// It will overwrite the address you specified before. +func (s *Server) SetListener(l net.Listener) error { + addrArray := gstr.SplitAndTrim(l.Addr().String(), ":") + port, err := strconv.Atoi(addrArray[len(addrArray)-1]) + if err != nil { + return err + } + s.config.Address = fmt.Sprintf(":%d", port) + s.config.Listeners = map[int]net.Listener{port: l} + return nil +} + +// SetListeners set the custom listeners for the server. // The key of map should specify the port like: SetListeners(map[int]net.Listener{80: ln}). -// If no port found, the listener will be ignored. -// If the listener's port not match the port provided in map, the listener will be ignored. -func (s *Server) SetListeners(listeners map[int]net.Listener) { +// If the listener's port not match the port provided in map, the method will return error. +func (s *Server) SetListeners(listeners map[int]net.Listener) error { for k, v := range listeners { addrArray := gstr.SplitAndTrim(v.Addr().String(), ":") port, err := strconv.Atoi(addrArray[len(addrArray)-1]) if err != nil { - intlog.Printf(context.TODO(), `ignore the listener with port %d`, k) - delete(listeners, k) - continue + return err } if port != k { - intlog.Printf(context.TODO(), `ignore the listener with port %d`, k) - delete(listeners, k) + return gerror.NewCodef( + gcode.CodeInvalidParameter, + "listener specified by port %d should listen at port %d, but got port: %d", + k, k, port, + ) } } s.config.Listeners = listeners + return nil } // EnableHTTPS enables HTTPS with given certification and key files for the server. diff --git a/net/ghttp/ghttp_z_unit_feature_custom_listeners_test.go b/net/ghttp/ghttp_z_unit_feature_custom_listeners_test.go index 0e6f58824..3efa2c193 100644 --- a/net/ghttp/ghttp_z_unit_feature_custom_listeners_test.go +++ b/net/ghttp/ghttp_z_unit_feature_custom_listeners_test.go @@ -16,6 +16,29 @@ import ( "github.com/gogf/gf/v2/net/ghttp" ) +func Test_SetCustomListener(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + s := g.Server() + s.Group("/", func(group *ghttp.RouterGroup) { + group.GET("/test", func(r *ghttp.Request) { + r.Response.Write("test") + }) + }) + ln, err := net.Listen("tcp", ":8199") + t.AssertNil(err) + err = s.SetListener(ln) + t.AssertNil(err) + + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + s.GetListenedPort() + + t.AssertEQ(s.GetListenedPort(), 8199) + }) +} + func Test_SetRightCustomListeners(t *testing.T) { gtest.C(t, func(t *gtest.T) { s := g.Server() @@ -27,7 +50,10 @@ func Test_SetRightCustomListeners(t *testing.T) { s.SetAddr(":8199") ln, err := net.Listen("tcp", ":8199") t.AssertNil(err) - s.SetListeners(map[int]net.Listener{8199: ln}) + err = s.SetListeners(map[int]net.Listener{8299: ln}) + t.AssertNE(err, nil) + err = s.SetListeners(map[int]net.Listener{8199: ln}) + t.AssertNil(err) s.Start() defer s.Shutdown()