improve cookie feature for ghttp.Server

This commit is contained in:
Jack
2020-08-03 20:00:00 +08:00
parent 3e3b5557f7
commit 6d68277db8
3 changed files with 97 additions and 56 deletions

View File

@ -9,28 +9,17 @@ package ghttp
import (
"net/http"
"time"
"github.com/gogf/gf/os/gtime"
)
// Cookie for HTTP COOKIE management.
type Cookie struct {
data map[string]CookieItem // Underlying cookie items.
path string // The default cookie path.
domain string // The default cookie domain
maxage time.Duration // The default cookie maxage.
server *Server // Belonged HTTP server
request *Request // Belonged HTTP request.
response *Response // Belonged HTTP response.
}
// CookieItem is cookie item stored in Cookie management object.
type CookieItem struct {
value string // Cookie value.
domain string // Cookie domain.
path string // Cookie path.
expireAt int64 // Cookie expiration timestamp.
httpOnly bool
data map[string]*http.Cookie // Underlying cookie items.
path string // The default cookie path.
domain string // The default cookie domain
maxAge time.Duration // The default cookie max age.
server *Server // Belonged HTTP server
request *Request // Belonged HTTP request.
response *Response // Belonged HTTP response.
}
// GetCookie creates or retrieves a cookie object with given request.
@ -48,21 +37,20 @@ func GetCookie(r *Request) *Cookie {
// init does lazy initialization for cookie object.
func (c *Cookie) init() {
if c.data == nil {
c.data = make(map[string]CookieItem)
c.path = c.request.Server.GetCookiePath()
c.domain = c.request.Server.GetCookieDomain()
c.maxage = c.request.Server.GetCookieMaxAge()
c.response = c.request.Response
// DO NOT ADD ANY DEFAULT COOKIE DOMAIN!
//if c.domain == "" {
// c.domain = c.request.GetHost()
//}
for _, v := range c.request.Cookies() {
c.data[v.Name] = CookieItem{
v.Value, v.Domain, v.Path, int64(v.Expires.Second()), v.HttpOnly,
}
}
if c.data != nil {
return
}
c.data = make(map[string]*http.Cookie)
c.path = c.request.Server.GetCookiePath()
c.domain = c.request.Server.GetCookieDomain()
c.maxAge = c.request.Server.GetCookieMaxAge()
c.response = c.request.Response
// DO NOT ADD ANY DEFAULT COOKIE DOMAIN!
//if c.domain == "" {
// c.domain = c.request.GetHost()
//}
for _, v := range c.request.Cookies() {
c.data[v.Name] = v
}
}
@ -71,7 +59,7 @@ func (c *Cookie) Map() map[string]string {
c.init()
m := make(map[string]string)
for k, v := range c.data {
m[k] = v.value
m[k] = v.Value
}
return m
}
@ -80,7 +68,7 @@ func (c *Cookie) Map() map[string]string {
func (c *Cookie) Contains(key string) bool {
c.init()
if r, ok := c.data[key]; ok {
if r.expireAt >= 0 {
if r.Expires.IsZero() || r.Expires.After(time.Now()) {
return true
}
}
@ -89,7 +77,7 @@ func (c *Cookie) Contains(key string) bool {
// Set sets cookie item with default domain, path and expiration age.
func (c *Cookie) Set(key, value string) {
c.SetCookie(key, value, c.domain, c.path, c.server.GetCookieMaxAge())
c.SetCookie(key, value, c.domain, c.path, c.maxAge)
}
// SetCookie sets cookie item given given domain, path and expiration age.
@ -101,11 +89,25 @@ func (c *Cookie) SetCookie(key, value, domain, path string, maxAge time.Duration
if len(httpOnly) > 0 {
isHttpOnly = httpOnly[0]
}
c.data[key] = CookieItem{
value, domain, path, gtime.Timestamp() + int64(maxAge.Seconds()), isHttpOnly,
c.data[key] = &http.Cookie{
Name: key,
Value: value,
Path: path,
Domain: domain,
Expires: time.Now().Add(maxAge),
HttpOnly: isHttpOnly,
}
}
// SetHttpCookie sets cookie with *http.Cookie.
func (c *Cookie) SetHttpCookie(cookie *http.Cookie) {
c.init()
if cookie.Expires.IsZero() {
cookie.Expires = time.Now().Add(c.maxAge)
}
c.data[cookie.Name] = cookie
}
// GetSessionId retrieves and returns the session id from cookie.
func (c *Cookie) GetSessionId() string {
return c.Get(c.server.GetSessionIdName())
@ -121,8 +123,8 @@ func (c *Cookie) SetSessionId(id string) {
func (c *Cookie) Get(key string, def ...string) string {
c.init()
if r, ok := c.data[key]; ok {
if r.expireAt >= 0 {
return r.value
if r.Expires.IsZero() || r.Expires.After(time.Now()) {
return r.Value
}
}
if len(def) > 0 {
@ -148,22 +150,12 @@ func (c *Cookie) Flush() {
if len(c.data) == 0 {
return
}
for k, v := range c.data {
// Cookie item matches expire != 0 means it is set in this request,
for _, v := range c.data {
// If cookie item is v.Expires.IsZero() means it is set in this request,
// which should be outputted to client.
if v.expireAt == 0 {
if v.Expires.IsZero() {
continue
}
http.SetCookie(
c.response.Writer,
&http.Cookie{
Name: k,
Value: v.value,
Domain: v.domain,
Path: v.path,
Expires: time.Unix(v.expireAt, 0),
HttpOnly: v.httpOnly,
},
)
http.SetCookie(c.response.Writer, v)
}
}

View File

@ -8,6 +8,7 @@ package ghttp_test
import (
"fmt"
"net/http"
"testing"
"time"
@ -33,6 +34,52 @@ func Test_Cookie(t *testing.T) {
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
gtest.C(t, func(t *gtest.T) {
client := ghttp.NewClient()
client.SetBrowserMode(true)
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
r1, e1 := client.Get("/set?k=key1&v=100")
if r1 != nil {
defer r1.Close()
}
t.Assert(e1, nil)
t.Assert(r1.ReadAllString(), "")
t.Assert(client.GetContent("/set?k=key2&v=200"), "")
t.Assert(client.GetContent("/get?k=key1"), "100")
t.Assert(client.GetContent("/get?k=key2"), "200")
t.Assert(client.GetContent("/get?k=key3"), "")
t.Assert(client.GetContent("/remove?k=key1"), "")
t.Assert(client.GetContent("/remove?k=key3"), "")
t.Assert(client.GetContent("/remove?k=key4"), "")
t.Assert(client.GetContent("/get?k=key1"), "")
t.Assert(client.GetContent("/get?k=key2"), "200")
})
}
func Test_SetHttpCookie(t *testing.T) {
p, _ := ports.PopRand()
s := g.Server(p)
s.BindHandler("/set", func(r *ghttp.Request) {
r.Cookie.SetHttpCookie(&http.Cookie{
Name: r.GetString("k"),
Value: r.GetString("v"),
})
})
s.BindHandler("/get", func(r *ghttp.Request) {
r.Response.Write(r.Cookie.Get(r.GetString("k")))
})
s.BindHandler("/remove", func(r *ghttp.Request) {
r.Cookie.Remove(r.GetString("k"))
})
s.SetPort(p)
s.SetDumpRouterMap(false)
s.Start()
defer s.Shutdown()
time.Sleep(100 * time.Millisecond)
gtest.C(t, func(t *gtest.T) {
client := ghttp.NewClient()

View File

@ -104,8 +104,10 @@ func (l *Logger) print(std io.Writer, lead string, values ...interface{}) {
// It here uses CAP for performance and concurrent safety.
if !l.init.Val() && l.init.Cas(false, true) {
// It just initializes once for each logger.
gtimer.AddOnce(l.config.RotateCheckInterval, l.rotateChecksTimely)
intlog.Printf("logger initialized: every %s", l.config.RotateCheckInterval.String())
if l.config.RotateSize > 0 || l.config.RotateExpire > 0 {
gtimer.AddOnce(l.config.RotateCheckInterval, l.rotateChecksTimely)
intlog.Printf("logger rotation initialized: every %s", l.config.RotateCheckInterval.String())
}
}
var (