diff --git a/g/net/ghttp/ghttp_client_request_client.go b/g/net/ghttp/ghttp_client_request_client.go index 35ff9c85d..567bdc869 100644 --- a/g/net/ghttp/ghttp_client_request_client.go +++ b/g/net/ghttp/ghttp_client_request_client.go @@ -24,11 +24,13 @@ import ( // http客户端 type Client struct { - http.Client // 底层http client对象 - header map[string]string // HEADER信息Map - prefix string // 设置请求的URL前缀 - authUser string // HTTP基本权限设置:名称 - authPass string // HTTP基本权限设置:密码 + http.Client // 底层http client对象 + header map[string]string // HEADER信息Map + cookies map[string]string // 自定义COOKIE + prefix string // 设置请求的URL前缀 + authUser string // HTTP基本权限设置:名称 + authPass string // HTTP基本权限设置:密码 + browserMode bool // 是否模拟浏览器模式(自动保存提交COOKIE) } // http客户端对象指针 @@ -39,10 +41,16 @@ func NewClient() (*Client) { DisableKeepAlives: true, }, }, - header : make(map[string]string), + header : make(map[string]string), + cookies : make(map[string]string), } } +// 是否模拟浏览器模式(自动保存提交COOKIE) +func (c *Client) SetBrowserMode(enabled bool) { + c.browserMode = enabled +} + // 设置HTTP Header func (c *Client) SetHeader(key, value string) { c.header[key] = value @@ -58,6 +66,18 @@ func (c *Client) SetHeaderRaw(header string) { } } +// 设置COOKIE +func (c *Client) SetCookie(key, value string) { + c.cookies[key] = value +} + +// 使用Map设置COOKIE +func (c *Client) SetCookieMap(cookieMap map[string]string) { + for k, v := range cookieMap { + c.cookies[k] = v + } +} + // 设置请求的URL前缀 func (c *Client) SetPrefix(prefix string) { c.prefix = prefix @@ -143,6 +163,19 @@ func (c *Client) Post(url string, data...string) (*ClientResponse, error) { req.Header.Set(k, v) } } + // COOKIE + if len(c.cookies) > 0 { + headerCookie := "" + for k, v := range c.cookies { + if len(headerCookie) > 0 { + headerCookie += ";" + } + headerCookie += k + "=" + v + } + if len(headerCookie) > 0 { + req.Header.Set("Cookie", headerCookie) + } + } // HTTP账号密码 if len(c.authUser) > 0 { req.SetBasicAuth(c.authUser, c.authPass) @@ -152,8 +185,10 @@ func (c *Client) Post(url string, data...string) (*ClientResponse, error) { if err != nil { return nil, err } - r := &ClientResponse{} - r.Response = *resp + r := &ClientResponse{ + cookies : make(map[string]string), + } + r.Response = resp return r, nil } @@ -254,13 +289,40 @@ func (c *Client) DoRequest(method, url string, data...string) (*ClientResponse, req.Header.Set(k, v) } } + // COOKIE + if len(c.cookies) > 0 { + headerCookie := "" + for k, v := range c.cookies { + if len(headerCookie) > 0 { + headerCookie += ";" + } + headerCookie += k + "=" + v + } + if len(headerCookie) > 0 { + req.Header.Set("Cookie", headerCookie) + } + } // 执行请求 resp, err := c.Do(req) if err != nil { return nil, err } - r := &ClientResponse{} - r.Response = *resp + r := &ClientResponse{ + cookies : make(map[string]string), + } + r.Response = resp + // 浏览器模式 + if c.browserMode { + now := time.Now() + for _, v := range r.Cookies() { + if v.Expires.UnixNano() < now.UnixNano() { + delete(c.cookies, v.Name) + } else { + c.cookies[v.Name] = v.Value + } + } + } + //fmt.Println(url, c.cookies) return r, nil } diff --git a/g/net/ghttp/ghttp_client_response.go b/g/net/ghttp/ghttp_client_response.go index 21db3029f..c9dd587d4 100644 --- a/g/net/ghttp/ghttp_client_response.go +++ b/g/net/ghttp/ghttp_client_response.go @@ -10,11 +10,27 @@ package ghttp import ( "io/ioutil" "net/http" + "time" ) // 客户端请求结果对象 type ClientResponse struct { - http.Response + *http.Response + cookies map[string]string +} + +// 获得返回的指定COOKIE值 +func (r *ClientResponse) GetCookie(key string) string { + if r.cookies == nil { + now := time.Now() + for _, v := range r.Cookies() { + if v.Expires.UnixNano() < now.UnixNano() { + continue + } + r.cookies[v.Name] = v.Value + } + } + return r.cookies[key] } // 获取返回的数据(二进制). diff --git a/g/net/ghttp/ghttp_server_cookie.go b/g/net/ghttp/ghttp_server_cookie.go index ae608a25d..e01ba78f3 100644 --- a/g/net/ghttp/ghttp_server_cookie.go +++ b/g/net/ghttp/ghttp_server_cookie.go @@ -42,6 +42,7 @@ func GetCookie(r *Request) *Cookie { } return &Cookie { request : r, + server : r.Server, } } @@ -52,7 +53,6 @@ func (c *Cookie) init() { c.path = c.request.Server.GetCookiePath() c.domain = c.request.Server.GetCookieDomain() c.maxage = c.request.Server.GetCookieMaxAge() - c.server = c.request.Server c.response = c.request.Response // 如果没有设置COOKIE有效域名,那么设置HOST为默认有效域名 if c.domain == "" { @@ -138,9 +138,14 @@ func (c *Cookie) Get(key string) string { return "" } +// 删除COOKIE,使用默认的domain&path +func (c *Cookie) Remove(key string) { + c.SetCookie(key, "", c.domain, c.path, -86400) +} + // 标记该cookie在对应的域名和路径失效 // 删除cookie的重点是需要通知浏览器客户端cookie已过期 -func (c *Cookie) Remove(key, domain, path string) { +func (c *Cookie) RemoveCookie(key, domain, path string) { c.SetCookie(key, "", domain, path, -86400) } diff --git a/g/net/ghttp/ghttp_server_session.go b/g/net/ghttp/ghttp_server_session.go index bbd386e8e..f2bf95351 100644 --- a/g/net/ghttp/ghttp_server_session.go +++ b/g/net/ghttp/ghttp_server_session.go @@ -38,7 +38,6 @@ func GetSession(r *Request) *Session { } return &Session { request : r, - server : r.Server, } } @@ -46,6 +45,7 @@ func GetSession(r *Request) *Session { func (s *Session) init() { if len(s.id) == 0 { s.id = s.request.Cookie.SessionId() + s.server = s.request.Server s.data = s.server.sessions.GetOrSetFuncLock(s.id, func() interface{} { return gmap.NewStringInterfaceMap() }, s.server.GetSessionMaxAge()).(*gmap.StringInterfaceMap) @@ -87,7 +87,7 @@ func (s *Session) BatchSet(m map[string]interface{}) { // 判断键名是否存在 func (s *Session) Contains (key string) bool { - if len(s.id) > 0 || s.request.Cookie.Contains(s.server.GetSessionIdName()) { + if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" { s.init() return s.data.Contains(key) } @@ -96,7 +96,7 @@ func (s *Session) Contains (key string) bool { // 获取SESSION func (s *Session) Get (key string) interface{} { - if len(s.id) > 0 || s.request.Cookie.Contains(s.server.GetSessionIdName()) { + if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" { s.init() return s.data.Get(key) } @@ -110,7 +110,7 @@ func (s *Session) GetVar(key string) gvar.VarRead { // 删除session func (s *Session) Remove(key string) { - if len(s.id) > 0 || s.request.Cookie.Contains(s.server.GetSessionIdName()) { + if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" { s.init() s.data.Remove(key) } @@ -118,7 +118,7 @@ func (s *Session) Remove(key string) { // 清空session func (s *Session) Clear() { - if len(s.id) > 0 || s.request.Cookie.Contains(s.server.GetSessionIdName()) { + if len(s.id) > 0 || s.request.Cookie.GetSessionId() != "" { s.init() s.data.Clear() } @@ -126,7 +126,7 @@ func (s *Session) Clear() { // 更新过期时间(如果用在守护进程中长期使用,需要手动调用进行更新,防止超时被清除) func (s *Session) UpdateExpire() { - if len(s.id) > 0 { + if len(s.id) > 0 && s.data.Size() > 0 { s.server.sessions.Set(s.id, s.data, s.server.GetSessionMaxAge()*1000) } } diff --git a/g/net/ghttp/ghttp_unit_cookie_test.go b/g/net/ghttp/ghttp_unit_cookie_test.go new file mode 100644 index 000000000..8dc323cdb --- /dev/null +++ b/g/net/ghttp/ghttp_unit_cookie_test.go @@ -0,0 +1,62 @@ +// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +// COOKIE测试 +package ghttp_test + +import ( + "github.com/gogf/gf/g" + "github.com/gogf/gf/g/net/ghttp" + "github.com/gogf/gf/g/os/gtime" + "github.com/gogf/gf/g/test/gtest" + "testing" + "time" +) + +func Test_Cookie(t *testing.T) { + s := g.Server(gtime.Nanosecond()) + s.BindHandler("/set", func(r *ghttp.Request){ + r.Cookie.Set(r.Get("k"), r.Get("v")) + }) + s.BindHandler("/get", func(r *ghttp.Request){ + //fmt.Println(r.Cookie.Map()) + r.Response.Write(r.Cookie.Get(r.Get("k"))) + }) + s.BindHandler("/remove", func(r *ghttp.Request){ + r.Cookie.Remove(r.Get("k")) + }) + s.SetPort(8500) + s.SetDumpRouteMap(false) + go s.Run() + defer func() { + s.Shutdown() + time.Sleep(time.Second) + }() + // 等待启动完成 + time.Sleep(time.Second) + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetBrowserMode(true) + client.SetPrefix("http://127.0.0.1:8500") + r1, e1 := client.Get("/set?k=key1&v=100") + if r1 != nil { + defer r1.Close() + } + gtest.Assert(e1, nil) + gtest.Assert(r1.ReadAllString(), "") + + gtest.Assert(client.GetContent("/set?k=key2&v=200"), "") + + gtest.Assert(client.GetContent("/get?k=key1"), "100") + gtest.Assert(client.GetContent("/get?k=key2"), "200") + gtest.Assert(client.GetContent("/get?k=key3"), "") + gtest.Assert(client.GetContent("/remove?k=key1"), "") + gtest.Assert(client.GetContent("/remove?k=key3"), "") + gtest.Assert(client.GetContent("/remove?k=key4"), "") + gtest.Assert(client.GetContent("/get?k=key1"), "") + gtest.Assert(client.GetContent("/get?k=key2"), "200") + }) +} diff --git a/g/net/ghttp/ghttp_unit_session_test.go b/g/net/ghttp/ghttp_unit_session_test.go new file mode 100644 index 000000000..6ceffd80e --- /dev/null +++ b/g/net/ghttp/ghttp_unit_session_test.go @@ -0,0 +1,66 @@ +// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +// SESSION测试 +package ghttp_test + +import ( + "github.com/gogf/gf/g" + "github.com/gogf/gf/g/net/ghttp" + "github.com/gogf/gf/g/os/gtime" + "github.com/gogf/gf/g/test/gtest" + "testing" + "time" +) + +func Test_Session(t *testing.T) { + s := g.Server(gtime.Nanosecond()) + s.BindHandler("/set", func(r *ghttp.Request){ + r.Session.Set(r.Get("k"), r.Get("v")) + }) + s.BindHandler("/get", func(r *ghttp.Request){ + r.Response.Write(r.Session.Get(r.Get("k"))) + }) + s.BindHandler("/remove", func(r *ghttp.Request){ + r.Session.Remove(r.Get("k")) + }) + s.BindHandler("/clear", func(r *ghttp.Request){ + r.Session.Clear() + }) + s.SetPort(8600) + s.SetDumpRouteMap(false) + go s.Run() + defer func() { + s.Shutdown() + time.Sleep(time.Second) + }() + // 等待启动完成 + time.Sleep(time.Second) + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetBrowserMode(true) + client.SetPrefix("http://127.0.0.1:8600") + r1, e1 := client.Get("/set?k=key1&v=100") + if r1 != nil { + defer r1.Close() + } + gtest.Assert(e1, nil) + gtest.Assert(r1.ReadAllString(), "") + + gtest.Assert(client.GetContent("/set?k=key2&v=200"), "") + + gtest.Assert(client.GetContent("/get?k=key1"), "100") + gtest.Assert(client.GetContent("/get?k=key2"), "200") + gtest.Assert(client.GetContent("/get?k=key3"), "") + gtest.Assert(client.GetContent("/remove?k=key1"), "") + gtest.Assert(client.GetContent("/remove?k=key3"), "") + gtest.Assert(client.GetContent("/remove?k=key4"), "") + gtest.Assert(client.GetContent("/get?k=key1"), "") + gtest.Assert(client.GetContent("/get?k=key2"), "200") + gtest.Assert(client.GetContent("/clear"), "") + gtest.Assert(client.GetContent("/get?k=key2"), "") + }) +}