From cd00ac446b0d35bb72f273f231f892e8ba6582d6 Mon Sep 17 00:00:00 2001 From: John Date: Tue, 3 Sep 2019 17:18:16 +0800 Subject: [PATCH] improve CORS feature for ghttp.Server --- .example/other/test2.go | 19 ++++++++--- database/gdb/gdb_model.go | 5 ++- encoding/gcompress/gcompress_file.go | 14 +++++--- net/ghttp/ghttp_request.go | 45 +++++++++++++------------- net/ghttp/ghttp_request_middleware.go | 2 +- net/ghttp/ghttp_response.go | 19 ++--------- net/ghttp/ghttp_response_cors.go | 12 ++++++- net/ghttp/ghttp_server_handler.go | 4 +-- net/ghttp/ghttp_server_router_serve.go | 27 ++++++++-------- os/glog/glog_logger.go | 2 +- 10 files changed, 83 insertions(+), 66 deletions(-) diff --git a/.example/other/test2.go b/.example/other/test2.go index 798ec99a1..ceb8ff72b 100644 --- a/.example/other/test2.go +++ b/.example/other/test2.go @@ -1,11 +1,20 @@ package main -import "github.com/gogf/gf/os/glog" +import ( + "fmt" -func Test() { - -} + "github.com/gogf/gf/encoding/gcompress" + "github.com/gogf/gf/os/gfile" +) func main() { - glog.Line().Println("123") + fmt.Println(gfile.Basename("/dir/*")) + return + err := gcompress.ZipPath( + "/Users/john/Workspace/Go/GOPATH/src/github.com/gogf/gf/.example/other", + "/Users/john/Temp/test.zip", + ) + if err != nil { + panic(err) + } } diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index ef6a13d3e..30ce8a2e7 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "reflect" + "strings" "github.com/gogf/gf/util/gconv" ) @@ -216,7 +217,9 @@ func (md *Model) GroupBy(groupBy string) *Model { // 链式操作,order by func (md *Model) OrderBy(orderBy string) *Model { model := md.getModel() - model.orderBy = orderBy + array := strings.Split(orderBy, " ") + array[0] = md.db.quoteWord(array[0]) + model.orderBy = strings.Join(array, " ") return model } diff --git a/encoding/gcompress/gcompress_file.go b/encoding/gcompress/gcompress_file.go index 1a3f6e854..0beaa5907 100644 --- a/encoding/gcompress/gcompress_file.go +++ b/encoding/gcompress/gcompress_file.go @@ -9,12 +9,13 @@ package gcompress import ( "archive/zip" "bytes" - "github.com/gogf/gf/os/gfile" - "github.com/gogf/gf/text/gstr" "io" "os" "path/filepath" "strings" + + "github.com/gogf/gf/os/gfile" + "github.com/gogf/gf/text/gstr" ) // ZipPath compresses to using zip compressing algorithm. @@ -31,7 +32,7 @@ func ZipPath(path, dest string, prefix ...string) error { // ZipPathWriter compresses to using zip compressing algorithm. // The unnecessary parameter indicates the path prefix for zip file. func ZipPathWriter(path string, writer io.Writer, prefix ...string) error { - pathRealPath, err := gfile.Search(path) + realPath, err := gfile.Search(path) if err != nil { return err } @@ -45,8 +46,13 @@ func ZipPathWriter(path string, writer io.Writer, prefix ...string) error { if len(prefix) > 0 { headerPrefix = prefix[0] } + headerPrefix = strings.Trim(headerPrefix, "\\/") + // If path is a directory, add it to the zip prefix. + if gfile.IsDir(realPath) { + headerPrefix = headerPrefix + "/" + gfile.Basename(realPath) + } for _, file := range files { - err := zipFile(file, headerPrefix+gfile.Dir(file[len(pathRealPath):]), zipWriter) + err := zipFile(file, headerPrefix+gfile.Dir(file[len(realPath):]), zipWriter) if err != nil { return err } diff --git a/net/ghttp/ghttp_request.go b/net/ghttp/ghttp_request.go index d673df9af..1158a9934 100644 --- a/net/ghttp/ghttp_request.go +++ b/net/ghttp/ghttp_request.go @@ -21,28 +21,29 @@ import ( // 请求对象 type Request struct { *http.Request - Id int // 请求ID(当前Server对象唯一) - Server *Server // 请求关联的服务器对象 - Cookie *Cookie // 与当前请求绑定的Cookie对象(并发安全) - Session *Session // 与当前请求绑定的Session对象(并发安全) - Response *Response // 对应请求的返回数据操作对象 - Router *Router // 匹配到的路由对象 - EnterTime int64 // 请求进入时间(微秒) - LeaveTime int64 // 请求完成时间(微秒) - Middleware *Middleware // 中间件功能调用对象 - handlers []*handlerParsedItem // 请求执行服务函数列表(包含中间件、路由函数、钩子函数) - handlerIndex int // 当前执行函数的索引号 - hasHookHandler bool // 是否注册有钩子函数(用于请求时提高钩子函数功能启用判断效率) - parsedGet bool // GET参数是否已经解析 - parsedPost bool // POST参数是否已经解析 - queryVars map[string][]string // GET参数 - routerVars map[string][]string // 路由解析参数 - exit bool // 是否退出当前请求流程执行 - params map[string]interface{} // 开发者自定义参数(请求流程中有效) - parsedHost string // 解析过后不带端口号的服务器域名名称 - clientIp string // 解析过后的客户端IP地址 - rawContent []byte // 客户端提交的原始参数 - isFileRequest bool // 是否为静态文件请求(非服务请求,当静态文件存在时,优先级会被服务请求高,被识别为文件请求) + Id int // 请求ID(当前Server对象唯一) + Server *Server // 请求关联的服务器对象 + Cookie *Cookie // 与当前请求绑定的Cookie对象(并发安全) + Session *Session // 与当前请求绑定的Session对象(并发安全) + Response *Response // 对应请求的返回数据操作对象 + Router *Router // 匹配到的路由对象 + EnterTime int64 // 请求进入时间(微秒) + LeaveTime int64 // 请求完成时间(微秒) + Middleware *Middleware // 中间件功能调用对象 + handlers []*handlerParsedItem // 请求执行服务函数列表(包含中间件、路由函数、钩子函数) + handlerIndex int // 当前执行函数的索引号 + hasHookHandler bool // 是否检索到钩子函数(用于请求时提高钩子函数功能启用判断效率) + hasServeHandler bool // 是否检索到服务函数 + parsedGet bool // GET参数是否已经解析 + parsedPost bool // POST参数是否已经解析 + queryVars map[string][]string // GET参数 + routerVars map[string][]string // 路由解析参数 + exit bool // 是否退出当前请求流程执行 + params map[string]interface{} // 开发者自定义参数(请求流程中有效) + parsedHost string // 解析过后不带端口号的服务器域名名称 + clientIp string // 解析过后的客户端IP地址 + rawContent []byte // 客户端提交的原始参数 + isFileRequest bool // 是否为静态文件请求(非服务请求,当静态文件存在时,优先级会被服务请求高,被识别为文件请求) } // 创建一个Request对象 diff --git a/net/ghttp/ghttp_request_middleware.go b/net/ghttp/ghttp_request_middleware.go index 7cb7febbc..f3778a546 100644 --- a/net/ghttp/ghttp_request_middleware.go +++ b/net/ghttp/ghttp_request_middleware.go @@ -24,7 +24,7 @@ func (m *Middleware) Next() { } item = m.request.handlers[m.request.handlerIndex] m.request.handlerIndex++ - // 通过中间件模式不执行钩子函数 + // 中间件执行时不执行钩子函数,由另外的逻辑进行控制 if item.handler.itemType == gHANDLER_TYPE_HOOK { continue } diff --git a/net/ghttp/ghttp_response.go b/net/ghttp/ghttp_response.go index 6570233e2..f86a8f08e 100644 --- a/net/ghttp/ghttp_response.go +++ b/net/ghttp/ghttp_response.go @@ -9,9 +9,9 @@ package ghttp import ( "bytes" + "encoding/json" "fmt" "net/http" - "strconv" "github.com/gogf/gf/os/gres" @@ -81,7 +81,7 @@ func (r *Response) Writefln(format string, params ...interface{}) { // 返回JSON func (r *Response) WriteJson(content interface{}) error { - if b, err := gparser.VarToJson(content); err != nil { + if b, err := json.Marshal(content); err != nil { return err } else { r.Header().Set("Content-Type", "application/json") @@ -92,7 +92,7 @@ func (r *Response) WriteJson(content interface{}) error { // 返回JSONP func (r *Response) WriteJsonP(content interface{}) error { - if b, err := gparser.VarToJson(content); err != nil { + if b, err := json.Marshal(content); err != nil { return err } else { //r.Header().Set("Content-Type", "application/json") @@ -120,19 +120,6 @@ func (r *Response) WriteXml(content interface{}, rootTag ...string) error { return nil } -// Deprecated, please use CORSDefault instead. -// -// (已废弃,请使用CORSDefault)允许AJAX跨域访问. -func (r *Response) SetAllowCrossDomainRequest(allowOrigin string, allowMethods string, maxAge ...int) { - age := 3628800 - if len(maxAge) > 0 { - age = maxAge[0] - } - r.Header().Set("Access-Control-Allow-Origin", allowOrigin) - r.Header().Set("Access-Control-Allow-Methods", allowMethods) - r.Header().Set("Access-Control-Max-Age", strconv.Itoa(age)) -} - // 返回HTTP Code状态码 func (r *Response) WriteStatus(status int, content ...interface{}) { if r.buffer.Len() == 0 { diff --git a/net/ghttp/ghttp_response_cors.go b/net/ghttp/ghttp_response_cors.go index b2df3f5d0..81bd7ad29 100644 --- a/net/ghttp/ghttp_response_cors.go +++ b/net/ghttp/ghttp_response_cors.go @@ -8,6 +8,7 @@ package ghttp import ( + "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" ) @@ -24,12 +25,21 @@ type CORSOptions struct { // 默认的CORS配置 func (r *Response) DefaultCORSOptions() CORSOptions { - return CORSOptions{ + options := CORSOptions{ AllowOrigin: "*", AllowMethods: HTTP_METHODS, AllowCredentials: "true", + AllowHeaders: "Origin, X-Requested-With, Content-Type, Accept, Key", MaxAge: 3628800, } + if referer := r.request.Referer(); referer != "" { + if p := gstr.PosR(referer, "/", 6); p != -1 { + options.AllowOrigin = referer[:p] + } else { + options.AllowOrigin = referer + } + } + return options } // See https://www.w3.org/TR/cors/ . diff --git a/net/ghttp/ghttp_server_handler.go b/net/ghttp/ghttp_server_handler.go index 160044c7d..b1d53a49f 100644 --- a/net/ghttp/ghttp_server_handler.go +++ b/net/ghttp/ghttp_server_handler.go @@ -109,7 +109,7 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { // 动态服务检索 if serveFile == nil || serveFile.dir { - request.handlers, request.hasHookHandler = s.getHandlersWithCache(request) + request.handlers, request.hasHookHandler, request.hasServeHandler = s.getHandlersWithCache(request) } // 判断最终对该请求提供的服务方式 @@ -126,7 +126,7 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { // 静态服务 s.serveFile(request, serveFile) } else { - if len(request.handlers) > 0 { + if request.hasServeHandler { // 动态服务 request.Middleware.Next() } else { diff --git a/net/ghttp/ghttp_server_router_serve.go b/net/ghttp/ghttp_server_router_serve.go index e624fb8bf..23a43ca59 100644 --- a/net/ghttp/ghttp_server_router_serve.go +++ b/net/ghttp/ghttp_server_router_serve.go @@ -19,28 +19,30 @@ import ( type handlerCacheItem struct { parsedItems []*handlerParsedItem hasHook bool + hasServe bool } // 查询请求处理方法. // 内部带锁机制,可以并发读,但是不能并发写;并且有缓存机制,按照Host、Method、Path进行缓存. -func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedItem, hasHook bool) { - cacheKey := s.serveHandlerKey(r.Method, r.URL.Path, r.GetHost()) - if v := s.serveCache.Get(cacheKey); v == nil { - parsedItems, hasHook = s.searchHandlers(r.Method, r.URL.Path, r.GetHost()) +func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) { + value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(r.Method, r.URL.Path, r.GetHost()), func() interface{} { + parsedItems, hasHook, hasServe = s.searchHandlers(r.Method, r.URL.Path, r.GetHost()) if parsedItems != nil { - s.serveCache.Set(cacheKey, &handlerCacheItem{parsedItems, hasHook}, s.config.RouterCacheExpire*1000) + return &handlerCacheItem{parsedItems, hasHook, hasServe} } - } else { - item := v.(*handlerCacheItem) - return item.parsedItems, item.hasHook + return nil + }, s.config.RouterCacheExpire*1000) + if value != nil { + item := value.(*handlerCacheItem) + return item.parsedItems, item.hasHook, item.hasServe } return } // 路由注册方法检索,返回所有该路由的注册函数,构造成数组返回 -func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*handlerParsedItem, hasHook bool) { +func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) { if len(path) == 0 { - return nil, false + return nil, false, false } // 遍历检索的域名列表,优先遍历默认域名 domains := []string{gDEFAULT_DOMAIN} @@ -56,7 +58,6 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han } parsedItemList := glist.New() lastMiddlewareItem := (*glist.Element)(nil) - isServeHandlerAdded := false for _, domain := range domains { p, ok := s.serveTree[domain] if !ok { @@ -98,7 +99,7 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han for e := lists[i].Front(); e != nil; e = e.Next() { item := e.Value.(*handlerItem) // 服务路由函数只能添加一次 - if isServeHandlerAdded { + if hasServe { switch item.itemType { case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER: continue @@ -126,7 +127,7 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han switch item.itemType { // 服务路由函数只能添加一次 case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER: - isServeHandlerAdded = true + hasServe = true parsedItemList.PushBack(parsedItem) // 中间件需要排序 diff --git a/os/glog/glog_logger.go b/os/glog/glog_logger.go index 393538c12..e62611934 100644 --- a/os/glog/glog_logger.go +++ b/os/glog/glog_logger.go @@ -185,7 +185,7 @@ func (l *Logger) SetPath(path string) error { } if !gfile.Exists(path) { if err := gfile.Mkdir(path); err != nil { - fmt.Fprintln(os.Stderr, fmt.Sprintf(`[glog] mkdir "%s" failed: %s`, path, err.Error())) + //fmt.Fprintln(os.Stderr, fmt.Sprintf(`[glog] mkdir "%s" failed: %s`, path, err.Error())) return err } }