From d8a7e364780f039c08402392defd3629777f5292 Mon Sep 17 00:00:00 2001 From: John Date: Wed, 4 Mar 2020 17:29:23 +0800 Subject: [PATCH 01/26] improve router feature for ghttp.Server --- .example/other/test.go | 33 +-- DONATOR.MD | 2 +- container/garray/garray_sorted_any.go | 3 + container/garray/garray_sorted_int.go | 5 +- container/garray/garray_sorted_str.go | 5 +- net/ghttp/ghttp_func.go | 4 +- net/ghttp/ghttp_request_param_file.go | 11 +- net/ghttp/ghttp_response_cors.go | 31 ++- net/ghttp/ghttp_server.go | 26 +- net/ghttp/ghttp_server_plugin.go | 2 +- net/ghttp/ghttp_server_router.go | 90 +++++-- net/ghttp/ghttp_server_router_hook.go | 5 - net/ghttp/ghttp_server_router_serve.go | 59 ++--- .../ghttp_unit_router_controller_rest_test.go | 12 - net/ghttp/ghttp_unit_router_hook_test.go | 25 +- os/gfile/gfile.go | 14 +- text/gregex/gregex_cache.go | 36 ++- util/gpage/gpage.go | 247 +++++++----------- 18 files changed, 309 insertions(+), 301 deletions(-) diff --git a/.example/other/test.go b/.example/other/test.go index 774c42627..92e1cfb76 100644 --- a/.example/other/test.go +++ b/.example/other/test.go @@ -1,37 +1,10 @@ package main import ( - "net/http" - - "github.com/gogf/gf/frame/g" - "github.com/gogf/gf/net/ghttp" + "github.com/gogf/gf/container/garray" ) -func MiddlewareAuth(r *ghttp.Request) { - token := r.Get("token") - if token == "123456" { - r.Response.Writeln("auth") - r.Middleware.Next() - } else { - r.Response.WriteStatus(http.StatusForbidden) - } -} - -func MiddlewareCORS(r *ghttp.Request) { - r.Response.Writeln("cors") - r.Response.CORSDefault() - r.Middleware.Next() -} - func main() { - s := g.Server() - s.Use(MiddlewareCORS) - s.Group("/api.v2", func(group *ghttp.RouterGroup) { - group.Middleware(MiddlewareAuth) - group.ALL("/user/list", func(r *ghttp.Request) { - r.Response.Writeln("list") - }) - }) - s.SetPort(8199) - s.Run() + arr := garray.NewStrArray(false) + arr.Unique() } diff --git a/DONATOR.MD b/DONATOR.MD index 418951da5..d5327b5a9 100644 --- a/DONATOR.MD +++ b/DONATOR.MD @@ -15,7 +15,7 @@ We currently accept donation by Alipay/WechatPay, please note your github/gitee |[zhuhuan12](https://gitee.com/zhuhuan12)|gitee|¥50.00 | |[zfan_codes](https://gitee.com/zfan_codes)|gitee|¥10.00 | |[arden](https://github.com/arden)|alipay|¥10.00 | -|[macnie](https://www.macnie.com)|wechat|¥100.00 | +|[macnie](https://www.macnie.com)|wechat|¥110.00 | |lah|wechat|¥100.00 | |x*z|wechat|¥20.00 | |潘兄|wechat|¥100.00 | diff --git a/container/garray/garray_sorted_any.go b/container/garray/garray_sorted_any.go index dd6c1690b..32439e8f8 100644 --- a/container/garray/garray_sorted_any.go +++ b/container/garray/garray_sorted_any.go @@ -439,6 +439,9 @@ func (a *SortedArray) SetUnique(unique bool) *SortedArray { func (a *SortedArray) Unique() *SortedArray { a.mu.Lock() defer a.mu.Unlock() + if len(a.array) == 0 { + return a + } i := 0 for { if i == len(a.array)-1 { diff --git a/container/garray/garray_sorted_int.go b/container/garray/garray_sorted_int.go index 4d469aa51..fc92567dc 100644 --- a/container/garray/garray_sorted_int.go +++ b/container/garray/garray_sorted_int.go @@ -429,6 +429,10 @@ func (a *SortedIntArray) SetUnique(unique bool) *SortedIntArray { // Unique uniques the array, clear repeated items. func (a *SortedIntArray) Unique() *SortedIntArray { a.mu.Lock() + defer a.mu.Unlock() + if len(a.array) == 0 { + return a + } i := 0 for { if i == len(a.array)-1 { @@ -440,7 +444,6 @@ func (a *SortedIntArray) Unique() *SortedIntArray { i++ } } - a.mu.Unlock() return a } diff --git a/container/garray/garray_sorted_str.go b/container/garray/garray_sorted_str.go index 1b8054c09..c7c9f03aa 100644 --- a/container/garray/garray_sorted_str.go +++ b/container/garray/garray_sorted_str.go @@ -414,6 +414,10 @@ func (a *SortedStrArray) SetUnique(unique bool) *SortedStrArray { // Unique uniques the array, clear repeated items. func (a *SortedStrArray) Unique() *SortedStrArray { a.mu.Lock() + defer a.mu.Unlock() + if len(a.array) == 0 { + return a + } i := 0 for { if i == len(a.array)-1 { @@ -425,7 +429,6 @@ func (a *SortedStrArray) Unique() *SortedStrArray { i++ } } - a.mu.Unlock() return a } diff --git a/net/ghttp/ghttp_func.go b/net/ghttp/ghttp_func.go index 4f32302f4..e2d46a8ef 100644 --- a/net/ghttp/ghttp_func.go +++ b/net/ghttp/ghttp_func.go @@ -50,9 +50,7 @@ func niceCallFunc(f func()) { defer func() { if err := recover(); err != nil { switch err { - case gEXCEPTION_EXIT: - fallthrough - case gEXCEPTION_EXIT_ALL: + case gEXCEPTION_EXIT, gEXCEPTION_EXIT_ALL: return default: panic(err) diff --git a/net/ghttp/ghttp_request_param_file.go b/net/ghttp/ghttp_request_param_file.go index 6e54725b2..e91f5b3e7 100644 --- a/net/ghttp/ghttp_request_param_file.go +++ b/net/ghttp/ghttp_request_param_file.go @@ -8,12 +8,12 @@ package ghttp import ( "errors" + "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/os/gfile" "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/util/grand" "io" "mime/multipart" - "os" "strconv" "strings" ) @@ -45,22 +45,21 @@ func (f *UploadFile) Save(path string, randomlyRename ...bool) error { } defer file.Close() - var newFile *os.File + filePath := path if gfile.IsDir(path) { filename := gfile.Basename(f.Filename) if len(randomlyRename) > 0 && randomlyRename[0] { filename = strings.ToLower(strconv.FormatInt(gtime.TimestampNano(), 36) + grand.S(6)) filename = filename + gfile.Ext(f.Filename) } - newFile, err = gfile.Create(gfile.Join(path, filename)) - } else { - newFile, err = gfile.Create(path) + filePath = gfile.Join(path, filename) } + newFile, err := gfile.Create(filePath) if err != nil { return err } defer newFile.Close() - + intlog.Printf(`save upload file: %s`, filePath) if _, err := io.Copy(newFile, file); err != nil { return err } diff --git a/net/ghttp/ghttp_response_cors.go b/net/ghttp/ghttp_response_cors.go index 1853239a7..835a1c8e2 100644 --- a/net/ghttp/ghttp_response_cors.go +++ b/net/ghttp/ghttp_response_cors.go @@ -27,6 +27,20 @@ type CORSOptions struct { AllowHeaders string // Access-Control-Allow-Headers } +var ( + // defaultAllowHeaders is the default allowed headers for CORS. + // It's defined as map for better header key searching performance. + defaultAllowHeaders = map[string]struct{}{ + "Origin": {}, + "Accept": {}, + "Cookie": {}, + "Authorization": {}, + "X-Auth-Token": {}, + "X-Requested-With": {}, + "Content-Type": {}, + } +) + // DefaultCORSOptions returns the default CORS options, // which allows any cross-domain request. func (r *Response) DefaultCORSOptions() CORSOptions { @@ -34,9 +48,24 @@ func (r *Response) DefaultCORSOptions() CORSOptions { AllowOrigin: "*", AllowMethods: HTTP_METHODS, AllowCredentials: "true", - AllowHeaders: "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With", MaxAge: 3628800, } + // Allow all client's custom headers in default. + if headers := r.Request.Header.Get("Access-Control-Request-Headers"); headers != "" { + array := gstr.SplitAndTrim(headers, ",") + for _, header := range array { + if _, ok := defaultAllowHeaders[header]; !ok { + options.AllowHeaders += header + "," + } + } + for header, _ := range defaultAllowHeaders { + if len(options.AllowHeaders) > 0 { + options.AllowHeaders += "," + } + options.AllowHeaders += header + } + } + // Allow all anywhere origin in default. if origin := r.Request.Header.Get("Origin"); origin != "" { options.AllowOrigin = origin } else if referer := r.Request.Referer(); referer != "" { diff --git a/net/ghttp/ghttp_server.go b/net/ghttp/ghttp_server.go index b6126fac6..d79aa80c7 100644 --- a/net/ghttp/ghttp_server.go +++ b/net/ghttp/ghttp_server.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "github.com/gogf/gf/debug/gdebug" + "github.com/gogf/gf/internal/intlog" "net/http" "os" "reflect" @@ -78,7 +79,6 @@ type ( // handlerItem is the registered handler for route handling, // including middleware and hook functions. handlerItem struct { - itemId int // Unique ID mark. itemName string // Handler name, which is automatically retrieved from runtime stack when registered. itemType int // Handler type: object/handler/controller/middleware/hook. itemFunc HandlerFunc // Handler address. @@ -143,7 +143,7 @@ var ( // it is used for quick HTTP method searching using map. methodsMap = make(map[string]struct{}) - // serverMapping stores more than one server instances. + // serverMapping stores more than one server instances for current process. // The key is the name of the server, and the value is its instance. serverMapping = gmap.NewStrAnyMap(true) @@ -444,14 +444,20 @@ func (s *Server) GetRouterArray() []RouterItem { } // Run starts server listening in blocking way. +// It's commonly used for single server situation. func (s *Server) Run() { if err := s.Start(); err != nil { s.Logger().Fatal(err) } - // Blocking using channel. <-s.closeChan - + // Remove plugins. + if len(s.plugins) > 0 { + for _, p := range s.plugins { + intlog.Printf(`remove plugin: %s`, p.Name()) + p.Remove() + } + } s.Logger().Printf("[ghttp] %d: all servers shutdown", gproc.Pid()) } @@ -459,7 +465,17 @@ func (s *Server) Run() { // It's commonly used in multiple servers situation. func Wait() { <-allDoneChan - + // Remove plugins. + serverMapping.Iterator(func(k string, v interface{}) bool { + s := v.(*Server) + if len(s.plugins) > 0 { + for _, p := range s.plugins { + intlog.Printf(`remove plugin: %s`, p.Name()) + p.Remove() + } + } + return true + }) glog.Printf("[ghttp] %d: all servers shutdown", gproc.Pid()) } diff --git a/net/ghttp/ghttp_server_plugin.go b/net/ghttp/ghttp_server_plugin.go index ff7b0ad69..d8d47aaf3 100644 --- a/net/ghttp/ghttp_server_plugin.go +++ b/net/ghttp/ghttp_server_plugin.go @@ -10,7 +10,7 @@ package ghttp type Plugin interface { Name() string // Name returns the name of the plugin. Author() string // Author returns the author of the plugin. - Version() string // Version returns the version of the plugin. + Version() string // Version returns the version of the plugin, like "v1.0.0". Description() string // Description returns the description of the plugin. Install(s *Server) error // Install installs the plugin before server starts. Remove() error // Remove removes the plugin. diff --git a/net/ghttp/ghttp_server_router.go b/net/ghttp/ghttp_server_router.go index 83bc9ef37..7e7c2f353 100644 --- a/net/ghttp/ghttp_server_router.go +++ b/net/ghttp/ghttp_server_router.go @@ -9,7 +9,7 @@ package ghttp import ( "errors" "fmt" - "github.com/gogf/gf/container/gtype" + "github.com/gogf/gf/util/gutil" "strings" "github.com/gogf/gf/debug/gdebug" @@ -23,12 +23,12 @@ const ( gFILTER_KEY = "/net/ghttp/ghttp" ) -var ( - // 用于服务函数的ID生成变量 - handlerIdGenerator = gtype.NewInt() -) +// handlerKey creates and returns an unique router key for given parameters. +func (s *Server) handlerKey(hook, method, path, domain string) string { + return hook + "%" + s.serveHandlerKey(method, path, domain) +} -// 解析pattern +// parsePattern parses the given pattern to domain, method and path variable. func (s *Server) parsePattern(pattern string) (domain, method, path string, err error) { path = strings.TrimSpace(pattern) domain = gDEFAULT_DOMAIN @@ -48,18 +48,17 @@ func (s *Server) parsePattern(pattern string) (domain, method, path string, err if path == "" { err = errors.New("invalid pattern: URI should not be empty") } - // 去掉末尾的"/"符号,与路由匹配时处理一致 if path != "/" { path = strings.TrimRight(path, "/") } return } -// 路由注册处理方法。 -// 非叶节点为哈希表检索节点,按照URI注册的层级进行高效检索,直至到叶子链表节点; -// 叶子节点是链表,按照优先级进行排序,优先级高的排前面,按照遍历检索,按照哈希表层级检索后的叶子链表数据量不会很大,所以效率比较高; +// setHandler creates router item with given handler and pattern and registers the handler to the router tree. +// The router tree can be treated as a multilayer hash table, please refer to the comment in following codes. +// This function is called during server starts up, which cares little about the performance. What really cares +// is the well designed router storage structure for router searching when the request is under serving. func (s *Server) setHandler(pattern string, handler *handlerItem) { - handler.itemId = handlerIdGenerator.Add(1) domain, method, uri, err := s.parsePattern(pattern) if err != nil { s.Logger().Fatal("invalid pattern:", pattern, err) @@ -69,7 +68,8 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { s.Logger().Fatal("invalid pattern:", pattern, "URI should lead with '/'") return } - // 注册地址记录及重复注册判断 + + // Repeated router checks, this feature can be disabled by server configuration. regKey := s.handlerKey(handler.hookName, method, uri, domain) if !s.config.RouteOverWrite { switch handler.itemType { @@ -80,11 +80,11 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { } } } - // 注册的路由信息对象 + // Create a new router by given parameter. handler.router = &Router{ Uri: uri, Domain: domain, - Method: method, + Method: strings.ToUpper(method), Priority: strings.Count(uri[1:], "/"), } handler.router.RegRule, handler.router.RegNames = s.patternToRegRule(uri) @@ -92,7 +92,8 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { if _, ok := s.serveTree[domain]; !ok { s.serveTree[domain] = make(map[string]interface{}) } - // 当前节点的规则链表 + // List array, very important for router register. + // There may be multiple lists adding into this array when searching from root to leaf. lists := make([]*glist.List, 0) array := ([]string)(nil) if strings.EqualFold("/", uri) { @@ -100,9 +101,57 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { } else { array = strings.Split(uri[1:], "/") } - // 键名"*fuzz"代表当前节点为模糊匹配节点,该节点也会有一个*list链表; - // 键名"*list"代表链表,叶子节点和模糊匹配节点都有该属性,优先级越高越排前; + // Multilayer hash table: + // 1. Each node of the table is separated by URI path which is split by char '/'. + // 2. The key "*fuzz" specifies this node is a fuzzy node, which has no certain name. + // 3. The key "*list" is the list item of the node, MOST OF THE NODES HAVE THIS ITEM, + // especially the fuzzy node. NOTE THAT the fuzzy node must have the "*list" item, + // and the leaf node also has "*list" item. If the node is not a fuzzy node either + // a leaf, it neither has "*list" item. + // 2. The "*list" item is a list containing registered router items ordered by their + // priorities from high to low. + // 3. There may be repeated router items in the router lists. The lists' priorities + // from root to leaf are from low to high. p := s.serveTree[domain] + for i, part := range array { + // Ignore empty URI part, like: /user//index + if part == "" { + continue + } + // Check if it's a fuzzy node. + if gregex.IsMatchString(`^[:\*]|\{[\w\.\-]+\}|\*`, part) { + part = "*fuzz" + // If it's a fuzzy node, it creates a "*list" item - which is a list - in the hash map. + // All the sub router items from this fuzzy node will also be added to its "*list" item. + if v, ok := p.(map[string]interface{})["*list"]; !ok { + newListForFuzzy := glist.New() + p.(map[string]interface{})["*list"] = newListForFuzzy + lists = append(lists, newListForFuzzy) + } else { + lists = append(lists, v.(*glist.List)) + } + } + // Make a new bucket for current node. + if _, ok := p.(map[string]interface{})[part]; !ok { + p.(map[string]interface{})[part] = make(map[string]interface{}) + } + // Loop to next bucket. + p = p.(map[string]interface{})[part] + // The leaf is a hash map and must have an item named "*list", which contains the router item. + // The leaf can be furthermore extended by adding more ket-value pairs into its map. + // Note that the `v != "*fuzz"` comparison is required as the list might be added in the former + // fuzzy checks. + if i == len(array)-1 && part != "*fuzz" { + if v, ok := p.(map[string]interface{})["*list"]; !ok { + list := glist.New() + p.(map[string]interface{})["*list"] = list + lists = append(lists, list) + } else { + lists = append(lists, v.(*glist.List)) + } + } + } + for k, v := range array { if len(v) == 0 { continue @@ -135,8 +184,8 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { } } - // 上面循环后得到的lists是该路由规则一路匹配下来相关的模糊匹配链表(注意不是这棵树所有的链表)。 - // 下面从头开始遍历每个节点的模糊匹配链表,将该路由项插入进去(按照优先级高的放在lists链表的前面) + // It iterates the list array of , compares priorities and inserts the new router item in + // the proper position of each list. The priority of the list is ordered from high to low. item := (*handlerItem)(nil) for _, l := range lists { pushed := false @@ -173,6 +222,7 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { // Append the route. s.routesMap[regKey] = append(s.routesMap[regKey], routeItem) } + gutil.Dump(s.serveTree) } // 对比两个handlerItem的优先级,需要非常注意的是,注意新老对比项的参数先后顺序。 @@ -312,7 +362,7 @@ func (s *Server) patternToRegRule(rule string) (regrule string, names []string) regrule += `/{0,1}.*` } default: - // 特殊字符替换 + // Special chars replacement. v = gstr.ReplaceByMap(v, map[string]string{ `.`: `\.`, `+`: `\+`, diff --git a/net/ghttp/ghttp_server_router_hook.go b/net/ghttp/ghttp_server_router_hook.go index fcce0aa66..63d170dba 100644 --- a/net/ghttp/ghttp_server_router_hook.go +++ b/net/ghttp/ghttp_server_router_hook.go @@ -66,8 +66,3 @@ func (s *Server) niceCallHookHandler(f HandlerFunc, r *Request) (err interface{} f(r) return } - -// 生成hook key,如果是hook key,那么使用'%'符号分隔 -func (s *Server) handlerKey(hook, method, path, domain string) string { - return hook + "%" + s.serveHandlerKey(method, path, domain) -} diff --git a/net/ghttp/ghttp_server_router_serve.go b/net/ghttp/ghttp_server_router_serve.go index a3b728ff9..7306a5614 100644 --- a/net/ghttp/ghttp_server_router_serve.go +++ b/net/ghttp/ghttp_server_router_serve.go @@ -15,15 +15,14 @@ import ( "github.com/gogf/gf/text/gregex" ) -// 缓存数据项 +// handlerCacheItem is a item for router cache. type handlerCacheItem struct { parsedItems []*handlerParsedItem hasHook bool hasServe bool } -// 查询请求处理方法. -// 内部带锁机制,可以并发读,但是不能并发写;并且有缓存机制,按照Host、Method、Path进行缓存. +// getHandlersWithCache searches the router item with cache feature for given request. func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) { value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(r.Method, r.URL.Path, r.GetHost()), func() interface{} { parsedItems, hasHook, hasServe = s.searchHandlers(r.Method, r.URL.Path, r.GetHost()) @@ -39,18 +38,19 @@ func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedI return } -// 路由注册方法检索,返回所有该路由的注册函数,构造成数组返回 +// searchHandlers retrieves and returns the routers with given parameters. +// Note that the returned routers contain serving handler, middleware handlers and hook handlers. func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) { if len(path) == 0 { return nil, false, false } - // 遍历检索的域名列表,优先遍历默认域名 + // Default domain has the most priority when iteration. domains := []string{gDEFAULT_DOMAIN} if !strings.EqualFold(gDEFAULT_DOMAIN, domain) { domains = append(domains, domain) } - // URL.Path层级拆分 - array := ([]string)(nil) + // Split the URL.path to separate parts. + var array []string if strings.EqualFold("/", path) { array = []string{"/"} } else { @@ -58,42 +58,40 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han } parsedItemList := glist.New() lastMiddlewareElem := (*glist.Element)(nil) - repeatHandlerCheckMap := make(map[int]struct{}) for _, domain := range domains { p, ok := s.serveTree[domain] if !ok { continue } - // 多层链表(每个节点都有一个*list链表)的目的是当叶子节点未有任何规则匹配时,让父级模糊匹配规则继续处理 lists := make([]*glist.List, 0, 16) - for k, v := range array { + for i, part := range array { // In case of double '/' URI, eg: /user//index - if v == "" { + if part == "" { continue } - if _, ok := p.(map[string]interface{})["*list"]; ok { - lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List)) + if v, ok := p.(map[string]interface{})["*list"]; ok { + lists = append(lists, v.(*glist.List)) } - if _, ok := p.(map[string]interface{})[v]; ok { - p = p.(map[string]interface{})[v] - if k == len(array)-1 { - if _, ok := p.(map[string]interface{})["*list"]; ok { - lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List)) + if _, ok := p.(map[string]interface{})[part]; ok { + p = p.(map[string]interface{})[part] + if i == len(array)-1 { + if v, ok := p.(map[string]interface{})["*list"]; ok { + lists = append(lists, v.(*glist.List)) break } } } else { - if _, ok := p.(map[string]interface{})["*fuzz"]; ok { - p = p.(map[string]interface{})["*fuzz"] + if v, ok := p.(map[string]interface{})["*fuzz"]; ok { + p = v } } // 如果是叶子节点,同时判断当前层级的"*fuzz"键名,解决例如:/user/*action 匹配 /user 的规则 - if k == len(array)-1 { - if _, ok := p.(map[string]interface{})["*fuzz"]; ok { - p = p.(map[string]interface{})["*fuzz"] + if i == len(array)-1 { + if v, ok := p.(map[string]interface{})["*fuzz"]; ok { + p = v } - if _, ok := p.(map[string]interface{})["*list"]; ok { - lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List)) + if v, ok := p.(map[string]interface{})["*list"]; ok { + lists = append(lists, v.(*glist.List)) } } } @@ -102,12 +100,6 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han for i := len(lists) - 1; i >= 0; i-- { for e := lists[i].Front(); e != nil; e = e.Next() { item := e.Value.(*handlerItem) - // 主要是用于路由注册函数的重复添加判断(特别是中间件和钩子函数) - if _, ok := repeatHandlerCheckMap[item.itemId]; ok { - continue - } else { - repeatHandlerCheckMap[item.itemId] = struct{}{} - } // 服务路由函数只能添加一次,将重复判断放在这里提高检索效率 if hasServe { switch item.itemType { @@ -115,8 +107,7 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han continue } } - // 动态匹配规则带有gDEFAULT_METHOD的情况,不会像静态规则那样直接解析为所有的HTTP METHOD存储 - if strings.EqualFold(item.router.Method, gDEFAULT_METHOD) || strings.EqualFold(item.router.Method, method) { + if item.router.Method == gDEFAULT_METHOD || item.router.Method == method { // 注意当不带任何动态路由规则时,len(match) == 1 if match, err := gregex.MatchString(item.router.RegRule, path); err == nil && len(match) > 0 { parsedItem := &handlerParsedItem{item, nil} @@ -207,7 +198,7 @@ func (item *handlerParsedItem) MarshalJSON() ([]byte, error) { return json.Marshal(item.handler) } -// 生成回调方法查询的Key +// serveHandlerKey creates and returns a cache key for router. func (s *Server) serveHandlerKey(method, path, domain string) string { if len(domain) > 0 { domain = "@" + domain diff --git a/net/ghttp/ghttp_unit_router_controller_rest_test.go b/net/ghttp/ghttp_unit_router_controller_rest_test.go index a5b7e3cd1..149e7d11f 100644 --- a/net/ghttp/ghttp_unit_router_controller_rest_test.go +++ b/net/ghttp/ghttp_unit_router_controller_rest_test.go @@ -46,14 +46,6 @@ func (c *ControllerRest) Delete() { c.Response.Write("Controller Delete") } -func (c *ControllerRest) Patch() { - c.Response.Write("Controller Patch") -} - -func (c *ControllerRest) Options() { - c.Response.Write("Controller Options") -} - func (c *ControllerRest) Head() { c.Response.Header().Set("head-ok", "1") } @@ -78,8 +70,6 @@ func Test_Router_ControllerRest(t *testing.T) { gtest.Assert(client.PutContent("/"), "1Controller Put2") gtest.Assert(client.PostContent("/"), "1Controller Post2") gtest.Assert(client.DeleteContent("/"), "1Controller Delete2") - gtest.Assert(client.PatchContent("/"), "1Controller Patch2") - gtest.Assert(client.OptionsContent("/"), "1Controller Options2") resp1, err := client.Head("/") if err == nil { defer resp1.Close() @@ -91,8 +81,6 @@ func Test_Router_ControllerRest(t *testing.T) { gtest.Assert(client.PutContent("/controller-rest/put"), "1Controller Put2") gtest.Assert(client.PostContent("/controller-rest/post"), "1Controller Post2") gtest.Assert(client.DeleteContent("/controller-rest/delete"), "1Controller Delete2") - gtest.Assert(client.PatchContent("/controller-rest/patch"), "1Controller Patch2") - gtest.Assert(client.OptionsContent("/controller-rest/options"), "1Controller Options2") resp2, err := client.Head("/controller-rest/head") if err == nil { defer resp2.Close() diff --git a/net/ghttp/ghttp_unit_router_hook_test.go b/net/ghttp/ghttp_unit_router_hook_test.go index e7a2daca2..1edc604fe 100644 --- a/net/ghttp/ghttp_unit_router_hook_test.go +++ b/net/ghttp/ghttp_unit_router_hook_test.go @@ -50,6 +50,7 @@ func Test_Router_Hook_Fuzzy_Router(t *testing.T) { pattern1 := "/:name/info" s.BindHookHandlerByMap(pattern1, map[string]ghttp.HandlerFunc{ ghttp.HOOK_BEFORE_SERVE: func(r *ghttp.Request) { + fmt.Println("called") r.SetParam("uid", i) i++ }, @@ -58,19 +59,19 @@ func Test_Router_Hook_Fuzzy_Router(t *testing.T) { r.Response.Write(r.Get("uid")) }) - pattern2 := "/{object}/list/{page}.java" - s.BindHookHandlerByMap(pattern2, map[string]ghttp.HandlerFunc{ - ghttp.HOOK_BEFORE_OUTPUT: func(r *ghttp.Request) { - r.Response.SetBuffer([]byte( - fmt.Sprint(r.Get("object"), "&", r.Get("page"), "&", i), - )) - }, - }) - s.BindHandler(pattern2, func(r *ghttp.Request) { - r.Response.Write(r.Router.Uri) - }) + //pattern2 := "/{object}/list/{page}.java" + //s.BindHookHandlerByMap(pattern2, map[string]ghttp.HandlerFunc{ + // ghttp.HOOK_BEFORE_OUTPUT: func(r *ghttp.Request) { + // r.Response.SetBuffer([]byte( + // fmt.Sprint(r.Get("object"), "&", r.Get("page"), "&", i), + // )) + // }, + //}) + //s.BindHandler(pattern2, func(r *ghttp.Request) { + // r.Response.Write(r.Router.Uri) + //}) s.SetPort(p) - s.SetDumpRouterMap(false) + //s.SetDumpRouterMap(false) s.Start() defer s.Shutdown() diff --git a/os/gfile/gfile.go b/os/gfile/gfile.go index 94c222819..5c101e632 100644 --- a/os/gfile/gfile.go +++ b/os/gfile/gfile.go @@ -10,6 +10,7 @@ package gfile import ( "bytes" "errors" + "github.com/gogf/gf/text/gstr" "os" "os/exec" "os/user" @@ -58,7 +59,9 @@ func Mkdir(path string) error { func Create(path string) (*os.File, error) { dir := Dir(path) if !Exists(dir) { - Mkdir(dir) + if err := Mkdir(dir); err != nil { + return nil, err + } } return os.Create(path) } @@ -93,7 +96,14 @@ func OpenWithFlagPerm(path string, flag int, perm os.FileMode) (*os.File, error) // Join joins string array paths with file separator of current system. func Join(paths ...string) string { - return strings.Join(paths, Separator) + var s string + for _, path := range paths { + if s != "" { + s += Separator + } + s += gstr.TrimRight(path, Separator) + } + return s } // Exists checks whether given exist. diff --git a/text/gregex/gregex_cache.go b/text/gregex/gregex_cache.go index c87168e63..979ec19cd 100644 --- a/text/gregex/gregex_cache.go +++ b/text/gregex/gregex_cache.go @@ -14,7 +14,7 @@ import ( var ( regexMu = sync.RWMutex{} // Cache for regex object. - // TODO There's no expiring logic for this map. + // Note that there's no expiring logic for this map. regexMap = make(map[string]*regexp.Regexp) ) @@ -22,29 +22,25 @@ var ( // It uses cache to enhance the performance for compiling regular expression pattern, // which means, it will return the same *regexp.Regexp object with the same regular // expression pattern. -func getRegexp(pattern string) (*regexp.Regexp, error) { - if r := getCache(pattern); r != nil { - return r, nil - } - if r, err := regexp.Compile(pattern); err == nil { - setCache(pattern, r) - return r, nil - } else { - return nil, err - } -} - -// getCache returns *regexp.Regexp object from cache by given , for internal usage. -func getCache(pattern string) (regex *regexp.Regexp) { +// +// It is concurrent-safe for multiple goroutines. +func getRegexp(pattern string) (regex *regexp.Regexp, err error) { + // Retrieve the regular expression object using reading lock. regexMu.RLock() regex = regexMap[pattern] regexMu.RUnlock() - return -} - -// setCache stores *regexp.Regexp object into cache, for internal usage. -func setCache(pattern string, regex *regexp.Regexp) { + if regex != nil { + return + } + // If it does not exist in the cache, + // it compiles the pattern and creates one. + regex, err = regexp.Compile(pattern) + if err != nil { + return + } + // Cache the result object using writing lock. regexMu.Lock() regexMap[pattern] = regex regexMu.Unlock() + return } diff --git a/util/gpage/gpage.go b/util/gpage/gpage.go index a231c74bc..01b4e7459 100644 --- a/util/gpage/gpage.go +++ b/util/gpage/gpage.go @@ -10,7 +10,7 @@ package gpage import ( "fmt" "math" - url2 "net/url" + "net/url" "strings" "github.com/gogf/gf/net/ghttp" @@ -19,30 +19,27 @@ import ( "github.com/gogf/gf/util/gconv" ) -// 分页对象 +// Page is the pagination implementer. type Page struct { - Url *url2.URL // 当前页面的URL对象 - Router *ghttp.Router // 当前页面的路由对象(与gf框架耦合,在静态分页下有效) - UrlTemplate string // URL生成规则,内部可使用{.page}变量指定页码 - TotalSize int // 总共数据条数 - TotalPage int // 总页数 - CurrentPage int // 当前页码 - PageName string // 分页参数名称(GET参数) - NextPageTag string // 下一页标签 - PrevPageTag string // 上一页标签 - FirstPageTag string // 首页标签 - LastPageTag string // 尾页标签 - PrevBar string // 上一分页条 - NextBar string // 下一分页条 - PageBarNum int // 控制分页条的数量 - AjaxActionName string // AJAX方法名,当该属性有值时,表示使用AJAX分页 + UrlTemplate string // Custom url template for page url producing. + TotalSize int // Total size. + TotalPage int // Total page, which is automatically calculated. + CurrentPage int // Current page number >= 1. + PageName string // Page variable name. It's "page" in default. + NextPageTag string // Tag name for next p. + PrevPageTag string // Tag name for prev p. + FirstPageTag string // Tag name for first p. + LastPageTag string // Tag name for last p. + PrevBar string // Tag string for prev bar. + NextBar string // Tag string for next bar. + PageBarNum int // Page bar number for displaying. + AjaxActionName string // Ajax function name. Ajax is enabled if this attribute is not empty. } // 创建一个分页对象,输入参数分别为: // 总数量、每页数量、当前页码、当前的URL(URI+QUERY)、(可选)路由规则(例如: /user/list/:page、/order/list/*page、/order/list/{page}.html) -func New(TotalSize, perPage int, CurrentPage interface{}, url string, router ...*ghttp.Router) *Page { - u, _ := url2.Parse(url) - page := &Page{ +func New(totalSize, pageSize, currentPage int, urlTemplate string) *Page { + p := &Page{ PageName: "page", PrevPageTag: "<", NextPageTag: ">", @@ -50,34 +47,20 @@ func New(TotalSize, perPage int, CurrentPage interface{}, url string, router ... LastPageTag: ">|", PrevBar: "<<", NextBar: ">>", - TotalSize: TotalSize, - TotalPage: int(math.Ceil(float64(TotalSize) / float64(perPage))), + TotalSize: totalSize, + TotalPage: int(math.Ceil(float64(totalSize) / float64(pageSize))), CurrentPage: 1, PageBarNum: 10, - Url: u, + UrlTemplate: urlTemplate, } - curPage := gconv.Int(CurrentPage) - if curPage > 0 { - page.CurrentPage = curPage + if currentPage > 0 { + p.CurrentPage = currentPage } - if len(router) > 0 { - page.Router = router[0] - } - return page -} - -// 启用AJAX分页 -func (page *Page) EnableAjax(actionName string) { - page.AjaxActionName = actionName -} - -// 设置URL生成规则模板,模板中可使用{.page}变量指定页码位置 -func (page *Page) SetUrlTemplate(template string) { - page.UrlTemplate = template + return p } // 获取显示"下一页"的内容. -func (page *Page) NextPage(styles ...string) string { +func (p *Page) NextPage(styles ...string) string { var curStyle, style string if len(styles) > 0 { curStyle = styles[0] @@ -85,14 +68,14 @@ func (page *Page) NextPage(styles ...string) string { if len(styles) > 1 { style = styles[0] } - if page.CurrentPage < page.TotalPage { - return page.GetLink(page.GetUrl(page.CurrentPage+1), page.NextPageTag, "下一页", style) + if p.CurrentPage < p.TotalPage { + return p.GetLink(p.GetUrl(p.CurrentPage+1), p.NextPageTag, "下一页", style) } - return fmt.Sprintf(`%s`, curStyle, page.NextPageTag) + return fmt.Sprintf(`%s`, curStyle, p.NextPageTag) } // 获取显示“上一页”的内容 -func (page *Page) PrevPage(styles ...string) string { +func (p *Page) PrevPage(styles ...string) string { var curStyle, style string if len(styles) > 0 { curStyle = styles[0] @@ -100,14 +83,14 @@ func (page *Page) PrevPage(styles ...string) string { if len(styles) > 1 { style = styles[0] } - if page.CurrentPage > 1 { - return page.GetLink(page.GetUrl(page.CurrentPage-1), page.PrevPageTag, "上一页", style) + if p.CurrentPage > 1 { + return p.GetLink(p.GetUrl(p.CurrentPage-1), p.PrevPageTag, "上一页", style) } - return fmt.Sprintf(`%s`, curStyle, page.PrevPageTag) + return fmt.Sprintf(`%s`, curStyle, p.PrevPageTag) } // 获取显示“首页”的代码 -func (page *Page) FirstPage(styles ...string) string { +func (p *Page) FirstPage(styles ...string) string { var curStyle, style string if len(styles) > 0 { curStyle = styles[0] @@ -115,14 +98,14 @@ func (page *Page) FirstPage(styles ...string) string { if len(styles) > 1 { style = styles[0] } - if page.CurrentPage == 1 { - return fmt.Sprintf(`%s`, curStyle, page.FirstPageTag) + if p.CurrentPage == 1 { + return fmt.Sprintf(`%s`, curStyle, p.FirstPageTag) } - return page.GetLink(page.GetUrl(1), page.FirstPageTag, "第一页", style) + return p.GetLink(p.GetUrl(1), p.FirstPageTag, "第一页", style) } // 获取显示“尾页”的内容 -func (page *Page) LastPage(styles ...string) string { +func (p *Page) LastPage(styles ...string) string { var curStyle, style string if len(styles) > 0 { curStyle = styles[0] @@ -130,14 +113,14 @@ func (page *Page) LastPage(styles ...string) string { if len(styles) > 1 { style = styles[0] } - if page.CurrentPage == page.TotalPage { - return fmt.Sprintf(`%s`, curStyle, page.LastPageTag) + if p.CurrentPage == p.TotalPage { + return fmt.Sprintf(`%s`, curStyle, p.LastPageTag) } - return page.GetLink(page.GetUrl(page.TotalPage), page.LastPageTag, "最后页", style) + return p.GetLink(p.GetUrl(p.TotalPage), p.LastPageTag, "最后页", style) } // 获得分页条列表内容 -func (page *Page) PageBar(styles ...string) string { +func (p *Page) PageBar(styles ...string) string { var curStyle, style string if len(styles) > 0 { curStyle = styles[0] @@ -145,19 +128,19 @@ func (page *Page) PageBar(styles ...string) string { if len(styles) > 1 { style = styles[0] } - plus := int(math.Ceil(float64(page.PageBarNum / 2))) - if page.PageBarNum-plus+page.CurrentPage > page.TotalPage { - plus = page.PageBarNum - page.TotalPage + page.CurrentPage + plus := int(math.Ceil(float64(p.PageBarNum / 2))) + if p.PageBarNum-plus+p.CurrentPage > p.TotalPage { + plus = p.PageBarNum - p.TotalPage + p.CurrentPage } - begin := page.CurrentPage - plus + 1 + begin := p.CurrentPage - plus + 1 if begin < 1 { begin = 1 } ret := "" - for i := begin; i < begin+page.PageBarNum; i++ { - if i <= page.TotalPage { - if i != page.CurrentPage { - ret += page.GetLink(page.GetUrl(i), gconv.String(i), style, "") + for i := begin; i < begin+p.PageBarNum; i++ { + if i <= p.TotalPage { + if i != p.CurrentPage { + ret += p.GetLink(p.GetUrl(i), gconv.String(i), style, "") } else { ret += fmt.Sprintf(`%d`, curStyle, i) } @@ -169,13 +152,13 @@ func (page *Page) PageBar(styles ...string) string { } // 获取基于select标签的显示跳转按钮的代码 -func (page *Page) SelectBar() string { - ret := `` + for i := 1; i <= p.TotalPage; i++ { + if i == p.CurrentPage { + ret += fmt.Sprintf(``, p.GetUrl(i), i) } else { - ret += fmt.Sprintf(``, page.GetUrl(i), i) + ret += fmt.Sprintf(``, p.GetUrl(i), i) } } ret += "" @@ -183,117 +166,87 @@ func (page *Page) SelectBar() string { } // 预定义的分页显示风格内容 -func (page *Page) GetContent(mode int) string { +func (p *Page) GetContent(mode int) string { switch mode { case 1: - page.NextPageTag = "下一页" - page.PrevPageTag = "上一页" + p.NextPageTag = "下一页" + p.PrevPageTag = "上一页" return fmt.Sprintf( `%s %d %s`, - page.PrevPage(), - page.CurrentPage, - page.NextPage(), + p.PrevPage(), + p.CurrentPage, + p.NextPage(), ) case 2: - page.NextPageTag = "下一页>>" - page.PrevPageTag = "<<上一页" - page.FirstPageTag = "首页" - page.LastPageTag = "尾页" + p.NextPageTag = "下一页>>" + p.PrevPageTag = "<<上一页" + p.FirstPageTag = "首页" + p.LastPageTag = "尾页" return fmt.Sprintf( `%s%s[第%d页]%s%s第%s页`, - page.FirstPage(), - page.PrevPage(), - page.CurrentPage, - page.NextPage(), - page.LastPage(), - page.SelectBar(), + p.FirstPage(), + p.PrevPage(), + p.CurrentPage, + p.NextPage(), + p.LastPage(), + p.SelectBar(), ) case 3: - page.NextPageTag = "下一页" - page.PrevPageTag = "上一页" - page.FirstPageTag = "首页" - page.LastPageTag = "尾页" - pageStr := page.FirstPage() - pageStr += page.PrevPage() - pageStr += page.PageBar("current") - pageStr += page.NextPage() - pageStr += page.LastPage() + p.NextPageTag = "下一页" + p.PrevPageTag = "上一页" + p.FirstPageTag = "首页" + p.LastPageTag = "尾页" + pageStr := p.FirstPage() + pageStr += p.PrevPage() + pageStr += p.PageBar("current") + pageStr += p.NextPage() + pageStr += p.LastPage() pageStr += fmt.Sprintf( `当前页%d/%d 共%d条`, - page.CurrentPage, - page.TotalPage, - page.TotalSize, + p.CurrentPage, + p.TotalPage, + p.TotalSize, ) return pageStr case 4: - page.NextPageTag = "下一页" - page.PrevPageTag = "上一页" - page.FirstPageTag = "首页" - page.LastPageTag = "尾页" - pageStr := page.FirstPage() - pageStr += page.PrevPage() - pageStr += page.PageBar("current") - pageStr += page.NextPage() - pageStr += page.LastPage() + p.NextPageTag = "下一页" + p.PrevPageTag = "上一页" + p.FirstPageTag = "首页" + p.LastPageTag = "尾页" + pageStr := p.FirstPage() + pageStr += p.PrevPage() + pageStr += p.PageBar("current") + pageStr += p.NextPage() + pageStr += p.LastPage() return pageStr } return "" } // 为指定的页面返回地址值 -func (page *Page) GetUrl(pageNo int) string { - // 复制一个URL对象 - url := *page.Url - if len(page.UrlTemplate) == 0 && page.Router != nil { - page.UrlTemplate = page.makeUrlTemplate(url.Path, page.Router) - } - if len(page.UrlTemplate) > 0 { - // 指定URL生成模板 - url.Path = gstr.Replace(page.UrlTemplate, "{.page}", gconv.String(pageNo)) +func (p *Page) GetUrl(pageNo int) string { + pattern := fmt.Sprintf(`(:%s|\*%s|\.%s)`, p.PageName, p.PageName, p.PageName) + result, _ := gregex.ReplaceString(pattern, pageNo, p.UrlTemplate) + url.Path = gstr.Replace(p.UrlTemplate, "{.page}", gconv.String(pageNo)) return url.String() } - values := page.Url.Query() - values.Set(page.PageName, gconv.String(pageNo)) + values := p.Url.Query() + values.Set(p.PageName, gconv.String(pageNo)) url.RawQuery = values.Encode() return url.String() } -// 根据当前URL以及注册路由信息计算出对应的URL模板 -func (page *Page) makeUrlTemplate(url string, router *ghttp.Router) (tpl string) { - if page.Router != nil && len(router.RegNames) > 0 { - if match, err := gregex.MatchString(router.RegRule, url); err == nil && len(match) > 0 { - if len(match) > len(router.RegNames) { - tpl = router.Uri - hasPageName := false - for i, name := range router.RegNames { - rule := fmt.Sprintf(`[:\*]%s|\{%s\}`, name, name) - if !hasPageName && strings.Compare(name, page.PageName) == 0 { - hasPageName = true - tpl, _ = gregex.ReplaceString(rule, `{.page}`, tpl) - } else { - tpl, _ = gregex.ReplaceString(rule, match[i+1], tpl) - } - } - if !hasPageName { - tpl = "" - } - } - } - } - return -} - // 获取链接地址 -func (page *Page) GetLink(url, text, title, style string) string { +func (p *Page) GetLink(url, text, title, style string) string { if len(style) > 0 { style = fmt.Sprintf(`class="%s" `, style) } - if len(page.AjaxActionName) > 0 { - return fmt.Sprintf(`%s`, style, page.AjaxActionName, url, text) + if len(p.AjaxActionName) > 0 { + return fmt.Sprintf(`%s`, style, p.AjaxActionName, url, text) } else { return fmt.Sprintf(`%s`, style, url, title, text) } From 4e7c6c1fb4fffe97fe158904720e589429334668 Mon Sep 17 00:00:00 2001 From: John Date: Wed, 4 Mar 2020 22:52:56 +0800 Subject: [PATCH 02/26] improve CORS feature for ghttp.Server --- net/ghttp/ghttp_request_middleware.go | 1 + net/ghttp/ghttp_response_cors.go | 53 ++++++++----- net/ghttp/ghttp_server.go | 1 + net/ghttp/ghttp_server_router.go | 61 +++++---------- net/ghttp/ghttp_server_router_serve.go | 35 ++++++--- ...go => ghttp_unit_middleware_basic_test.go} | 35 ++++++++- net/ghttp/ghttp_unit_middleware_cors_test.go | 76 +++++++++++++++++++ net/ghttp/ghttp_unit_router_hook_test.go | 23 +++--- 8 files changed, 196 insertions(+), 89 deletions(-) rename net/ghttp/{ghttp_unit_middleware_test.go => ghttp_unit_middleware_basic_test.go} (95%) create mode 100644 net/ghttp/ghttp_unit_middleware_cors_test.go diff --git a/net/ghttp/ghttp_request_middleware.go b/net/ghttp/ghttp_request_middleware.go index f4e4907e9..050c16f0e 100644 --- a/net/ghttp/ghttp_request_middleware.go +++ b/net/ghttp/ghttp_request_middleware.go @@ -123,6 +123,7 @@ func (m *Middleware) Next() { }, func(exception interface{}) { m.request.error = gerror.Newf("%v", exception) m.request.Response.WriteStatus(http.StatusInternalServerError, exception) + loop = false }) } // Check the http status code after all handler and middleware done. diff --git a/net/ghttp/ghttp_response_cors.go b/net/ghttp/ghttp_response_cors.go index 835a1c8e2..59fc35d5d 100644 --- a/net/ghttp/ghttp_response_cors.go +++ b/net/ghttp/ghttp_response_cors.go @@ -8,11 +8,10 @@ package ghttp import ( - "net/http" - "net/url" - "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" + "net/http" + "net/url" ) // CORSOptions is the options for CORS feature. @@ -29,18 +28,18 @@ type CORSOptions struct { var ( // defaultAllowHeaders is the default allowed headers for CORS. - // It's defined as map for better header key searching performance. - defaultAllowHeaders = map[string]struct{}{ - "Origin": {}, - "Accept": {}, - "Cookie": {}, - "Authorization": {}, - "X-Auth-Token": {}, - "X-Requested-With": {}, - "Content-Type": {}, - } + // It's defined another map for better header key searching performance. + defaultAllowHeaders = "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With" + defaultAllowHeadersMap = make(map[string]struct{}) ) +func init() { + array := gstr.SplitAndTrim(defaultAllowHeaders, ",") + for _, header := range array { + defaultAllowHeadersMap[header] = struct{}{} + } +} + // DefaultCORSOptions returns the default CORS options, // which allows any cross-domain request. func (r *Response) DefaultCORSOptions() CORSOptions { @@ -48,22 +47,17 @@ func (r *Response) DefaultCORSOptions() CORSOptions { AllowOrigin: "*", AllowMethods: HTTP_METHODS, AllowCredentials: "true", + AllowHeaders: defaultAllowHeaders, MaxAge: 3628800, } // Allow all client's custom headers in default. if headers := r.Request.Header.Get("Access-Control-Request-Headers"); headers != "" { array := gstr.SplitAndTrim(headers, ",") for _, header := range array { - if _, ok := defaultAllowHeaders[header]; !ok { + if _, ok := defaultAllowHeadersMap[header]; !ok { options.AllowHeaders += header + "," } } - for header, _ := range defaultAllowHeaders { - if len(options.AllowHeaders) > 0 { - options.AllowHeaders += "," - } - options.AllowHeaders += header - } } // Allow all anywhere origin in default. if origin := r.Request.Header.Get("Origin"); origin != "" { @@ -101,8 +95,25 @@ func (r *Response) CORS(options CORSOptions) { } // No continue service handling if it's OPTIONS request. if gstr.Equal(r.Request.Method, "OPTIONS") { + // Request method's handler searching. + // It here uses Server.routesMap attribute enhancing the searching performance. + if method := r.Request.Header.Get("Access-Control-Request-Method"); method != "" { + routerKey := "" + for _, domain := range []string{gDEFAULT_DOMAIN, r.Request.GetHost()} { + for _, v := range []string{gDEFAULT_METHOD, method} { + routerKey = r.Server.handlerKey("", v, r.Request.URL.Path, domain) + if r.Server.routesMap[routerKey] != nil { + if r.Status == 0 { + r.Status = http.StatusOK + } + r.Request.ExitAll() + } + } + } + } + // Cannot find the request handler. if r.Status == 0 { - r.Status = http.StatusOK + r.Status = http.StatusNotFound } r.Request.ExitAll() } diff --git a/net/ghttp/ghttp_server.go b/net/ghttp/ghttp_server.go index d79aa80c7..b446a17e8 100644 --- a/net/ghttp/ghttp_server.go +++ b/net/ghttp/ghttp_server.go @@ -79,6 +79,7 @@ type ( // handlerItem is the registered handler for route handling, // including middleware and hook functions. handlerItem struct { + itemId int // Unique handler item id mark. itemName string // Handler name, which is automatically retrieved from runtime stack when registered. itemType int // Handler type: object/handler/controller/middleware/hook. itemFunc HandlerFunc // Handler address. diff --git a/net/ghttp/ghttp_server_router.go b/net/ghttp/ghttp_server_router.go index 7e7c2f353..29426863c 100644 --- a/net/ghttp/ghttp_server_router.go +++ b/net/ghttp/ghttp_server_router.go @@ -9,7 +9,7 @@ package ghttp import ( "errors" "fmt" - "github.com/gogf/gf/util/gutil" + "github.com/gogf/gf/container/gtype" "strings" "github.com/gogf/gf/debug/gdebug" @@ -23,6 +23,11 @@ const ( gFILTER_KEY = "/net/ghttp/ghttp" ) +var ( + // handlerIdGenerator is handler item id generator. + handlerIdGenerator = gtype.NewInt() +) + // handlerKey creates and returns an unique router key for given parameters. func (s *Server) handlerKey(hook, method, path, domain string) string { return hook + "%" + s.serveHandlerKey(method, path, domain) @@ -59,6 +64,7 @@ func (s *Server) parsePattern(pattern string) (domain, method, path string, err // This function is called during server starts up, which cares little about the performance. What really cares // is the well designed router storage structure for router searching when the request is under serving. func (s *Server) setHandler(pattern string, handler *handlerItem) { + handler.itemId = handlerIdGenerator.Add(1) domain, method, uri, err := s.parsePattern(pattern) if err != nil { s.Logger().Fatal("invalid pattern:", pattern, err) @@ -70,11 +76,11 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { } // Repeated router checks, this feature can be disabled by server configuration. - regKey := s.handlerKey(handler.hookName, method, uri, domain) + routerKey := s.handlerKey(handler.hookName, method, uri, domain) if !s.config.RouteOverWrite { switch handler.itemType { case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER: - if item, ok := s.routesMap[regKey]; ok { + if item, ok := s.routesMap[routerKey]; ok { s.Logger().Fatalf(`duplicated route registry "%s", already registered at %s`, pattern, item[0].file) return } @@ -143,47 +149,14 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { // fuzzy checks. if i == len(array)-1 && part != "*fuzz" { if v, ok := p.(map[string]interface{})["*list"]; !ok { - list := glist.New() - p.(map[string]interface{})["*list"] = list - lists = append(lists, list) + leafList := glist.New() + p.(map[string]interface{})["*list"] = leafList + lists = append(lists, leafList) } else { lists = append(lists, v.(*glist.List)) } } } - - for k, v := range array { - if len(v) == 0 { - continue - } - // 判断是否模糊匹配规则 - if gregex.IsMatchString(`^[:\*]|\{[\w\.\-]+\}|\*`, v) { - v = "*fuzz" - // 由于是模糊规则,因此这里会有一个*list,用以将后续的路由规则加进来, - // 检索会从叶子节点的链表往根节点按照优先级进行检索 - if v, ok := p.(map[string]interface{})["*list"]; !ok { - p.(map[string]interface{})["*list"] = glist.New() - lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List)) - } else { - lists = append(lists, v.(*glist.List)) - } - } - // 属性层级数据写入 - if _, ok := p.(map[string]interface{})[v]; !ok { - p.(map[string]interface{})[v] = make(map[string]interface{}) - } - p = p.(map[string]interface{})[v] - // 到达叶子节点,往list中增加匹配规则(条件 v != "*fuzz" 是因为模糊节点的话在前面已经添加了*list链表) - if k == len(array)-1 && v != "*fuzz" { - if v, ok := p.(map[string]interface{})["*list"]; !ok { - p.(map[string]interface{})["*list"] = glist.New() - lists = append(lists, p.(map[string]interface{})["*list"].(*glist.List)) - } else { - lists = append(lists, v.(*glist.List)) - } - } - } - // It iterates the list array of , compares priorities and inserts the new router item in // the proper position of each list. The priority of the list is ordered from high to low. item := (*handlerItem)(nil) @@ -206,8 +179,8 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { } } // Initialize the route map item. - if _, ok := s.routesMap[regKey]; !ok { - s.routesMap[regKey] = make([]registeredRouteItem, 0) + if _, ok := s.routesMap[routerKey]; !ok { + s.routesMap[routerKey] = make([]registeredRouteItem, 0) } _, file, line := gdebug.CallerWithFilter(gFILTER_KEY) routeItem := registeredRouteItem{ @@ -217,12 +190,12 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { switch handler.itemType { case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER: // Overwrite the route. - s.routesMap[regKey] = []registeredRouteItem{routeItem} + s.routesMap[routerKey] = []registeredRouteItem{routeItem} default: // Append the route. - s.routesMap[regKey] = append(s.routesMap[regKey], routeItem) + s.routesMap[routerKey] = append(s.routesMap[routerKey], routeItem) } - gutil.Dump(s.serveTree) + //gutil.Dump(s.serveTree) } // 对比两个handlerItem的优先级,需要非常注意的是,注意新老对比项的参数先后顺序。 diff --git a/net/ghttp/ghttp_server_router_serve.go b/net/ghttp/ghttp_server_router_serve.go index 7306a5614..0fb563b8a 100644 --- a/net/ghttp/ghttp_server_router_serve.go +++ b/net/ghttp/ghttp_server_router_serve.go @@ -15,7 +15,7 @@ import ( "github.com/gogf/gf/text/gregex" ) -// handlerCacheItem is a item for router cache. +// handlerCacheItem is an item for router searching cache. type handlerCacheItem struct { parsedItems []*handlerParsedItem hasHook bool @@ -24,8 +24,17 @@ type handlerCacheItem struct { // getHandlersWithCache searches the router item with cache feature for given request. func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) { - value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(r.Method, r.URL.Path, r.GetHost()), func() interface{} { - parsedItems, hasHook, hasServe = s.searchHandlers(r.Method, r.URL.Path, r.GetHost()) + method := r.Method + // Special http method OPTIONS handling. + // It searches the handler with the request method instead of OPTIONS method. + if method == "OPTIONS" { + if v := r.Request.Header.Get("Access-Control-Request-Method"); v != "" { + method = v + } + } + // Search and cache the router handlers. + value := s.serveCache.GetOrSetFunc(s.serveHandlerKey(method, r.URL.Path, r.GetHost()), func() interface{} { + parsedItems, hasHook, hasServe = s.searchHandlers(method, r.URL.Path, r.GetHost()) if parsedItems != nil { return &handlerCacheItem{parsedItems, hasHook, hasServe} } @@ -44,11 +53,6 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han if len(path) == 0 { return nil, false, false } - // Default domain has the most priority when iteration. - domains := []string{gDEFAULT_DOMAIN} - if !strings.EqualFold(gDEFAULT_DOMAIN, domain) { - domains = append(domains, domain) - } // Split the URL.path to separate parts. var array []string if strings.EqualFold("/", path) { @@ -58,11 +62,14 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han } parsedItemList := glist.New() lastMiddlewareElem := (*glist.Element)(nil) - for _, domain := range domains { + repeatHandlerCheckMap := make(map[int]struct{}, 16) + // Default domain has the most priority when iteration. + for _, domain := range []string{gDEFAULT_DOMAIN, domain} { p, ok := s.serveTree[domain] if !ok { continue } + // Make a list array with capacity of 16. lists := make([]*glist.List, 0, 16) for i, part := range array { // In case of double '/' URI, eg: /user//index @@ -72,8 +79,8 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han if v, ok := p.(map[string]interface{})["*list"]; ok { lists = append(lists, v.(*glist.List)) } - if _, ok := p.(map[string]interface{})[part]; ok { - p = p.(map[string]interface{})[part] + if v, ok := p.(map[string]interface{})[part]; ok { + p = v if i == len(array)-1 { if v, ok := p.(map[string]interface{})["*list"]; ok { lists = append(lists, v.(*glist.List)) @@ -100,6 +107,12 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han for i := len(lists) - 1; i >= 0; i-- { for e := lists[i].Front(); e != nil; e = e.Next() { item := e.Value.(*handlerItem) + // 主要是用于路由注册函数的重复添加判断(特别是中间件和钩子函数) + if _, ok := repeatHandlerCheckMap[item.itemId]; ok { + continue + } else { + repeatHandlerCheckMap[item.itemId] = struct{}{} + } // 服务路由函数只能添加一次,将重复判断放在这里提高检索效率 if hasServe { switch item.itemType { diff --git a/net/ghttp/ghttp_unit_middleware_test.go b/net/ghttp/ghttp_unit_middleware_basic_test.go similarity index 95% rename from net/ghttp/ghttp_unit_middleware_test.go rename to net/ghttp/ghttp_unit_middleware_basic_test.go index cfb56b4c1..0ff2c4beb 100644 --- a/net/ghttp/ghttp_unit_middleware_test.go +++ b/net/ghttp/ghttp_unit_middleware_basic_test.go @@ -594,8 +594,9 @@ func MiddlewareCORS(r *ghttp.Request) { func Test_Middleware_CORSAndAuth(t *testing.T) { p := ports.PopRand() s := g.Server(p) + s.Use(MiddlewareCORS) s.Group("/api.v2", func(group *ghttp.RouterGroup) { - group.Middleware(MiddlewareAuth, MiddlewareCORS) + group.Middleware(MiddlewareAuth) group.POST("/user/list", func(r *ghttp.Request) { r.Response.Write("list") }) @@ -680,3 +681,35 @@ func Test_Middleware_Scope(t *testing.T) { gtest.Assert(client.GetContent("/scope3"), "ae3fb") }) } + +func Test_Middleware_Panic(t *testing.T) { + p := ports.PopRand() + s := g.Server(p) + i := 0 + s.Group("/", func(group *ghttp.RouterGroup) { + group.Group("/", func(group *ghttp.RouterGroup) { + group.Middleware(func(r *ghttp.Request) { + i++ + panic("error") + r.Middleware.Next() + }, func(r *ghttp.Request) { + i++ + r.Middleware.Next() + }) + group.ALL("/", func(r *ghttp.Request) { + r.Response.Write(i) + }) + }) + }) + s.SetPort(p) + //s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + time.Sleep(100 * time.Millisecond) + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + gtest.Assert(client.GetContent("/"), "error") + }) +} diff --git a/net/ghttp/ghttp_unit_middleware_cors_test.go b/net/ghttp/ghttp_unit_middleware_cors_test.go new file mode 100644 index 000000000..c99e16c97 --- /dev/null +++ b/net/ghttp/ghttp_unit_middleware_cors_test.go @@ -0,0 +1,76 @@ +// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package ghttp_test + +import ( + "fmt" + "github.com/gogf/gf/frame/g" + "github.com/gogf/gf/net/ghttp" + "github.com/gogf/gf/test/gtest" + "testing" + "time" +) + +func Test_Middleware_CORS(t *testing.T) { + p := ports.PopRand() + s := g.Server(p) + s.Group("/api.v2", func(group *ghttp.RouterGroup) { + group.Middleware(MiddlewareCORS) + group.POST("/user/list", func(r *ghttp.Request) { + r.Response.Write("list") + }) + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + time.Sleep(100 * time.Millisecond) + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + // Common Checks. + gtest.Assert(client.GetContent("/"), "Not Found") + gtest.Assert(client.GetContent("/api.v2"), "Not Found") + + // GET request does not any route. + resp, err := client.Get("/api.v2/user/list") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0) + resp.Close() + + // POST request matches the route and CORS middleware. + resp, err = client.Post("/api.v2/user/list") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1) + gtest.Assert(resp.Header["Access-Control-Allow-Headers"][0], "Origin,Content-Type,Accept,User-Agent,Cookie,Authorization,X-Auth-Token,X-Requested-With") + gtest.Assert(resp.Header["Access-Control-Allow-Methods"][0], "GET,PUT,POST,DELETE,PATCH,HEAD,CONNECT,OPTIONS,TRACE") + gtest.Assert(resp.Header["Access-Control-Allow-Origin"][0], "*") + gtest.Assert(resp.Header["Access-Control-Max-Age"][0], "3628800") + resp.Close() + }) + // OPTIONS GET + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + client.SetHeader("Access-Control-Request-Method", "GET") + resp, err := client.Options("/api.v2/user/list") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 0) + gtest.Assert(resp.ReadAllString(), "Not Found") + resp.Close() + }) + // OPTIONS POST + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + client.SetHeader("Access-Control-Request-Method", "POST") + resp, err := client.Options("/api.v2/user/list") + gtest.Assert(err, nil) + gtest.Assert(len(resp.Header["Access-Control-Allow-Headers"]), 1) + resp.Close() + }) +} diff --git a/net/ghttp/ghttp_unit_router_hook_test.go b/net/ghttp/ghttp_unit_router_hook_test.go index 1edc604fe..c3b592b79 100644 --- a/net/ghttp/ghttp_unit_router_hook_test.go +++ b/net/ghttp/ghttp_unit_router_hook_test.go @@ -50,7 +50,6 @@ func Test_Router_Hook_Fuzzy_Router(t *testing.T) { pattern1 := "/:name/info" s.BindHookHandlerByMap(pattern1, map[string]ghttp.HandlerFunc{ ghttp.HOOK_BEFORE_SERVE: func(r *ghttp.Request) { - fmt.Println("called") r.SetParam("uid", i) i++ }, @@ -59,17 +58,17 @@ func Test_Router_Hook_Fuzzy_Router(t *testing.T) { r.Response.Write(r.Get("uid")) }) - //pattern2 := "/{object}/list/{page}.java" - //s.BindHookHandlerByMap(pattern2, map[string]ghttp.HandlerFunc{ - // ghttp.HOOK_BEFORE_OUTPUT: func(r *ghttp.Request) { - // r.Response.SetBuffer([]byte( - // fmt.Sprint(r.Get("object"), "&", r.Get("page"), "&", i), - // )) - // }, - //}) - //s.BindHandler(pattern2, func(r *ghttp.Request) { - // r.Response.Write(r.Router.Uri) - //}) + pattern2 := "/{object}/list/{page}.java" + s.BindHookHandlerByMap(pattern2, map[string]ghttp.HandlerFunc{ + ghttp.HOOK_BEFORE_OUTPUT: func(r *ghttp.Request) { + r.Response.SetBuffer([]byte( + fmt.Sprint(r.Get("object"), "&", r.Get("page"), "&", i), + )) + }, + }) + s.BindHandler(pattern2, func(r *ghttp.Request) { + r.Response.Write(r.Router.Uri) + }) s.SetPort(p) //s.SetDumpRouterMap(false) s.Start() From f68b66e60684673c23b4697c0935a90f0e5bc06e Mon Sep 17 00:00:00 2001 From: John Date: Wed, 4 Mar 2020 23:32:27 +0800 Subject: [PATCH 03/26] update comment for ghttp.Server --- net/ghttp/ghttp_server_router.go | 67 +++++++++++++++++--------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/net/ghttp/ghttp_server_router.go b/net/ghttp/ghttp_server_router.go index 29426863c..343d060b7 100644 --- a/net/ghttp/ghttp_server_router.go +++ b/net/ghttp/ghttp_server_router.go @@ -93,7 +93,7 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { Method: strings.ToUpper(method), Priority: strings.Count(uri[1:], "/"), } - handler.router.RegRule, handler.router.RegNames = s.patternToRegRule(uri) + handler.router.RegRule, handler.router.RegNames = s.patternToRegular(uri) if _, ok := s.serveTree[domain]; !ok { s.serveTree[domain] = make(map[string]interface{}) @@ -195,31 +195,34 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { // Append the route. s.routesMap[routerKey] = append(s.routesMap[routerKey], routeItem) } - //gutil.Dump(s.serveTree) } -// 对比两个handlerItem的优先级,需要非常注意的是,注意新老对比项的参数先后顺序。 -// 返回值true表示newItem优先级比oldItem高,会被添加链表中oldRouter的前面;否则后面。 -// 优先级比较规则: -// 1、中间件优先级最高,按照添加顺序优先级执行; -// 2、其他路由注册类型,层级越深优先级越高(对比/数量); -// 3、模糊规则优先级:{xxx} > :xxx > *xxx; +// compareRouterPriority compares the priority between and . It returns true +// if 's priority is higher than , else it returns false. The higher priority +// item will be insert into the router list before the other one. +// +// Comparison rules: +// 1. The middleware has the most high priority. +// 2. URI: The deeper the higher (simply check the count of char '/' in the URI). +// 3. Route type: {xxx} > :xxx > *xxx. func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerItem) bool { - // 中间件优先级最高,按照添加顺序优先级执行 + // If they're all type of middleware, the priority is according their registered sequence. if newItem.itemType == gHANDLER_TYPE_MIDDLEWARE && oldItem.itemType == gHANDLER_TYPE_MIDDLEWARE { return false } + // The middleware has the most high priority. if newItem.itemType == gHANDLER_TYPE_MIDDLEWARE && oldItem.itemType != gHANDLER_TYPE_MIDDLEWARE { return true } - // 优先比较层级,层级越深优先级越高 + // URI: The deeper the higher (simply check the count of char '/' in the URI). if newItem.router.Priority > oldItem.router.Priority { return true } if newItem.router.Priority < oldItem.router.Priority { return false } - // 精准匹配比模糊匹配规则优先级高,例如:/name/act 比 /{name}/:act 优先级高 + // Route type: {xxx} > :xxx > *xxx. + // Eg: /name/act > /{name}/:act var fuzzyCountFieldNew, fuzzyCountFieldOld int var fuzzyCountNameNew, fuzzyCountNameOld int var fuzzyCountAnyNew, fuzzyCountAnyOld int @@ -253,16 +256,16 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte return false } - /** 如果模糊规则数量相等,那么执行分别的数量判断 **/ + // If the counts of their fuzzy rules equal. - // 例如:/name/{act} 比 /name/:act 优先级高 + // Eg: /name/{act} > /name/:act if fuzzyCountFieldNew > fuzzyCountFieldOld { return true } if fuzzyCountFieldNew < fuzzyCountFieldOld { return false } - // 例如: /name/:act 比 /name/*act 优先级高 + // Eg: /name/:act > /name/*act if fuzzyCountNameNew > fuzzyCountNameOld { return true } @@ -270,9 +273,10 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte return false } - /** 比较路由规则长度,越长的规则优先级越高,模糊/命名规则不算长度 **/ + // It then compares the length of their URI, + // but the fuzzy and named parts of the URI are not calculated to the result. - // 例如:/admin-goods-{page} 比 /admin-{page} 优先级高 + // Eg: /admin-goods-{page} > /admin-{page} var uriNew, uriOld string uriNew, _ = gregex.ReplaceString(`\{[^/]+\}`, "", newItem.router.Uri) uriNew, _ = gregex.ReplaceString(`:[^/]+`, "", uriNew) @@ -287,9 +291,8 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte return false } - /* 模糊规则数量相等,后续不用再判断*规则的数量比较了 */ - - // 比较HTTP METHOD,更精准的优先级更高 + // It then compares the accuracy of their http method, + // the more accurate the more priority. if newItem.router.Method != gDEFAULT_METHOD { return true } @@ -297,23 +300,25 @@ func (s *Server) compareRouterPriority(newItem *handlerItem, oldItem *handlerIte return true } - // 如果是服务路由,那么新的规则比旧的规则优先级高(路由覆盖) + // If they have different router type, + // the new router item has more priority than the other one. if newItem.itemType == gHANDLER_TYPE_HANDLER || newItem.itemType == gHANDLER_TYPE_OBJECT || newItem.itemType == gHANDLER_TYPE_CONTROLLER { return true } - // 如果是其他路由(HOOK/中间件),那么新的规则比旧的规则优先级低,使得注册相同路由则顺序执行 + // Other situations, like HOOK items, + // the old router item has more priority than the other one. return false } -// 将pattern(不带method和domain)解析成正则表达式匹配以及对应的query字符串 -func (s *Server) patternToRegRule(rule string) (regrule string, names []string) { +// patternToRegular converts route rule to according regular expression. +func (s *Server) patternToRegular(rule string) (regular string, names []string) { if len(rule) < 2 { return rule, nil } - regrule = "^" + regular = "^" array := strings.Split(rule[1:], "/") for _, v := range array { if len(v) == 0 { @@ -322,17 +327,17 @@ func (s *Server) patternToRegRule(rule string) (regrule string, names []string) switch v[0] { case ':': if len(v) > 1 { - regrule += `/([^/]+)` + regular += `/([^/]+)` names = append(names, v[1:]) } else { - regrule += `/[^/]+` + regular += `/[^/]+` } case '*': if len(v) > 1 { - regrule += `/{0,1}(.*)` + regular += `/{0,1}(.*)` names = append(names, v[1:]) } else { - regrule += `/{0,1}.*` + regular += `/{0,1}.*` } default: // Special chars replacement. @@ -346,12 +351,12 @@ func (s *Server) patternToRegRule(rule string) (regrule string, names []string) return `([^/]+)` }) if strings.EqualFold(s, v) { - regrule += "/" + v + regular += "/" + v } else { - regrule += "/" + s + regular += "/" + s } } } - regrule += `$` + regular += `$` return } From 70722444201e0ed335280662baa2e69fb60e5847 Mon Sep 17 00:00:00 2001 From: John Date: Thu, 5 Mar 2020 16:08:55 +0800 Subject: [PATCH 04/26] improve comment for ghttp.Server --- net/ghttp/ghttp_response_cors.go | 9 ++-- net/ghttp/ghttp_server_router.go | 10 ++-- net/ghttp/ghttp_server_router_serve.go | 60 ++++++++++++++---------- net/ghttp/ghttp_server_service_object.go | 14 ++++-- 4 files changed, 54 insertions(+), 39 deletions(-) diff --git a/net/ghttp/ghttp_response_cors.go b/net/ghttp/ghttp_response_cors.go index 59fc35d5d..824ed5500 100644 --- a/net/ghttp/ghttp_response_cors.go +++ b/net/ghttp/ghttp_response_cors.go @@ -95,23 +95,24 @@ func (r *Response) CORS(options CORSOptions) { } // No continue service handling if it's OPTIONS request. if gstr.Equal(r.Request.Method, "OPTIONS") { - // Request method's handler searching. - // It here uses Server.routesMap attribute enhancing the searching performance. + // Request method handler searching. + // It here simply uses Server.routesMap attribute enhancing the searching performance. if method := r.Request.Header.Get("Access-Control-Request-Method"); method != "" { routerKey := "" for _, domain := range []string{gDEFAULT_DOMAIN, r.Request.GetHost()} { for _, v := range []string{gDEFAULT_METHOD, method} { - routerKey = r.Server.handlerKey("", v, r.Request.URL.Path, domain) + routerKey = r.Server.routerMapKey("", v, r.Request.URL.Path, domain) if r.Server.routesMap[routerKey] != nil { if r.Status == 0 { r.Status = http.StatusOK } + // No continue serving. r.Request.ExitAll() } } } } - // Cannot find the request handler. + // Cannot find the request serving handler, it then responses 404. if r.Status == 0 { r.Status = http.StatusNotFound } diff --git a/net/ghttp/ghttp_server_router.go b/net/ghttp/ghttp_server_router.go index 343d060b7..e27767893 100644 --- a/net/ghttp/ghttp_server_router.go +++ b/net/ghttp/ghttp_server_router.go @@ -28,8 +28,10 @@ var ( handlerIdGenerator = gtype.NewInt() ) -// handlerKey creates and returns an unique router key for given parameters. -func (s *Server) handlerKey(hook, method, path, domain string) string { +// routerMapKey creates and returns an unique router key for given parameters. +// This key is used for Server.routerMap attribute, which is mainly for checks for +// repeated router registering. +func (s *Server) routerMapKey(hook, method, path, domain string) string { return hook + "%" + s.serveHandlerKey(method, path, domain) } @@ -76,7 +78,7 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { } // Repeated router checks, this feature can be disabled by server configuration. - routerKey := s.handlerKey(handler.hookName, method, uri, domain) + routerKey := s.routerMapKey(handler.hookName, method, uri, domain) if !s.config.RouteOverWrite { switch handler.itemType { case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER: @@ -98,7 +100,7 @@ func (s *Server) setHandler(pattern string, handler *handlerItem) { if _, ok := s.serveTree[domain]; !ok { s.serveTree[domain] = make(map[string]interface{}) } - // List array, very important for router register. + // List array, very important for router registering. // There may be multiple lists adding into this array when searching from root to leaf. lists := make([]*glist.List, 0) array := ([]string)(nil) diff --git a/net/ghttp/ghttp_server_router_serve.go b/net/ghttp/ghttp_server_router_serve.go index 0fb563b8a..0ad031eba 100644 --- a/net/ghttp/ghttp_server_router_serve.go +++ b/net/ghttp/ghttp_server_router_serve.go @@ -15,13 +15,24 @@ import ( "github.com/gogf/gf/text/gregex" ) -// handlerCacheItem is an item for router searching cache. +// handlerCacheItem is an item just for internal router searching cache. type handlerCacheItem struct { parsedItems []*handlerParsedItem hasHook bool hasServe bool } +// serveHandlerKey creates and returns a handler key for router. +func (s *Server) serveHandlerKey(method, path, domain string) string { + if len(domain) > 0 { + domain = "@" + domain + } + if method == "" { + return path + strings.ToLower(domain) + } + return strings.ToUpper(method) + ":" + path + strings.ToLower(domain) +} + // getHandlersWithCache searches the router item with cache feature for given request. func (s *Server) getHandlersWithCache(r *Request) (parsedItems []*handlerParsedItem, hasHook, hasServe bool) { method := r.Method @@ -76,10 +87,12 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han if part == "" { continue } + // Add all list of each node to the list array. if v, ok := p.(map[string]interface{})["*list"]; ok { lists = append(lists, v.(*glist.List)) } if v, ok := p.(map[string]interface{})[part]; ok { + // Loop to the next node by certain key name. p = v if i == len(array)-1 { if v, ok := p.(map[string]interface{})["*list"]; ok { @@ -87,33 +100,36 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han break } } - } else { - if v, ok := p.(map[string]interface{})["*fuzz"]; ok { - p = v - } + } else if v, ok := p.(map[string]interface{})["*fuzz"]; ok { + // Loop to the next node by fuzzy node item. + p = v } - // 如果是叶子节点,同时判断当前层级的"*fuzz"键名,解决例如:/user/*action 匹配 /user 的规则 if i == len(array)-1 { + // It here also checks the fuzzy item, + // for rule case like: "/user/*action" matches to "/user". if v, ok := p.(map[string]interface{})["*fuzz"]; ok { p = v } + // The leaf must have a list item. It adds the list to the list array. if v, ok := p.(map[string]interface{})["*list"]; ok { lists = append(lists, v.(*glist.List)) } } } - // 多层链表遍历检索,从数组末尾的链表开始遍历,末尾的深度高优先级也高 + // OK, let's loop the result list array, adding the handler item to the result handler result array. + // As the tail of the list array has the most priority, it iterates the list array from its tail to head. for i := len(lists) - 1; i >= 0; i-- { for e := lists[i].Front(); e != nil; e = e.Next() { item := e.Value.(*handlerItem) - // 主要是用于路由注册函数的重复添加判断(特别是中间件和钩子函数) + // Filter repeated handler item, especially the middleware and hook handlers. + // It is necessary, do not remove this checks logic unless you really know how it is necessary. if _, ok := repeatHandlerCheckMap[item.itemId]; ok { continue } else { repeatHandlerCheckMap[item.itemId] = struct{}{} } - // 服务路由函数只能添加一次,将重复判断放在这里提高检索效率 + // Serving handler can only be added to the handler array just once. if hasServe { switch item.itemType { case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER: @@ -121,26 +137,29 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han } } if item.router.Method == gDEFAULT_METHOD || item.router.Method == method { - // 注意当不带任何动态路由规则时,len(match) == 1 + // Note the rule having no fuzzy rules: len(match) == 1 if match, err := gregex.MatchString(item.router.RegRule, path); err == nil && len(match) > 0 { parsedItem := &handlerParsedItem{item, nil} - // 如果需要路由规则中带有URI名称匹配,那么需要重新正则解析URL + // If the rule contains fuzzy names, + // it needs paring the URL to retrieve the values for the names. if len(item.router.RegNames) > 0 { if len(match) > len(item.router.RegNames) { parsedItem.values = make(map[string]string) - // 如果存在存在同名路由参数名称,那么执行覆盖 + // It there repeated names, it just overwrites the same one. for i, name := range item.router.RegNames { parsedItem.values[name] = match[i+1] } } } switch item.itemType { - // 服务路由函数只能添加一次 + // The serving handler can be only added just once. case gHANDLER_TYPE_HANDLER, gHANDLER_TYPE_OBJECT, gHANDLER_TYPE_CONTROLLER: hasServe = true parsedItemList.PushBack(parsedItem) - // 中间件需要排序在链表中服务函数之前,并且多个中间件按照顺序添加以便于后续执行 + // The middleware is inserted before the serving handler. + // If there're multiple middlewares, they're inserted into the result list by their registering order. + // The middlewares are also executed by their registering order. case gHANDLER_TYPE_MIDDLEWARE: if lastMiddlewareElem == nil { lastMiddlewareElem = parsedItemList.PushFront(parsedItem) @@ -148,7 +167,7 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han lastMiddlewareElem = parsedItemList.InsertAfter(lastMiddlewareElem, parsedItem) } - // 钩子函数存在性判断 + // HOOK handler, just push it back to the list. case gHANDLER_TYPE_HOOK: hasHook = true parsedItemList.PushBack(parsedItem) @@ -210,14 +229,3 @@ func (item *handlerItem) MarshalJSON() ([]byte, error) { func (item *handlerParsedItem) MarshalJSON() ([]byte, error) { return json.Marshal(item.handler) } - -// serveHandlerKey creates and returns a cache key for router. -func (s *Server) serveHandlerKey(method, path, domain string) string { - if len(domain) > 0 { - domain = "@" + domain - } - if method == "" { - return path + strings.ToLower(domain) - } - return strings.ToUpper(method) + ":" + path + strings.ToLower(domain) -} diff --git a/net/ghttp/ghttp_server_service_object.go b/net/ghttp/ghttp_server_service_object.go index a83c2b377..4fd135822 100644 --- a/net/ghttp/ghttp_server_service_object.go +++ b/net/ghttp/ghttp_server_service_object.go @@ -39,7 +39,7 @@ func (s *Server) BindObjectRest(pattern string, object interface{}) { } func (s *Server) doBindObject(pattern string, object interface{}, method string, middleware []HandlerFunc) { - // Convert input method to map for convenience and high performance searching. + // Convert input method to map for convenience and high performance searching purpose. var methodMap map[string]bool if len(method) > 0 { methodMap = make(map[string]bool) @@ -86,12 +86,16 @@ func (s *Server) doBindObject(pattern string, object interface{}, method string, if !ok { if len(methodMap) > 0 { // 指定的方法名称注册,那么需要使用错误提示 - s.Logger().Errorf(`invalid route method: %s.%s.%s defined as "%s", but "func(*ghttp.Request)" is required for object registry`, - pkgPath, objName, methodName, v.Method(i).Type().String()) + s.Logger().Errorf( + `invalid route method: %s.%s.%s defined as "%s", but "func(*ghttp.Request)" is required for object registry`, + pkgPath, objName, methodName, v.Method(i).Type().String(), + ) } else { // 否则只是Debug提示 - s.Logger().Debugf(`ignore route method: %s.%s.%s defined as "%s", no match "func(*ghttp.Request)"`, - pkgPath, objName, methodName, v.Method(i).Type().String()) + s.Logger().Debugf( + `ignore route method: %s.%s.%s defined as "%s", no match "func(*ghttp.Request)"`, + pkgPath, objName, methodName, v.Method(i).Type().String(), + ) } continue } From a161b44cc774ea63ff3723bb12207af03a9619cc Mon Sep 17 00:00:00 2001 From: John Date: Thu, 5 Mar 2020 18:07:07 +0800 Subject: [PATCH 05/26] improve package gpage --- util/gpage/gpage.go | 173 ++++++++++++++-------------------- util/gpage/gpage_unit_test.go | 116 +++++++++++++++++++++++ 2 files changed, 186 insertions(+), 103 deletions(-) create mode 100644 util/gpage/gpage_unit_test.go diff --git a/util/gpage/gpage.go b/util/gpage/gpage.go index 01b4e7459..e448ce080 100644 --- a/util/gpage/gpage.go +++ b/util/gpage/gpage.go @@ -9,125 +9,93 @@ package gpage import ( "fmt" - "math" - "net/url" - "strings" - - "github.com/gogf/gf/net/ghttp" - "github.com/gogf/gf/text/gregex" "github.com/gogf/gf/text/gstr" "github.com/gogf/gf/util/gconv" + "math" ) // Page is the pagination implementer. +// All the attributes are public, you can change them when necessary. type Page struct { - UrlTemplate string // Custom url template for page url producing. TotalSize int // Total size. TotalPage int // Total page, which is automatically calculated. CurrentPage int // Current page number >= 1. - PageName string // Page variable name. It's "page" in default. + UrlTemplate string // Custom url template for page url producing. + LinkStyle string // CSS style name for HTML link tag . + SpanStyle string // CSS style name for HTML span tag , which is used for first, current and last page tag. + SelectStyle string // CSS style name for HTML select tag ` + barContent := fmt.Sprintf(`" - return ret + barContent += "" + return barContent } -// 预定义的分页显示风格内容 +// GetContent returns the page content for predefined mode. +// These predefined contents are mainly for chinese localization purpose. You can defines your own +// page function retrieving the page content according to the implementation of this function. func (p *Page) GetContent(mode int) string { switch mode { case 1: @@ -200,7 +171,7 @@ func (p *Page) GetContent(mode int) string { p.LastPageTag = "尾页" pageStr := p.FirstPage() pageStr += p.PrevPage() - pageStr += p.PageBar("current") + pageStr += p.PageBar() pageStr += p.NextPage() pageStr += p.LastPage() pageStr += fmt.Sprintf( @@ -218,7 +189,7 @@ func (p *Page) GetContent(mode int) string { p.LastPageTag = "尾页" pageStr := p.FirstPage() pageStr += p.PrevPage() - pageStr += p.PageBar("current") + pageStr += p.PageBar() pageStr += p.NextPage() pageStr += p.LastPage() return pageStr @@ -226,28 +197,24 @@ func (p *Page) GetContent(mode int) string { return "" } -// 为指定的页面返回地址值 -func (p *Page) GetUrl(pageNo int) string { - pattern := fmt.Sprintf(`(:%s|\*%s|\.%s)`, p.PageName, p.PageName, p.PageName) - result, _ := gregex.ReplaceString(pattern, pageNo, p.UrlTemplate) - url.Path = gstr.Replace(p.UrlTemplate, "{.page}", gconv.String(pageNo)) - return url.String() - } - - values := p.Url.Query() - values.Set(p.PageName, gconv.String(pageNo)) - url.RawQuery = values.Encode() - return url.String() +// GetUrl parses the UrlTemplate with given page number and returns the URL string. +// Note that the UrlTemplate attribute can be either an URL or a URI string with "{.page}" +// place holder specifying the page number position. +func (p *Page) GetUrl(page int) string { + return gstr.Replace(p.UrlTemplate, "{.page}", gconv.String(page)) } -// 获取链接地址 -func (p *Page) GetLink(url, text, title, style string) string { - if len(style) > 0 { - style = fmt.Sprintf(`class="%s" `, style) - } +// GetLink returns the HTML link tag content for given page number. +func (p *Page) GetLink(page int, text, title string) string { if len(p.AjaxActionName) > 0 { - return fmt.Sprintf(`%s`, style, p.AjaxActionName, url, text) + return fmt.Sprintf( + `%s`, + p.LinkStyle, p.AjaxActionName, p.GetUrl(page), title, text, + ) } else { - return fmt.Sprintf(`%s`, style, url, title, text) + return fmt.Sprintf( + `%s`, + p.LinkStyle, p.GetUrl(page), title, text, + ) } } diff --git a/util/gpage/gpage_unit_test.go b/util/gpage/gpage_unit_test.go new file mode 100644 index 000000000..f1a6d9196 --- /dev/null +++ b/util/gpage/gpage_unit_test.go @@ -0,0 +1,116 @@ +// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +// go test *.go -bench=".*" + +package gpage_test + +import ( + "github.com/gogf/gf/util/gpage" + "testing" + + "github.com/gogf/gf/test/gtest" +) + +func Test_New(t *testing.T) { + gtest.Case(t, func() { + page := gpage.New(9, 2, 1, `/user/list?page={.page}`) + gtest.Assert(page.TotalSize, 9) + gtest.Assert(page.TotalPage, 5) + gtest.Assert(page.CurrentPage, 1) + }) + gtest.Case(t, func() { + page := gpage.New(9, 2, 0, `/user/list?page={.page}`) + gtest.Assert(page.TotalSize, 9) + gtest.Assert(page.TotalPage, 5) + gtest.Assert(page.CurrentPage, 1) + }) +} + +func Test_Basic(t *testing.T) { + gtest.Case(t, func() { + page := gpage.New(9, 2, 1, `/user/list?page={.page}`) + gtest.Assert(page.NextPage(), `>`) + gtest.Assert(page.PrevPage(), `<`) + gtest.Assert(page.FirstPage(), `|<`) + gtest.Assert(page.LastPage(), `>|`) + gtest.Assert(page.PageBar(), `12345`) + }) + + gtest.Case(t, func() { + page := gpage.New(9, 2, 3, `/user/list?page={.page}`) + gtest.Assert(page.NextPage(), `>`) + gtest.Assert(page.PrevPage(), `<`) + gtest.Assert(page.FirstPage(), `|<`) + gtest.Assert(page.LastPage(), `>|`) + gtest.Assert(page.PageBar(), `12345`) + }) + + gtest.Case(t, func() { + page := gpage.New(9, 2, 5, `/user/list?page={.page}`) + gtest.Assert(page.NextPage(), `>`) + gtest.Assert(page.PrevPage(), `<`) + gtest.Assert(page.FirstPage(), `|<`) + gtest.Assert(page.LastPage(), `>|`) + gtest.Assert(page.PageBar(), `12345`) + }) +} + +func Test_CustomTag(t *testing.T) { + gtest.Case(t, func() { + page := gpage.New(5, 1, 2, `/user/list/{.page}`) + page.PrevPageTag = "《" + page.NextPageTag = "》" + page.FirstPageTag = "|《" + page.LastPageTag = "》|" + page.PrevBarTag = "《《" + page.NextBarTag = "》》" + gtest.Assert(page.NextPage(), ``) + gtest.Assert(page.PrevPage(), ``) + gtest.Assert(page.FirstPage(), `|《`) + gtest.Assert(page.LastPage(), `》|`) + gtest.Assert(page.PageBar(), `12345`) + }) +} + +func Test_CustomStyle(t *testing.T) { + gtest.Case(t, func() { + page := gpage.New(5, 1, 2, `/user/list/{.page}`) + page.LinkStyle = "MyPageLink" + page.SpanStyle = "MyPageSpan" + page.SelectStyle = "MyPageSelect" + gtest.Assert(page.NextPage(), `>`) + gtest.Assert(page.PrevPage(), `<`) + gtest.Assert(page.FirstPage(), `|<`) + gtest.Assert(page.LastPage(), `>|`) + gtest.Assert(page.PageBar(), `12345`) + gtest.Assert(page.SelectBar(), ``) + }) +} + +func Test_Ajax(t *testing.T) { + gtest.Case(t, func() { + page := gpage.New(5, 1, 2, `/user/list/{.page}`) + page.AjaxActionName = "LoadPage" + gtest.Assert(page.NextPage(), `>`) + gtest.Assert(page.PrevPage(), `<`) + gtest.Assert(page.FirstPage(), `|<`) + gtest.Assert(page.LastPage(), `>|`) + gtest.Assert(page.PageBar(), `12345`) + }) +} + +func Test_PredefinedContent(t *testing.T) { + gtest.Case(t, func() { + page := gpage.New(5, 1, 2, `/user/list/{.page}`) + page.AjaxActionName = "LoadPage" + gtest.Assert(page.GetContent(1), `上一页 2 下一页`) + gtest.Assert(page.GetContent(2), `首页<<上一页[第2页]下一页>>尾页页`) + gtest.Assert(page.GetContent(3), `首页上一页12345下一页尾页当前页2/5 共5条`) + gtest.Assert(page.GetContent(4), `首页上一页12345下一页尾页`) + gtest.Assert(page.GetContent(5), ``) + }) +} From 93d0760898eb6da6343622305f3e642b1ebb8c32 Mon Sep 17 00:00:00 2001 From: John Date: Fri, 6 Mar 2020 11:01:03 +0800 Subject: [PATCH 06/26] add GetPage function for ghttp.Request --- .example/util/gpage/gpage.go | 3 +- .example/util/gpage/gpage_ajax.go | 15 ++++-- .example/util/gpage/gpage_custom1.go | 2 +- .example/util/gpage/gpage_custom2.go | 4 +- .example/util/gpage/gpage_static1.go | 3 +- .example/util/gpage/gpage_static2.go | 3 +- .example/util/gpage/gpage_template.go | 5 +- container/garray/garray_z_example_test.go | 4 +- container/gqueue/gqueue.go | 9 ++-- database/gdb/gdb_mssql.go | 2 +- database/gdb/gdb_oracle.go | 4 +- net/ghttp/ghttp_request_param_page.go | 64 +++++++++++++++++++++++ net/ghttp/ghttp_unit_param_page_test.go | 48 +++++++++++++++++ os/gtimer/gtimer.go | 22 ++++---- util/gpage/gpage.go | 7 ++- 15 files changed, 159 insertions(+), 36 deletions(-) create mode 100644 net/ghttp/ghttp_request_param_page.go create mode 100644 net/ghttp/ghttp_unit_param_page_test.go diff --git a/.example/util/gpage/gpage.go b/.example/util/gpage/gpage.go index ab75e32dc..9b6071f0d 100644 --- a/.example/util/gpage/gpage.go +++ b/.example/util/gpage/gpage.go @@ -4,13 +4,12 @@ import ( "github.com/gogf/gf/frame/g" "github.com/gogf/gf/net/ghttp" "github.com/gogf/gf/os/gview" - "github.com/gogf/gf/util/gpage" ) func main() { s := ghttp.GetServer() s.BindHandler("/page/demo", func(r *ghttp.Request) { - page := gpage.New(100, 10, r.Get("page"), r.URL.String()) + page := r.GetPage(100, 10) buffer, _ := gview.ParseContent(` diff --git a/.example/util/gpage/gpage_ajax.go b/.example/util/gpage/gpage_ajax.go index 299715ceb..84c4facfa 100644 --- a/.example/util/gpage/gpage_ajax.go +++ b/.example/util/gpage/gpage_ajax.go @@ -4,14 +4,13 @@ import ( "github.com/gogf/gf/frame/g" "github.com/gogf/gf/net/ghttp" "github.com/gogf/gf/os/gview" - "github.com/gogf/gf/util/gpage" ) func main() { s := ghttp.GetServer() s.BindHandler("/page/ajax", func(r *ghttp.Request) { - page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router) - page.EnableAjax("DoAjax") + page := r.GetPage(100, 10) + page.AjaxActionName = "DoAjax" buffer, _ := gview.ParseContent(` @@ -29,11 +28,17 @@ func main() { -
{{.page}}
+
{{.page1}}
+
{{.page2}}
+
{{.page3}}
+
{{.page4}}
`, g.Map{ - "page": page.GetContent(1), + "page1": page.GetContent(1), + "page2": page.GetContent(2), + "page3": page.GetContent(3), + "page4": page.GetContent(4), }) r.Response.Write(buffer) }) diff --git a/.example/util/gpage/gpage_custom1.go b/.example/util/gpage/gpage_custom1.go index 57cebaff1..9571c36aa 100644 --- a/.example/util/gpage/gpage_custom1.go +++ b/.example/util/gpage/gpage_custom1.go @@ -23,7 +23,7 @@ func wrapContent(page *gpage.Page) string { func main() { s := ghttp.GetServer() s.BindHandler("/page/custom1/*page", func(r *ghttp.Request) { - page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router) + page := r.GetPage(100, 10) content := wrapContent(page) buffer, _ := gview.ParseContent(` diff --git a/.example/util/gpage/gpage_custom2.go b/.example/util/gpage/gpage_custom2.go index f50428075..2e84042c2 100644 --- a/.example/util/gpage/gpage_custom2.go +++ b/.example/util/gpage/gpage_custom2.go @@ -15,7 +15,7 @@ func pageContent(page *gpage.Page) string { page.LastPageTag = "LastPage" pageStr := page.FirstPage() pageStr += page.PrevPage() - pageStr += page.PageBar("current-page") + pageStr += page.PageBar() pageStr += page.NextPage() pageStr += page.LastPage() return pageStr @@ -24,7 +24,7 @@ func pageContent(page *gpage.Page) string { func main() { s := ghttp.GetServer() s.BindHandler("/page/custom2/*page", func(r *ghttp.Request) { - page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router) + page := r.GetPage(100, 10) buffer, _ := gview.ParseContent(` diff --git a/.example/util/gpage/gpage_static1.go b/.example/util/gpage/gpage_static1.go index 1d1bbf798..c62a0851a 100644 --- a/.example/util/gpage/gpage_static1.go +++ b/.example/util/gpage/gpage_static1.go @@ -4,13 +4,12 @@ import ( "github.com/gogf/gf/frame/g" "github.com/gogf/gf/net/ghttp" "github.com/gogf/gf/os/gview" - "github.com/gogf/gf/util/gpage" ) func main() { s := g.Server() s.BindHandler("/page/static/*page", func(r *ghttp.Request) { - page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router) + page := r.GetPage(100, 10) buffer, _ := gview.ParseContent(` diff --git a/.example/util/gpage/gpage_static2.go b/.example/util/gpage/gpage_static2.go index d0056fbd0..135736fc0 100644 --- a/.example/util/gpage/gpage_static2.go +++ b/.example/util/gpage/gpage_static2.go @@ -4,13 +4,12 @@ import ( "github.com/gogf/gf/frame/g" "github.com/gogf/gf/net/ghttp" "github.com/gogf/gf/os/gview" - "github.com/gogf/gf/util/gpage" ) func main() { s := g.Server() s.BindHandler("/:obj/*action/{page}.html", func(r *ghttp.Request) { - page := gpage.New(100, 10, r.Get("page"), r.URL.String(), r.Router) + page := r.GetPage(100, 10) buffer, _ := gview.ParseContent(` diff --git a/.example/util/gpage/gpage_template.go b/.example/util/gpage/gpage_template.go index c4e25e3d9..a06c27e4a 100644 --- a/.example/util/gpage/gpage_template.go +++ b/.example/util/gpage/gpage_template.go @@ -4,14 +4,13 @@ import ( "github.com/gogf/gf/frame/g" "github.com/gogf/gf/net/ghttp" "github.com/gogf/gf/os/gview" - "github.com/gogf/gf/util/gpage" ) func main() { s := g.Server() s.BindHandler("/page/template/{page}.html", func(r *ghttp.Request) { - page := gpage.New(100, 10, r.Get("page"), r.URL.String()) - page.SetUrlTemplate("/order/list/{.page}.html") + page := r.GetPage(100, 10) + page.UrlTemplate = "/order/list/{.page}.html" buffer, _ := gview.ParseContent(` diff --git a/container/garray/garray_z_example_test.go b/container/garray/garray_z_example_test.go index 7ab4da733..eace2f821 100644 --- a/container/garray/garray_z_example_test.go +++ b/container/garray/garray_z_example_test.go @@ -13,8 +13,8 @@ import ( ) func Example_basic() { - // 创建普通的数组,默认并发安全(带锁) - a := garray.New(true) + // 创建普通的数组 + a := garray.New() // 添加数据项 for i := 0; i < 10; i++ { diff --git a/container/gqueue/gqueue.go b/container/gqueue/gqueue.go index 449c15adf..1ecfc9231 100644 --- a/container/gqueue/gqueue.go +++ b/container/gqueue/gqueue.go @@ -4,7 +4,7 @@ // If a copy of the MIT was not distributed with this file, // You can obtain one at https://github.com/gogf/gf. -// Package gqueue provides a dynamic/static concurrent-safe queue. +// Package gqueue provides dynamic/static concurrent-safe queue. // // Features: // @@ -25,6 +25,7 @@ import ( "github.com/gogf/gf/container/gtype" ) +// Queue is a concurrent-safe queue built on doubly linked list and channel. type Queue struct { limit int // Limit for queue size. list *glist.List // Underlying list structure for data maintaining. @@ -54,14 +55,14 @@ func New(limit ...int) *Queue { q.list = glist.New(true) q.events = make(chan struct{}, math.MaxInt32) q.C = make(chan interface{}, gDEFAULT_QUEUE_SIZE) - go q.startAsyncLoop() + go q.asyncLoopFromListToChannel() } return q } -// startAsyncLoop starts an asynchronous goroutine, +// asyncLoopFromListToChannel starts an asynchronous goroutine, // which handles the data synchronization from list to channel . -func (q *Queue) startAsyncLoop() { +func (q *Queue) asyncLoopFromListToChannel() { defer func() { if q.closed.Val() { _ = recover() diff --git a/database/gdb/gdb_mssql.go b/database/gdb/gdb_mssql.go index af183389c..f3d27bb25 100644 --- a/database/gdb/gdb_mssql.go +++ b/database/gdb/gdb_mssql.go @@ -176,7 +176,7 @@ func (db *dbMssql) Tables(schema ...string) (tables []string, err error) { } for _, m := range result { for _, v := range m { - tables = append(tables, strings.ToLower(v.String())) + tables = append(tables, v.String()) } } return diff --git a/database/gdb/gdb_oracle.go b/database/gdb/gdb_oracle.go index 751f46f72..686dce1fa 100644 --- a/database/gdb/gdb_oracle.go +++ b/database/gdb/gdb_oracle.go @@ -123,16 +123,16 @@ func (db *dbOracle) parseSql(sql string) string { } // Tables retrieves and returns the tables of current schema. +// Note that it ignores the parameter in oracle database, as it is not necessary. func (db *dbOracle) Tables(schema ...string) (tables []string, err error) { var result Result - result, err = db.doGetAll(nil, "SELECT TABLE_NAME FROM USER_TABLES ORDER BY TABLE_NAME") if err != nil { return } for _, m := range result { for _, v := range m { - tables = append(tables, strings.ToLower(v.String())) + tables = append(tables, v.String()) } } return diff --git a/net/ghttp/ghttp_request_param_page.go b/net/ghttp/ghttp_request_param_page.go new file mode 100644 index 000000000..f6cb94fa7 --- /dev/null +++ b/net/ghttp/ghttp_request_param_page.go @@ -0,0 +1,64 @@ +// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package ghttp + +import ( + "fmt" + "github.com/gogf/gf/text/gregex" + "github.com/gogf/gf/text/gstr" + "github.com/gogf/gf/util/gpage" +) + +// GetPage creates and returns the pagination object for given and . +// NOTE THAT the page parameter name from client is constantly defined as gpage.PAGE_NAME +// for simplification and convenience. +func (r *Request) GetPage(totalSize, pageSize int) *gpage.Page { + // It must has Router object attribute. + if r.Router == nil { + panic("Router object not found") + } + url := *r.URL + urlTemplate := url.Path + uriHasPageName := false + // Check the page variable in the URI. + if len(r.Router.RegNames) > 0 { + for _, name := range r.Router.RegNames { + if name == gpage.PAGE_NAME { + uriHasPageName = true + break + } + } + if uriHasPageName { + if match, err := gregex.MatchString(r.Router.RegRule, url.Path); err == nil && len(match) > 0 { + if len(match) > len(r.Router.RegNames) { + urlTemplate = r.Router.Uri + for i, name := range r.Router.RegNames { + rule := fmt.Sprintf(`[:\*]%s|\{%s\}`, name, name) + if name == gpage.PAGE_NAME { + urlTemplate, _ = gregex.ReplaceString(rule, gpage.PAGE_PLACE_HOLDER, urlTemplate) + } else { + urlTemplate, _ = gregex.ReplaceString(rule, match[i+1], urlTemplate) + } + } + } + } + } + } + // Check the page variable in the query string. + if !uriHasPageName { + values := url.Query() + values.Set(gpage.PAGE_NAME, gpage.PAGE_PLACE_HOLDER) + url.RawQuery = values.Encode() + // Replace the encodes "{.page}" to "{.page}". + url.RawQuery = gstr.Replace(url.RawQuery, "%7B.page%7D", "{.page}") + } + if url.RawQuery != "" { + urlTemplate += "?" + url.RawQuery + } + + return gpage.New(totalSize, pageSize, r.GetInt(gpage.PAGE_NAME), urlTemplate) +} diff --git a/net/ghttp/ghttp_unit_param_page_test.go b/net/ghttp/ghttp_unit_param_page_test.go new file mode 100644 index 000000000..865781b08 --- /dev/null +++ b/net/ghttp/ghttp_unit_param_page_test.go @@ -0,0 +1,48 @@ +// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package ghttp_test + +import ( + "fmt" + "testing" + "time" + + "github.com/gogf/gf/frame/g" + "github.com/gogf/gf/net/ghttp" + "github.com/gogf/gf/test/gtest" +) + +func Test_Params_Page(t *testing.T) { + p := ports.PopRand() + s := g.Server(p) + s.Group("/", func(group *ghttp.RouterGroup) { + group.GET("/list", func(r *ghttp.Request) { + page := r.GetPage(5, 2) + r.Response.Write(page.GetContent(4)) + }) + group.GET("/list/{page}.html", func(r *ghttp.Request) { + page := r.GetPage(5, 2) + r.Response.Write(page.GetContent(4)) + }) + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + + time.Sleep(100 * time.Millisecond) + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + gtest.Assert(client.GetContent("/list"), `首页上一页123下一页尾页`) + gtest.Assert(client.GetContent("/list?page=3"), `首页上一页123下一页尾页`) + + gtest.Assert(client.GetContent("/list/1.html"), `首页上一页123下一页尾页`) + gtest.Assert(client.GetContent("/list/3.html"), `首页上一页123下一页尾页`) + }) +} diff --git a/os/gtimer/gtimer.go b/os/gtimer/gtimer.go index d7b9c302c..1147ae078 100644 --- a/os/gtimer/gtimer.go +++ b/os/gtimer/gtimer.go @@ -4,16 +4,18 @@ // If a copy of the MIT was not distributed with this file, // You can obtain one at https://github.com/gogf/gf. -// Package gtimer implements Hierarchical Timing Wheel for interval/delayed jobs running and management. +// Package gtimer implements Hierarchical Timing Wheel for interval/delayed jobs +// running and management. // -// This package is designed for management for millions of timing jobs. -// The differences between gtime and gcron are as follows: -// 1. gcron is implemented based on gtimer. +// This package is designed for management for millions of timing jobs. The differences +// between gtimer and gcron are as follows: +// 1. package gcron is implemented based on package gtimer. // 2. gtimer is designed for high performance and for millions of timing jobs. -// 3. gcron supports pattern grammar like linux crontab. -// 4. gtimer's benchmark OP is measured in nanoseconds, and gcron's benchmark OP is measured in microseconds. +// 3. gcron supports configuration pattern grammar like linux crontab, which is more manually readable. +// 4. gtimer's benchmark OP is measured in nanoseconds, and gcron's benchmark OP is measured +// in microseconds. // -// Note the common delay of the timer: https://github.com/golang/go/issues/14410 +// ALSO VERY NOTE the common delay of the timer: https://github.com/golang/go/issues/14410 package gtimer import ( @@ -119,8 +121,10 @@ func DelayAddTimes(delay time.Duration, interval time.Duration, times int, job J defaultTimer.DelayAddTimes(delay, interval, times, job) } -// Exit is used in timing job, which exits and marks it closed from timer. -// The timing job will be removed from timer later. +// Exit is used in timing job internally, which exits and marks it closed from timer. +// The timing job will be automatically removed from timer later. It uses "panic-recover" +// mechanism internally implementing this feature, which is designed for simplification +// and convenience. func Exit() { panic(gPANIC_EXIT) } diff --git a/util/gpage/gpage.go b/util/gpage/gpage.go index e448ce080..4b240247d 100644 --- a/util/gpage/gpage.go +++ b/util/gpage/gpage.go @@ -34,6 +34,11 @@ type Page struct { AjaxActionName string // Ajax function name. Ajax is enabled if this attribute is not empty. } +const ( + PAGE_NAME = "page" // PAGE_NAME defines the default page name. + PAGE_PLACE_HOLDER = "{.page}" // PAGE_PLACE_HOLDER defines the place holder for the url template. +) + // New creates and returns a pagination manager. // Note that the parameter specifies the URL producing template, like: // /user/list/{.page}, /user/list/{.page}.html, /user/list?page={.page}&type=1, etc. @@ -201,7 +206,7 @@ func (p *Page) GetContent(mode int) string { // Note that the UrlTemplate attribute can be either an URL or a URI string with "{.page}" // place holder specifying the page number position. func (p *Page) GetUrl(page int) string { - return gstr.Replace(p.UrlTemplate, "{.page}", gconv.String(page)) + return gstr.Replace(p.UrlTemplate, PAGE_PLACE_HOLDER, gconv.String(page)) } // GetLink returns the HTML link tag content for given page number. From 31f19b0eee8c2c91c8f317cc17d46a452a5143e2 Mon Sep 17 00:00:00 2001 From: John Date: Fri, 6 Mar 2020 15:38:32 +0800 Subject: [PATCH 07/26] improve package gcompress --- .../gcompress/gcompress_z_unit_gzip_test.go | 39 +++++ .../gcompress/gcompress_z_unit_zip_test.go | 153 ++++++++++++++++++ ..._test.go => gcompress_z_unit_zlib_test.go} | 28 +--- encoding/gcompress/gcompress_zip_file.go | 35 +++- encoding/gcompress/testdata/zip/path1/1.txt | 1 + encoding/gcompress/testdata/zip/path2/2.txt | 1 + 6 files changed, 222 insertions(+), 35 deletions(-) create mode 100644 encoding/gcompress/gcompress_z_unit_gzip_test.go create mode 100644 encoding/gcompress/gcompress_z_unit_zip_test.go rename encoding/gcompress/{gcompress_test.go => gcompress_z_unit_zlib_test.go} (61%) create mode 100644 encoding/gcompress/testdata/zip/path1/1.txt create mode 100644 encoding/gcompress/testdata/zip/path2/2.txt diff --git a/encoding/gcompress/gcompress_z_unit_gzip_test.go b/encoding/gcompress/gcompress_z_unit_gzip_test.go new file mode 100644 index 000000000..712261a03 --- /dev/null +++ b/encoding/gcompress/gcompress_z_unit_gzip_test.go @@ -0,0 +1,39 @@ +// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gcompress_test + +import ( + "testing" + + "github.com/gogf/gf/encoding/gcompress" + "github.com/gogf/gf/test/gtest" +) + +func Test_Gzip_UnGzip(t *testing.T) { + src := "Hello World!!" + + gzip := []byte{ + 0x1f, 0x8b, 0x08, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xff, + 0xf2, 0x48, 0xcd, 0xc9, 0xc9, + 0x57, 0x08, 0xcf, 0x2f, 0xca, + 0x49, 0x51, 0x54, 0x04, 0x04, + 0x00, 0x00, 0xff, 0xff, 0x9d, + 0x24, 0xa8, 0xd1, 0x0d, 0x00, + 0x00, 0x00, + } + + arr := []byte(src) + data, _ := gcompress.Gzip(arr) + gtest.Assert(data, gzip) + + data, _ = gcompress.UnGzip(gzip) + gtest.Assert(data, arr) + + data, _ = gcompress.UnGzip(gzip[1:]) + gtest.Assert(data, nil) +} diff --git a/encoding/gcompress/gcompress_z_unit_zip_test.go b/encoding/gcompress/gcompress_z_unit_zip_test.go new file mode 100644 index 000000000..9fb0f3b1f --- /dev/null +++ b/encoding/gcompress/gcompress_z_unit_zip_test.go @@ -0,0 +1,153 @@ +// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package gcompress_test + +import ( + "bytes" + "github.com/gogf/gf/debug/gdebug" + "github.com/gogf/gf/encoding/gcompress" + "github.com/gogf/gf/os/gfile" + "github.com/gogf/gf/os/gtime" + "testing" + + "github.com/gogf/gf/test/gtest" +) + +func Test_ZipPath(t *testing.T) { + // file + gtest.Case(t, func() { + srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path1", "1.txt") + dstPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "zip.zip") + + gtest.Assert(gfile.Exists(dstPath), false) + err := gcompress.ZipPath(srcPath, dstPath) + gtest.Assert(err, nil) + gtest.Assert(gfile.Exists(dstPath), true) + defer gfile.Remove(dstPath) + + tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr()) + err = gfile.Mkdir(tempDirPath) + gtest.Assert(err, nil) + + err = gcompress.UnZipFile(dstPath, tempDirPath) + gtest.Assert(err, nil) + defer gfile.Remove(tempDirPath) + + gtest.Assert( + gfile.GetContents(gfile.Join(tempDirPath, "1.txt")), + gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")), + ) + }) + // directory + gtest.Case(t, func() { + srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip") + dstPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "zip.zip") + + pwd := gfile.Pwd() + err := gfile.Chdir(srcPath) + defer gfile.Chdir(pwd) + gtest.Assert(err, nil) + + gtest.Assert(gfile.Exists(dstPath), false) + err = gcompress.ZipPath(srcPath, dstPath) + gtest.Assert(err, nil) + gtest.Assert(gfile.Exists(dstPath), true) + defer gfile.Remove(dstPath) + + tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr()) + err = gfile.Mkdir(tempDirPath) + gtest.Assert(err, nil) + + err = gcompress.UnZipFile(dstPath, tempDirPath) + gtest.Assert(err, nil) + defer gfile.Remove(tempDirPath) + + gtest.Assert( + gfile.GetContents(gfile.Join(tempDirPath, "zip", "path1", "1.txt")), + gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")), + ) + gtest.Assert( + gfile.GetContents(gfile.Join(tempDirPath, "zip", "path2", "2.txt")), + gfile.GetContents(gfile.Join(srcPath, "path2", "2.txt")), + ) + }) + // multiple paths joined using char ',' + gtest.Case(t, func() { + srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip") + srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path1") + srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path2") + dstPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "zip.zip") + + pwd := gfile.Pwd() + err := gfile.Chdir(srcPath) + defer gfile.Chdir(pwd) + gtest.Assert(err, nil) + + gtest.Assert(gfile.Exists(dstPath), false) + err = gcompress.ZipPath(srcPath1+", "+srcPath2, dstPath) + gtest.Assert(err, nil) + gtest.Assert(gfile.Exists(dstPath), true) + defer gfile.Remove(dstPath) + + tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr()) + err = gfile.Mkdir(tempDirPath) + gtest.Assert(err, nil) + + zipContent := gfile.GetBytes(dstPath) + gtest.AssertGT(len(zipContent), 0) + err = gcompress.UnZipContent(zipContent, tempDirPath) + gtest.Assert(err, nil) + defer gfile.Remove(tempDirPath) + + gtest.Assert( + gfile.GetContents(gfile.Join(tempDirPath, "path1", "1.txt")), + gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")), + ) + gtest.Assert( + gfile.GetContents(gfile.Join(tempDirPath, "path2", "2.txt")), + gfile.GetContents(gfile.Join(srcPath, "path2", "2.txt")), + ) + }) +} + +func Test_ZipPathWriter(t *testing.T) { + gtest.Case(t, func() { + srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip") + srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path1") + srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "zip", "path2") + + pwd := gfile.Pwd() + err := gfile.Chdir(srcPath) + defer gfile.Chdir(pwd) + gtest.Assert(err, nil) + + writer := bytes.NewBuffer(nil) + gtest.Assert(writer.Len(), 0) + err = gcompress.ZipPathWriter(srcPath1+", "+srcPath2, writer) + gtest.Assert(err, nil) + gtest.AssertGT(writer.Len(), 0) + + tempDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr()) + err = gfile.Mkdir(tempDirPath) + gtest.Assert(err, nil) + + zipContent := writer.Bytes() + gtest.AssertGT(len(zipContent), 0) + err = gcompress.UnZipContent(zipContent, tempDirPath) + gtest.Assert(err, nil) + defer gfile.Remove(tempDirPath) + + gtest.Assert( + gfile.GetContents(gfile.Join(tempDirPath, "path1", "1.txt")), + gfile.GetContents(gfile.Join(srcPath, "path1", "1.txt")), + ) + gtest.Assert( + gfile.GetContents(gfile.Join(tempDirPath, "path2", "2.txt")), + gfile.GetContents(gfile.Join(srcPath, "path2", "2.txt")), + ) + }) +} diff --git a/encoding/gcompress/gcompress_test.go b/encoding/gcompress/gcompress_z_unit_zlib_test.go similarity index 61% rename from encoding/gcompress/gcompress_test.go rename to encoding/gcompress/gcompress_z_unit_zlib_test.go index 1b563f716..68ea18703 100644 --- a/encoding/gcompress/gcompress_test.go +++ b/encoding/gcompress/gcompress_z_unit_zlib_test.go @@ -13,7 +13,7 @@ import ( "github.com/gogf/gf/test/gtest" ) -func TestZlib(t *testing.T) { +func Test_Zlib_UnZlib(t *testing.T) { gtest.Case(t, func() { src := "hello, world\n" dst := []byte{120, 156, 202, 72, 205, 201, 201, 215, 81, 40, 207, 47, 202, 73, 225, 2, 4, 0, 0, 255, 255, 33, 231, 4, 147} @@ -31,30 +31,4 @@ func TestZlib(t *testing.T) { data, _ = gcompress.UnZlib(dst[1:]) gtest.Assert(data, nil) }) - -} - -func TestGzip(t *testing.T) { - src := "Hello World!!" - - gzip := []byte{ - 0x1f, 0x8b, 0x08, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0xff, - 0xf2, 0x48, 0xcd, 0xc9, 0xc9, - 0x57, 0x08, 0xcf, 0x2f, 0xca, - 0x49, 0x51, 0x54, 0x04, 0x04, - 0x00, 0x00, 0xff, 0xff, 0x9d, - 0x24, 0xa8, 0xd1, 0x0d, 0x00, - 0x00, 0x00, - } - - arr := []byte(src) - data, _ := gcompress.Gzip(arr) - gtest.Assert(data, gzip) - - data, _ = gcompress.UnGzip(gzip) - gtest.Assert(data, arr) - - data, _ = gcompress.UnGzip(gzip[1:]) - gtest.Assert(data, nil) } diff --git a/encoding/gcompress/gcompress_zip_file.go b/encoding/gcompress/gcompress_zip_file.go index 595138d2d..c36839fed 100644 --- a/encoding/gcompress/gcompress_zip_file.go +++ b/encoding/gcompress/gcompress_zip_file.go @@ -9,6 +9,7 @@ package gcompress import ( "archive/zip" "bytes" + "github.com/gogf/gf/internal/intlog" "io" "os" "path/filepath" @@ -32,7 +33,15 @@ func ZipPath(paths, dest string, prefix ...string) error { return err } defer writer.Close() - return ZipPathWriter(paths, writer, prefix...) + zipWriter := zip.NewWriter(writer) + defer zipWriter.Close() + for _, path := range strings.Split(paths, ",") { + path = strings.TrimSpace(path) + if err := doZipPathWriter(path, gfile.RealPath(dest), zipWriter, prefix...); err != nil { + return err + } + } + return nil } // ZipPathWriter compresses to using zip compressing algorithm. @@ -45,17 +54,21 @@ func ZipPathWriter(paths string, writer io.Writer, prefix ...string) error { defer zipWriter.Close() for _, path := range strings.Split(paths, ",") { path = strings.TrimSpace(path) - if err := doZipPathWriter(path, zipWriter, prefix...); err != nil { + if err := doZipPathWriter(path, "", zipWriter, prefix...); err != nil { return err } } return nil } -func doZipPathWriter(path string, zipWriter *zip.Writer, prefix ...string) error { +// doZipPathWriter compresses the file of given and writes the content to . +// The parameter specifies the exclusive file path that is not compressed to , +// commonly the destination zip file path. +// The unnecessary parameter indicates the path prefix for zip file. +func doZipPathWriter(path string, exclude string, zipWriter *zip.Writer, prefix ...string) error { var err error var files []string - realPath, err := gfile.Search(path) + path, err = gfile.Search(path) if err != nil { return err } @@ -80,7 +93,11 @@ func doZipPathWriter(path string, zipWriter *zip.Writer, prefix ...string) error } headerPrefix = strings.Replace(headerPrefix, "//", "/", -1) for _, file := range files { - err := zipFile(file, headerPrefix+gfile.Dir(file[len(realPath):]), zipWriter) + if exclude == file { + intlog.Printf(`exclude file path: %s`, file) + continue + } + err := zipFile(file, headerPrefix+gfile.Dir(file[len(path):]), zipWriter) if err != nil { return err } @@ -101,10 +118,10 @@ func doZipPathWriter(path string, zipWriter *zip.Writer, prefix ...string) error } // UnZipFile decompresses to using zip compressing algorithm. -// The parameter specifies the unzipped path of , +// The optional parameter specifies the unzipped path of , // which can be used to specify part of the archive file to unzip. // -// Note thate the parameter should be a directory. +// Note that the parameter should be a directory. func UnZipFile(archive, dest string, path ...string) error { readerCloser, err := zip.OpenReader(archive) if err != nil { @@ -118,7 +135,7 @@ func UnZipFile(archive, dest string, path ...string) error { // The parameter specifies the unzipped path of , // which can be used to specify part of the archive file to unzip. // -// Note thate the parameter should be a directory. +// Note that the parameter should be a directory. func UnZipContent(data []byte, dest string, path ...string) error { reader, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) if err != nil { @@ -178,6 +195,8 @@ func unZipFileWithReader(reader *zip.Reader, dest string, path ...string) error return nil } +// zipFile compresses the file of given and writes the content to . +// The parameter indicates the path prefix for zip file. func zipFile(path string, prefix string, zw *zip.Writer) error { file, err := os.Open(path) if err != nil { diff --git a/encoding/gcompress/testdata/zip/path1/1.txt b/encoding/gcompress/testdata/zip/path1/1.txt new file mode 100644 index 000000000..8529b38b3 --- /dev/null +++ b/encoding/gcompress/testdata/zip/path1/1.txt @@ -0,0 +1 @@ +This is a test file for zip compression purpose. \ No newline at end of file diff --git a/encoding/gcompress/testdata/zip/path2/2.txt b/encoding/gcompress/testdata/zip/path2/2.txt new file mode 100644 index 000000000..f51884841 --- /dev/null +++ b/encoding/gcompress/testdata/zip/path2/2.txt @@ -0,0 +1 @@ +This is an another test file for zip compression purpose. \ No newline at end of file From 7f0163d958db79c2174ea9eced7a2950c5150844 Mon Sep 17 00:00:00 2001 From: John Date: Fri, 6 Mar 2020 23:22:08 +0800 Subject: [PATCH 08/26] improve gconv.Struct* functions for custom types conversion --- .example/other/test.go | 21 +++++++++++-- database/gdb/gdb_unit_z_mysql_model_test.go | 22 ++++++++++++++ util/gconv/gconv.go | 4 ++- util/gconv/gconv_struct.go | 33 ++++++++++++++------- util/gconv/gconv_z_unit_struct_test.go | 30 +++++++++++++++++++ 5 files changed, 96 insertions(+), 14 deletions(-) diff --git a/.example/other/test.go b/.example/other/test.go index 92e1cfb76..4e9c48ba9 100644 --- a/.example/other/test.go +++ b/.example/other/test.go @@ -1,10 +1,25 @@ package main import ( - "github.com/gogf/gf/container/garray" + "fmt" + "github.com/gogf/gf/frame/g" + "github.com/gogf/gf/util/gconv" ) +type MyInt int + +//func (i *MyInt) UnmarshalValue(interface{}) error { +// *i = 10 +// return nil +//} func main() { - arr := garray.NewStrArray(false) - arr.Unique() + type User struct { + Id MyInt + } + user := new(User) + err := gconv.Struct(g.Map{ + "id": 1, + }, user) + fmt.Println(err) + fmt.Println(user) } diff --git a/database/gdb/gdb_unit_z_mysql_model_test.go b/database/gdb/gdb_unit_z_mysql_model_test.go index 4f523a66b..f25fd1d39 100644 --- a/database/gdb/gdb_unit_z_mysql_model_test.go +++ b/database/gdb/gdb_unit_z_mysql_model_test.go @@ -681,6 +681,28 @@ func Test_Model_Struct(t *testing.T) { }) } +func Test_Model_Struct_CustomType(t *testing.T) { + table := createInitTable() + defer dropTable(table) + + type MyInt int + + gtest.Case(t, func() { + type User struct { + Id MyInt + Passport string + Password string + NickName string + CreateTime gtime.Time + } + user := new(User) + err := db.Table(table).Where("id=1").Struct(user) + gtest.Assert(err, nil) + gtest.Assert(user.NickName, "name_1") + gtest.Assert(user.CreateTime.String(), "2018-10-24 10:00:00") + }) +} + func Test_Model_Structs(t *testing.T) { table := createInitTable() defer dropTable(table) diff --git a/util/gconv/gconv.go b/util/gconv/gconv.go index 54a99ddc3..c34d0a569 100644 --- a/util/gconv/gconv.go +++ b/util/gconv/gconv.go @@ -48,7 +48,8 @@ var ( ) // Convert converts the variable to the type , the type is specified by string. -// The optional parameter is used for additional parameter passing. +// The optional parameter is used for additional necessary parameter for this conversion. +// It supports common types conversion as its conversion based on type name string. func Convert(i interface{}, t string, params ...interface{}) interface{} { switch t { case "int": @@ -121,6 +122,7 @@ func Convert(i interface{}, t string, params ...interface{}) interface{} { case "Duration", "time.Duration": return Duration(i) default: + return i } } diff --git a/util/gconv/gconv_struct.go b/util/gconv/gconv_struct.go index eade34f9d..46fdeea40 100644 --- a/util/gconv/gconv_struct.go +++ b/util/gconv/gconv_struct.go @@ -242,6 +242,7 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e if !structFieldValue.CanSet() { return nil } + // If any panic, it secondly uses reflect conversion and assignment. defer func() { if recover() != nil { err = bindVarToReflectValue(structFieldValue, value) @@ -250,6 +251,7 @@ func bindVarToStructByIndex(elem reflect.Value, index int, value interface{}) (e if empty.IsNil(value) { structFieldValue.Set(reflect.Zero(structFieldValue.Type())) } else { + // It firstly simply assigns the value to the attribute. structFieldValue.Set(reflect.ValueOf(Convert(value, structFieldValue.Type().String()))) } return nil @@ -260,7 +262,8 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e switch structFieldValue.Kind() { case reflect.Struct: if err := Struct(value, structFieldValue); err != nil { - structFieldValue.Set(reflect.ValueOf(value)) + // Note there's reflect conversion mechanism here. + structFieldValue.Set(reflect.ValueOf(value).Convert(structFieldValue.Type())) } // Note that the slice element might be type of struct, // so it uses Struct function doing the converting internally. @@ -275,13 +278,15 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e if t.Kind() == reflect.Ptr { e := reflect.New(t.Elem()).Elem() if err := Struct(v.Index(i).Interface(), e); err != nil { - e.Set(reflect.ValueOf(v.Index(i).Interface())) + // Note there's reflect conversion mechanism here. + e.Set(reflect.ValueOf(v.Index(i).Interface()).Convert(t)) } a.Index(i).Set(e.Addr()) } else { e := reflect.New(t).Elem() if err := Struct(v.Index(i).Interface(), e); err != nil { - e.Set(reflect.ValueOf(v.Index(i).Interface())) + // Note there's reflect conversion mechanism here. + e.Set(reflect.ValueOf(v.Index(i).Interface()).Convert(t)) } a.Index(i).Set(e) } @@ -293,13 +298,15 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e if t.Kind() == reflect.Ptr { e := reflect.New(t.Elem()).Elem() if err := Struct(value, e); err != nil { - e.Set(reflect.ValueOf(value)) + // Note there's reflect conversion mechanism here. + e.Set(reflect.ValueOf(value).Convert(t)) } a.Index(0).Set(e.Addr()) } else { e := reflect.New(t).Elem() if err := Struct(value, e); err != nil { - e.Set(reflect.ValueOf(value)) + // Note there's reflect conversion mechanism here. + e.Set(reflect.ValueOf(value).Convert(t)) } a.Index(0).Set(e) } @@ -311,34 +318,40 @@ func bindVarToReflectValue(structFieldValue reflect.Value, value interface{}) (e // Assign value with interface Set. // Note that only pointer can implement interface Set. if v, ok := item.Interface().(apiUnmarshalValue); ok { - v.UnmarshalValue(value) + err = v.UnmarshalValue(value) structFieldValue.Set(item) - return nil + return err } elem := item.Elem() if err = bindVarToReflectValue(elem, value); err == nil { structFieldValue.Set(elem.Addr()) } + // It mainly and specially handles the interface of nil value. case reflect.Interface: if value == nil { + // Specially. structFieldValue.Set(reflect.ValueOf((*interface{})(nil))) } else { - structFieldValue.Set(reflect.ValueOf(value)) + // Note there's reflect conversion mechanism here. + structFieldValue.Set(reflect.ValueOf(value).Convert(structFieldValue.Type())) } default: defer func() { if e := recover(); e != nil { err = errors.New( - fmt.Sprintf(`cannot convert "%d" to type "%s"`, + fmt.Sprintf(`cannot convert value "%d" to type "%s"`, value, structFieldValue.Type().String(), ), ) } }() - structFieldValue.Set(reflect.ValueOf(value)) + // It here uses reflect converting to type of the attribute and assigns + // the result value to the attribute. It might fail and panic if the usual Go + // conversion rules do not allow conversion. + structFieldValue.Set(reflect.ValueOf(value).Convert(structFieldValue.Type())) } return nil } diff --git a/util/gconv/gconv_z_unit_struct_test.go b/util/gconv/gconv_z_unit_struct_test.go index 9ba5fdf9f..767117b87 100644 --- a/util/gconv/gconv_z_unit_struct_test.go +++ b/util/gconv/gconv_z_unit_struct_test.go @@ -324,6 +324,36 @@ func Test_Struct_Attr_Struct_Slice_Ptr(t *testing.T) { }) } +func Test_Struct_Attr_CustomType1(t *testing.T) { + type MyInt int + type User struct { + Id MyInt + Name string + } + gtest.Case(t, func() { + user := new(User) + err := gconv.Struct(g.Map{"id": 1, "name": "john"}, user) + gtest.Assert(err, nil) + gtest.Assert(user.Id, 1) + gtest.Assert(user.Name, "john") + }) +} + +func Test_Struct_Attr_CustomType2(t *testing.T) { + type MyInt int + type User struct { + Id []MyInt + Name string + } + gtest.Case(t, func() { + user := new(User) + err := gconv.Struct(g.Map{"id": g.Slice{1, 2}, "name": "john"}, user) + gtest.Assert(err, nil) + gtest.Assert(user.Id, g.Slice{1, 2}) + gtest.Assert(user.Name, "john") + }) +} + func Test_Struct_PrivateAttribute(t *testing.T) { type User struct { Id int From a34ca0ff4bfa759488c72cdfc590675306bfe6cb Mon Sep 17 00:00:00 2001 From: John Date: Sat, 7 Mar 2020 19:31:33 +0800 Subject: [PATCH 09/26] improve uploading file feature for ghttp.Server; improve package gfile/gstr/gdebug --- debug/gdebug/gdebug.go | 18 ++- net/ghttp/ghttp_request_param.go | 15 +++ net/ghttp/ghttp_request_param_file.go | 45 ++++---- net/ghttp/ghttp_unit_param_file_test.go | 146 ++++++++++++++++++++++++ net/ghttp/testdata/upload/file1.txt | 1 + net/ghttp/testdata/upload/file2.txt | 1 + os/gfile/gfile.go | 43 ++++--- text/gstr/gstr.go | 10 +- 8 files changed, 235 insertions(+), 44 deletions(-) create mode 100644 net/ghttp/ghttp_unit_param_file_test.go create mode 100644 net/ghttp/testdata/upload/file1.txt create mode 100644 net/ghttp/testdata/upload/file2.txt diff --git a/debug/gdebug/gdebug.go b/debug/gdebug/gdebug.go index 6809ccfa2..e4d897a7f 100644 --- a/debug/gdebug/gdebug.go +++ b/debug/gdebug/gdebug.go @@ -10,6 +10,9 @@ package gdebug import ( "bytes" "fmt" + "io/ioutil" + "os" + "os/exec" "path/filepath" "reflect" "runtime" @@ -19,7 +22,6 @@ import ( "github.com/gogf/gf/encoding/ghash" "github.com/gogf/gf/crypto/gmd5" - "github.com/gogf/gf/os/gfile" ) const ( @@ -31,20 +33,30 @@ var ( goRootForFilter = runtime.GOROOT() // goRootForFilter is used for stack filtering purpose. binaryVersion = "" // The version of current running binary(uint64 hex). binaryVersionMd5 = "" // The version of current running binary(MD5). + selfPath = "" // Current running binary absolute path. ) func init() { if goRootForFilter != "" { goRootForFilter = strings.Replace(goRootForFilter, "\\", "/", -1) } + // Initialize internal package variable: selfPath. + selfPath, _ := exec.LookPath(os.Args[0]) + if selfPath != "" { + selfPath, _ = filepath.Abs(selfPath) + } + if selfPath == "" { + selfPath, _ = filepath.Abs(os.Args[0]) + } } // BinVersion returns the version of current running binary. // It uses ghash.BKDRHash+BASE36 algorithm to calculate the unique version of the binary. func BinVersion() string { if binaryVersion == "" { + binaryContent, _ := ioutil.ReadFile(selfPath) binaryVersion = strconv.FormatInt( - int64(ghash.BKDRHash(gfile.GetBytes(gfile.SelfPath()))), + int64(ghash.BKDRHash(binaryContent)), 36, ) } @@ -55,7 +67,7 @@ func BinVersion() string { // It uses MD5 algorithm to calculate the unique version of the binary. func BinVersionMd5() string { if binaryVersionMd5 == "" { - binaryVersionMd5, _ = gmd5.EncryptFile(gfile.SelfPath()) + binaryVersionMd5, _ = gmd5.EncryptFile(selfPath) } return binaryVersionMd5 } diff --git a/net/ghttp/ghttp_request_param.go b/net/ghttp/ghttp_request_param.go index 8b5698350..6b66a820b 100644 --- a/net/ghttp/ghttp_request_param.go +++ b/net/ghttp/ghttp_request_param.go @@ -9,6 +9,7 @@ package ghttp import ( "bytes" "encoding/json" + "fmt" "github.com/gogf/gf/container/gvar" "github.com/gogf/gf/encoding/gjson" "github.com/gogf/gf/encoding/gurl" @@ -302,5 +303,19 @@ func (r *Request) GetMultipartFiles(name string) []*multipart.FileHeader { if v := form.File[name+"[]"]; len(v) > 0 { return v } + // Support "name[0]","name[1]","name[2]", etc. as array parameter. + key := "" + files := make([]*multipart.FileHeader, 0) + for i := 0; ; i++ { + key = fmt.Sprintf(`%s[%d]`, name, i) + if v := form.File[key]; len(v) > 0 { + files = append(files, v[0]) + } else { + break + } + } + if len(files) > 0 { + return files + } return nil } diff --git a/net/ghttp/ghttp_request_param_file.go b/net/ghttp/ghttp_request_param_file.go index e91f5b3e7..6b0de4f70 100644 --- a/net/ghttp/ghttp_request_param_file.go +++ b/net/ghttp/ghttp_request_param_file.go @@ -26,7 +26,7 @@ type UploadFile struct { // UploadFiles is array type for *UploadFile. type UploadFiles []*UploadFile -// Save saves the single uploading file to specified path. +// Save saves the single uploading file to specified path and returns the saved file name. // The parameter path can be either a directory or a file path. If is a directory, // it saves the uploading file to the directory using its original name. If is a // file path, it saves the uploading file to the file path. @@ -35,56 +35,61 @@ type UploadFiles []*UploadFile // make sense if the is a directory. // // Note that it will overwrite the target file if there's already a same name file exist. -func (f *UploadFile) Save(path string, randomlyRename ...bool) error { +func (f *UploadFile) Save(path string, randomlyRename ...bool) (filename string, err error) { if f == nil { - return nil + return } file, err := f.Open() if err != nil { - return err + return "", err } defer file.Close() filePath := path if gfile.IsDir(path) { - filename := gfile.Basename(f.Filename) + name := gfile.Basename(f.Filename) if len(randomlyRename) > 0 && randomlyRename[0] { - filename = strings.ToLower(strconv.FormatInt(gtime.TimestampNano(), 36) + grand.S(6)) - filename = filename + gfile.Ext(f.Filename) + name = strings.ToLower(strconv.FormatInt(gtime.TimestampNano(), 36) + grand.S(6)) + name = name + gfile.Ext(f.Filename) } - filePath = gfile.Join(path, filename) + filePath = gfile.Join(path, name) } newFile, err := gfile.Create(filePath) if err != nil { - return err + return "", err } defer newFile.Close() intlog.Printf(`save upload file: %s`, filePath) if _, err := io.Copy(newFile, file); err != nil { - return err + return "", err } - return nil + return gfile.Basename(filePath), nil } -// Save saves all uploading files to specified directory path. +// Save saves all uploading files to specified directory path and returns the saved file names. // // The parameter should be a directory path or it returns error. // // The parameter specifies whether randomly renames all the file names. -func (fs UploadFiles) Save(dirPath string, randomlyRename ...bool) error { +func (fs UploadFiles) Save(dirPath string, randomlyRename ...bool) (filenames []string, err error) { if len(fs) == 0 { - return nil + return nil, nil } - if !gfile.IsDir(dirPath) { - return errors.New(`parameter "dirPath" should be a directory path`) + if !gfile.Exists(dirPath) { + if err = gfile.Mkdir(dirPath); err != nil { + return + } + } else if !gfile.IsDir(dirPath) { + return nil, errors.New(`parameter "dirPath" should be a directory path`) } - var err error for _, f := range fs { - if err = f.Save(dirPath, randomlyRename...); err != nil { - return err + if filename, err := f.Save(dirPath, randomlyRename...); err != nil { + return filenames, err + } else { + filenames = append(filenames, filename) } } - return nil + return } // GetUploadFile retrieves and returns the uploading file with specified form name. diff --git a/net/ghttp/ghttp_unit_param_file_test.go b/net/ghttp/ghttp_unit_param_file_test.go new file mode 100644 index 000000000..0a5384644 --- /dev/null +++ b/net/ghttp/ghttp_unit_param_file_test.go @@ -0,0 +1,146 @@ +// Copyright 2018 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package ghttp_test + +import ( + "fmt" + "github.com/gogf/gf/debug/gdebug" + "github.com/gogf/gf/os/gfile" + "github.com/gogf/gf/os/gtime" + "github.com/gogf/gf/text/gstr" + "testing" + "time" + + "github.com/gogf/gf/frame/g" + "github.com/gogf/gf/net/ghttp" + "github.com/gogf/gf/test/gtest" +) + +func Test_Params_File_Single(t *testing.T) { + dstDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr()) + err := gfile.Mkdir(dstDirPath) + gtest.Assert(err, nil) + + p := ports.PopRand() + s := g.Server(p) + s.BindHandler("/upload/single", func(r *ghttp.Request) { + file := r.GetUploadFile("file") + if file == nil { + r.Response.WriteExit("upload file cannot be empty") + } + + if name, err := file.Save(dstDirPath, r.GetBool("randomlyRename")); err == nil { + r.Response.WriteExit(name) + } + r.Response.WriteExit("upload failed") + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + time.Sleep(100 * time.Millisecond) + // normal name + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt") + dstPath := gfile.Join(dstDirPath, "file1.txt") + content := client.PostContent("/upload/single", g.Map{ + "file": "@file:" + srcPath, + }) + gtest.AssertNE(content, "") + gtest.AssertNE(content, "upload file cannot be empty") + gtest.AssertNE(content, "upload failed") + gtest.Assert(content, "file1.txt") + gtest.Assert(gfile.GetContents(dstPath), gfile.GetContents(srcPath)) + }) + // randomly rename. + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file2.txt") + content := client.PostContent("/upload/single", g.Map{ + "file": "@file:" + srcPath, + "randomlyRename": true, + }) + dstPath := gfile.Join(dstDirPath, content) + gtest.AssertNE(content, "") + gtest.AssertNE(content, "upload file cannot be empty") + gtest.AssertNE(content, "upload failed") + gtest.Assert(gfile.GetContents(dstPath), gfile.GetContents(srcPath)) + }) +} + +func Test_Params_File_Batch(t *testing.T) { + dstDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr()) + err := gfile.Mkdir(dstDirPath) + gtest.Assert(err, nil) + + p := ports.PopRand() + s := g.Server(p) + s.BindHandler("/upload/batch", func(r *ghttp.Request) { + files := r.GetUploadFiles("file") + if files == nil { + r.Response.WriteExit("upload file cannot be empty") + } + + if names, err := files.Save(dstDirPath, r.GetBool("randomlyRename")); err == nil { + r.Response.WriteExit(gstr.Join(names, ",")) + } + r.Response.WriteExit("upload failed") + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + time.Sleep(100 * time.Millisecond) + // normal name + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt") + srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file2.txt") + dstPath1 := gfile.Join(dstDirPath, "file1.txt") + dstPath2 := gfile.Join(dstDirPath, "file2.txt") + content := client.PostContent("/upload/batch", g.Map{ + "file[0]": "@file:" + srcPath1, + "file[1]": "@file:" + srcPath2, + }) + gtest.AssertNE(content, "") + gtest.AssertNE(content, "upload file cannot be empty") + gtest.AssertNE(content, "upload failed") + gtest.Assert(content, "file1.txt,file2.txt") + gtest.Assert(gfile.GetContents(dstPath1), gfile.GetContents(srcPath1)) + gtest.Assert(gfile.GetContents(dstPath2), gfile.GetContents(srcPath2)) + }) + // randomly rename. + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + srcPath1 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt") + srcPath2 := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file2.txt") + content := client.PostContent("/upload/batch", g.Map{ + "file[0]": "@file:" + srcPath1, + "file[1]": "@file:" + srcPath2, + "randomlyRename": true, + }) + gtest.AssertNE(content, "") + gtest.AssertNE(content, "upload file cannot be empty") + gtest.AssertNE(content, "upload failed") + + array := gstr.SplitAndTrim(content, ",") + gtest.Assert(len(array), 2) + dstPath1 := gfile.Join(dstDirPath, array[0]) + dstPath2 := gfile.Join(dstDirPath, array[1]) + gtest.Assert(gfile.GetContents(dstPath1), gfile.GetContents(srcPath1)) + gtest.Assert(gfile.GetContents(dstPath2), gfile.GetContents(srcPath2)) + }) +} diff --git a/net/ghttp/testdata/upload/file1.txt b/net/ghttp/testdata/upload/file1.txt new file mode 100644 index 000000000..1885a2771 --- /dev/null +++ b/net/ghttp/testdata/upload/file1.txt @@ -0,0 +1 @@ +file1.txt: This file is for uploading unit test case. \ No newline at end of file diff --git a/net/ghttp/testdata/upload/file2.txt b/net/ghttp/testdata/upload/file2.txt new file mode 100644 index 000000000..e3167d360 --- /dev/null +++ b/net/ghttp/testdata/upload/file2.txt @@ -0,0 +1 @@ +file2.txt: This file is for uploading unit test case. \ No newline at end of file diff --git a/os/gfile/gfile.go b/os/gfile/gfile.go index 5c101e632..bc99687dc 100644 --- a/os/gfile/gfile.go +++ b/os/gfile/gfile.go @@ -31,17 +31,32 @@ const ( var ( // Default perm for file opening. DefaultPerm = os.FileMode(0666) + // The absolute file path for main package. // It can be only checked and set once. mainPkgPath = gtype.NewString() + + // selfPath is the current running binary path. + // As it is most commonly used, it is so defined as an internal package variable. + selfPath = "" + // Temporary directory of system. tempDir = "/tmp" ) func init() { + // Initialize internal package variable: tempDir. if !Exists(tempDir) { tempDir = os.TempDir() } + // Initialize internal package variable: selfPath. + selfPath, _ := exec.LookPath(os.Args[0]) + if selfPath != "" { + selfPath, _ = filepath.Abs(selfPath) + } + if selfPath == "" { + selfPath, _ = filepath.Abs(os.Args[0]) + } } // Mkdir creates directories recursively with given . @@ -66,17 +81,20 @@ func Create(path string) (*os.File, error) { return os.Create(path) } -// Open opens file/directory readonly. +// Open opens file/directory READONLY. func Open(path string) (*os.File, error) { return os.Open(path) } -// OpenFile opens file/directory with given and . +// OpenFile opens file/directory with custom and . +// The parameter is like: O_RDONLY, O_RDWR, O_RDWR|O_CREATE|O_TRUNC, etc. func OpenFile(path string, flag int, perm os.FileMode) (*os.File, error) { return os.OpenFile(path, flag, perm) } -// OpenWithFlag opens file/directory with default perm and given . +// OpenWithFlag opens file/directory with default perm and custom . +// The default is 0666. +// The parameter is like: O_RDONLY, O_RDWR, O_RDWR|O_CREATE|O_TRUNC, etc. func OpenWithFlag(path string, flag int) (*os.File, error) { f, err := os.OpenFile(path, flag, DefaultPerm) if err != nil { @@ -85,9 +103,11 @@ func OpenWithFlag(path string, flag int) (*os.File, error) { return f, nil } -// OpenWithFlagPerm opens file/directory with given and . +// OpenWithFlagPerm opens file/directory with custom and . +// The parameter is like: O_RDONLY, O_RDWR, O_RDWR|O_CREATE|O_TRUNC, etc. +// The parameter is like: 0600, 0666, 0777, etc. func OpenWithFlagPerm(path string, flag int, perm os.FileMode) (*os.File, error) { - f, err := os.OpenFile(path, flag, os.FileMode(perm)) + f, err := os.OpenFile(path, flag, perm) if err != nil { return nil, err } @@ -208,6 +228,7 @@ func Glob(pattern string, onlyNames ...bool) ([]string, error) { // Remove deletes all file/directory with parameter. // If parameter is directory, it deletes it recursively. func Remove(path string) error { + //intlog.Print(`Remove:`, path) return os.RemoveAll(path) } @@ -278,17 +299,7 @@ func RealPath(path string) string { // SelfPath returns absolute file path of current running process(binary). func SelfPath() string { - path, _ := exec.LookPath(os.Args[0]) - if path != "" { - path, _ = filepath.Abs(path) - if path != "" { - return path - } - } - if path == "" { - path, _ = filepath.Abs(os.Args[0]) - } - return path + return selfPath } // SelfName returns file name of current running process(binary). diff --git a/text/gstr/gstr.go b/text/gstr/gstr.go index 1f504c15c..27b61f6d5 100644 --- a/text/gstr/gstr.go +++ b/text/gstr/gstr.go @@ -465,16 +465,16 @@ func SplitAndTrimSpace(str, delimiter string) []string { return array } -// Join concatenates the elements of a to create a single string. The separator string -// sep is placed between elements in the resulting string. +// Join concatenates the elements of to create a single string. The separator string +// is placed between elements in the resulting string. func Join(array []string, sep string) string { return strings.Join(array, sep) } -// JoinAny concatenates the elements of a to create a single string. The separator string -// sep is placed between elements in the resulting string. +// JoinAny concatenates the elements of to create a single string. The separator string +// is placed between elements in the resulting string. // -// The parameter can be any type of slice. +// The parameter can be any type of slice, which be converted to string array. func JoinAny(array interface{}, sep string) string { return strings.Join(gconv.Strings(array), sep) } From 5bdf1a71b8aec2acbd10b65242f0346dbc3a6989 Mon Sep 17 00:00:00 2001 From: John Date: Sat, 7 Mar 2020 20:20:52 +0800 Subject: [PATCH 10/26] improve uploading file feature for ghttp.Server --- net/ghttp/ghttp_request_param_file.go | 37 ++++++++++----------- net/ghttp/ghttp_unit_param_file_test.go | 44 +++++++++++++++++++++---- 2 files changed, 54 insertions(+), 27 deletions(-) diff --git a/net/ghttp/ghttp_request_param_file.go b/net/ghttp/ghttp_request_param_file.go index 6b0de4f70..65f3e8f81 100644 --- a/net/ghttp/ghttp_request_param_file.go +++ b/net/ghttp/ghttp_request_param_file.go @@ -26,34 +26,38 @@ type UploadFile struct { // UploadFiles is array type for *UploadFile. type UploadFiles []*UploadFile -// Save saves the single uploading file to specified path and returns the saved file name. -// The parameter path can be either a directory or a file path. If is a directory, -// it saves the uploading file to the directory using its original name. If is a -// file path, it saves the uploading file to the file path. +// Save saves the single uploading file to directory path and returns the saved file name. +// +// The parameter should be a directory path or it returns error. // // The parameter specifies whether randomly renames the file name, which // make sense if the is a directory. // // Note that it will overwrite the target file if there's already a same name file exist. -func (f *UploadFile) Save(path string, randomlyRename ...bool) (filename string, err error) { +func (f *UploadFile) Save(dirPath string, randomlyRename ...bool) (filename string, err error) { if f == nil { return } + if !gfile.Exists(dirPath) { + if err = gfile.Mkdir(dirPath); err != nil { + return + } + } else if !gfile.IsDir(dirPath) { + return "", errors.New(`parameter "dirPath" should be a directory path`) + } + file, err := f.Open() if err != nil { return "", err } defer file.Close() - filePath := path - if gfile.IsDir(path) { - name := gfile.Basename(f.Filename) - if len(randomlyRename) > 0 && randomlyRename[0] { - name = strings.ToLower(strconv.FormatInt(gtime.TimestampNano(), 36) + grand.S(6)) - name = name + gfile.Ext(f.Filename) - } - filePath = gfile.Join(path, name) + name := gfile.Basename(f.Filename) + if len(randomlyRename) > 0 && randomlyRename[0] { + name = strings.ToLower(strconv.FormatInt(gtime.TimestampNano(), 36) + grand.S(6)) + name = name + gfile.Ext(f.Filename) } + filePath := gfile.Join(dirPath, name) newFile, err := gfile.Create(filePath) if err != nil { return "", err @@ -75,13 +79,6 @@ func (fs UploadFiles) Save(dirPath string, randomlyRename ...bool) (filenames [] if len(fs) == 0 { return nil, nil } - if !gfile.Exists(dirPath) { - if err = gfile.Mkdir(dirPath); err != nil { - return - } - } else if !gfile.IsDir(dirPath) { - return nil, errors.New(`parameter "dirPath" should be a directory path`) - } for _, f := range fs { if filename, err := f.Save(dirPath, randomlyRename...); err != nil { return filenames, err diff --git a/net/ghttp/ghttp_unit_param_file_test.go b/net/ghttp/ghttp_unit_param_file_test.go index 0a5384644..357eb3e4c 100644 --- a/net/ghttp/ghttp_unit_param_file_test.go +++ b/net/ghttp/ghttp_unit_param_file_test.go @@ -22,9 +22,6 @@ import ( func Test_Params_File_Single(t *testing.T) { dstDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr()) - err := gfile.Mkdir(dstDirPath) - gtest.Assert(err, nil) - p := ports.PopRand() s := g.Server(p) s.BindHandler("/upload/single", func(r *ghttp.Request) { @@ -77,11 +74,45 @@ func Test_Params_File_Single(t *testing.T) { }) } +func Test_Params_File_CustomName(t *testing.T) { + dstDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr()) + p := ports.PopRand() + s := g.Server(p) + s.BindHandler("/upload/single", func(r *ghttp.Request) { + file := r.GetUploadFile("file") + if file == nil { + r.Response.WriteExit("upload file cannot be empty") + } + file.Filename = "my.txt" + if name, err := file.Save(dstDirPath, r.GetBool("randomlyRename")); err == nil { + r.Response.WriteExit(name) + } + r.Response.WriteExit("upload failed") + }) + s.SetPort(p) + s.SetDumpRouterMap(false) + s.Start() + defer s.Shutdown() + time.Sleep(100 * time.Millisecond) + gtest.Case(t, func() { + client := ghttp.NewClient() + client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p)) + + srcPath := gfile.Join(gdebug.CallerDirectory(), "testdata", "upload", "file1.txt") + dstPath := gfile.Join(dstDirPath, "my.txt") + content := client.PostContent("/upload/single", g.Map{ + "file": "@file:" + srcPath, + }) + gtest.AssertNE(content, "") + gtest.AssertNE(content, "upload file cannot be empty") + gtest.AssertNE(content, "upload failed") + gtest.Assert(content, "my.txt") + gtest.Assert(gfile.GetContents(dstPath), gfile.GetContents(srcPath)) + }) +} + func Test_Params_File_Batch(t *testing.T) { dstDirPath := gfile.Join(gfile.TempDir(), gtime.TimestampNanoStr()) - err := gfile.Mkdir(dstDirPath) - gtest.Assert(err, nil) - p := ports.PopRand() s := g.Server(p) s.BindHandler("/upload/batch", func(r *ghttp.Request) { @@ -89,7 +120,6 @@ func Test_Params_File_Batch(t *testing.T) { if files == nil { r.Response.WriteExit("upload file cannot be empty") } - if names, err := files.Save(dstDirPath, r.GetBool("randomlyRename")); err == nil { r.Response.WriteExit(gstr.Join(names, ",")) } From 6665d62e7e814f831edb6881f78f448a8206a9d4 Mon Sep 17 00:00:00 2001 From: John Date: Sat, 7 Mar 2020 20:28:00 +0800 Subject: [PATCH 11/26] improve package gfile --- os/gfile/gfile.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/os/gfile/gfile.go b/os/gfile/gfile.go index bc99687dc..d905f09d4 100644 --- a/os/gfile/gfile.go +++ b/os/gfile/gfile.go @@ -50,7 +50,7 @@ func init() { tempDir = os.TempDir() } // Initialize internal package variable: selfPath. - selfPath, _ := exec.LookPath(os.Args[0]) + selfPath, _ = exec.LookPath(os.Args[0]) if selfPath != "" { selfPath, _ = filepath.Abs(selfPath) } From 0e52d467d37ee8c78669aed7dcc6c599bc5d7179 Mon Sep 17 00:00:00 2001 From: John Date: Sun, 8 Mar 2020 00:17:42 +0800 Subject: [PATCH 12/26] improving package gdb --- .example/database/gdb/driver/driver.go | 121 ++++++++++ .example/database/gdb/mysql/gdb_all.go | 6 +- .example/database/gdb/mysql/gdb_value.go | 1 - database/gdb/gdb.go | 177 ++++++-------- database/gdb/{gdb_base.go => gdb_core.go} | 227 +++++++++--------- .../gdb/{gdb_config.go => gdb_core_config.go} | 55 +++-- database/gdb/gdb_core_utility.go | 50 ++++ .../gdb/{gdb_mssql.go => gdb_driver_mssql.go} | 44 ++-- .../gdb/{gdb_mysql.go => gdb_driver_mysql.go} | 44 ++-- .../{gdb_oracle.go => gdb_driver_oracle.go} | 73 +++--- .../gdb/{gdb_pgsql.go => gdb_driver_pgsql.go} | 38 +-- .../{gdb_sqlite.go => gdb_driver_sqlite.go} | 26 +- database/gdb/gdb_func.go | 2 +- database/gdb/gdb_model.go | 66 ++--- database/gdb/gdb_schema.go | 8 +- database/gdb/gdb_structure.go | 6 +- database/gdb/gdb_transaction.go | 30 +-- database/gredis/gredis.go | 5 +- go.mod | 3 +- 19 files changed, 594 insertions(+), 388 deletions(-) create mode 100644 .example/database/gdb/driver/driver.go rename database/gdb/{gdb_base.go => gdb_core.go} (70%) rename database/gdb/{gdb_config.go => gdb_core_config.go} (80%) create mode 100644 database/gdb/gdb_core_utility.go rename database/gdb/{gdb_mssql.go => gdb_driver_mssql.go} (82%) rename database/gdb/{gdb_mysql.go => gdb_driver_mysql.go} (67%) rename database/gdb/{gdb_oracle.go => gdb_driver_oracle.go} (82%) rename database/gdb/{gdb_pgsql.go => gdb_driver_pgsql.go} (74%) rename database/gdb/{gdb_sqlite.go => gdb_driver_sqlite.go} (64%) diff --git a/.example/database/gdb/driver/driver.go b/.example/database/gdb/driver/driver.go new file mode 100644 index 000000000..0d0599c9c --- /dev/null +++ b/.example/database/gdb/driver/driver.go @@ -0,0 +1,121 @@ +// Copyright 2017 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. + +package driver + +import ( + "database/sql" + "fmt" + "github.com/gogf/gf/database/gdb" + "github.com/gogf/gf/internal/intlog" + "github.com/gogf/gf/text/gstr" + + _ "github.com/gf-third/mysql" +) + +type MyDriver struct { + *gdb.Core +} + +// Open creates and returns a underlying sql.DB object for mysql. +func (d *MyDriver) Open(config *gdb.ConfigNode) (*sql.DB, error) { + var source string + if config.LinkInfo != "" { + source = config.LinkInfo + } else { + source = fmt.Sprintf( + "%s:%s@tcp(%s:%s)/%s?charset=%s&multiStatements=true&parseTime=true&loc=Local", + config.User, config.Pass, config.Host, config.Port, config.Name, config.Charset, + ) + } + intlog.Printf("Open: %s", source) + if db, err := sql.Open("gf-mysql", source); err == nil { + return db, nil + } else { + return nil, err + } +} + +// getChars returns the security char for this type of database. +func (d *MyDriver) GetChars() (charLeft string, charRight string) { + return "`", "`" +} + +// handleSqlBeforeExec handles the sql before posts it to database. +func (d *MyDriver) HandleSqlBeforeExec(sql string) string { + return sql +} + +// Tables retrieves and returns the tables of current schema. +func (d *MyDriver) Tables(schema ...string) (tables []string, err error) { + var result gdb.Result + link, err := d.DB.GetSlave(schema...) + if err != nil { + return nil, err + } + result, err = d.DB.DoGetAll(link, `SHOW TABLES`) + if err != nil { + return + } + for _, m := range result { + for _, v := range m { + tables = append(tables, v.String()) + } + } + return +} + +// gdb.TableFields retrieves and returns the fields information of specified table of current schema. +// +// Note that it returns a map containing the field name and its corresponding fields. +// As a map is unsorted, the gdb.TableField struct has a "Index" field marks its sequence in the fields. +// +// It's using cache feature to enhance the performance, which is never expired util the process restarts. +func (d *MyDriver) TableFields(table string, schema ...string) (fields map[string]*gdb.TableField, err error) { + table = gstr.Trim(table) + if gstr.Contains(table, " ") { + panic("function gdb.TableFields supports only single table operations") + } + checkSchema := d.DB.GetSchema() + if len(schema) > 0 && schema[0] != "" { + checkSchema = schema[0] + } + v := d.DB.GetCache().GetOrSetFunc( + fmt.Sprintf(`mysql_table_fields_%s_%s`, table, checkSchema), + func() interface{} { + var result gdb.Result + var link *sql.DB + link, err = d.DB.GetSlave(checkSchema) + if err != nil { + return nil + } + result, err = d.DB.DoGetAll( + link, + fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.DB.QuoteWord(table)), + ) + if err != nil { + return nil + } + fields = make(map[string]*gdb.TableField) + for i, m := range result { + fields[m["Field"].String()] = &gdb.TableField{ + Index: i, + Name: m["Field"].String(), + Type: m["Type"].String(), + Null: m["Null"].Bool(), + Key: m["Key"].String(), + Default: m["Default"].Val(), + Extra: m["Extra"].String(), + Comment: m["Comment"].String(), + } + } + return fields + }, 0) + if err == nil { + fields = v.(map[string]*gdb.TableField) + } + return +} diff --git a/.example/database/gdb/mysql/gdb_all.go b/.example/database/gdb/mysql/gdb_all.go index 681277607..30161552f 100644 --- a/.example/database/gdb/mysql/gdb_all.go +++ b/.example/database/gdb/mysql/gdb_all.go @@ -11,11 +11,11 @@ func main() { // 开启调试模式,以便于记录所有执行的SQL db.SetDebug(true) - r, e := db.Table("test").OrderBy("id asc").All() + r, e := db.Table("test").Order("id asc").All() if e != nil { - panic(e) + fmt.Println(e) } if r != nil { - fmt.Println(r.ToList()) + fmt.Println(r.List()) } } diff --git a/.example/database/gdb/mysql/gdb_value.go b/.example/database/gdb/mysql/gdb_value.go index 400323a16..c0e325189 100644 --- a/.example/database/gdb/mysql/gdb_value.go +++ b/.example/database/gdb/mysql/gdb_value.go @@ -9,5 +9,4 @@ func main() { db.SetDebug(true) db.Table("user").Fields("DISTINCT id,nickname").Filter().All() - } diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index b804ad059..cc3e1b429 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -34,14 +34,14 @@ type DB interface { Prepare(sql string, execOnMaster ...bool) (*sql.Stmt, error) // Internal APIs for CURD, which can be overwrote for custom CURD implements. - doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) - doGetAll(link dbLink, query string, args ...interface{}) (result Result, err error) - doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) - doPrepare(link dbLink, query string) (*sql.Stmt, error) - doInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) - doBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) - doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) - doDelete(link dbLink, table string, condition string, args ...interface{}) (result sql.Result, err error) + DoQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) + DoGetAll(link dbLink, query string, args ...interface{}) (result Result, err error) + DoExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) + DoPrepare(link dbLink, query string) (*sql.Stmt, error) + DoInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) + DoBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) + DoUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) + DoDelete(link dbLink, table string, condition string, args ...interface{}) (result sql.Result, err error) // Query APIs for convenience purpose. GetAll(query string, args ...interface{}) (Result, error) @@ -52,11 +52,11 @@ type DB interface { GetStructs(objPointerSlice interface{}, query string, args ...interface{}) error GetScan(objPointer interface{}, query string, args ...interface{}) error - // Master/Slave support. + // Master/Slave specification support. Master() (*sql.DB, error) Slave() (*sql.DB, error) - // Ping. + // Ping-Pong. PingMaster() error PingSlave() error @@ -75,48 +75,44 @@ type DB interface { Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) Delete(table string, condition interface{}, args ...interface{}) (sql.Result, error) - // Create model. + // Model creation. From(tables string) *Model Table(tables string) *Model Schema(schema string) *Schema // Configuration methods. + GetCache() *gcache.Cache SetDebug(debug bool) + GetDebug() bool SetSchema(schema string) + GetSchema() string + GetPrefix() string SetLogger(logger *glog.Logger) GetLogger() *glog.Logger SetMaxIdleConnCount(n int) SetMaxOpenConnCount(n int) SetMaxConnLifetime(d time.Duration) + + // Utility methods. + GetChars() (charLeft string, charRight string) + GetMaster(schema ...string) (*sql.DB, error) + GetSlave(schema ...string) (*sql.DB, error) + QuoteWord(s string) string + QuoteString(s string) string + HandleSqlBeforeExec(sql string) string Tables(schema ...string) (tables []string, err error) TableFields(table string, schema ...string) (map[string]*TableField, error) // Internal methods. - getCache() *gcache.Cache - getChars() (charLeft string, charRight string) - getDebug() bool - getPrefix() string - getMaster(schema ...string) (*sql.DB, error) - getSlave(schema ...string) (*sql.DB, error) - quoteWord(s string) string - quoteString(s string) string handleTableName(table string) string filterFields(schema, table string, data map[string]interface{}) map[string]interface{} convertValue(fieldValue []byte, fieldType string) interface{} rowsToResult(rows *sql.Rows) (Result, error) - handleSqlBeforeExec(sql string) string } -// dbLink is a common database function wrapper interface for internal usage. -type dbLink interface { - Query(query string, args ...interface{}) (*sql.Rows, error) - Exec(sql string, args ...interface{}) (sql.Result, error) - Prepare(sql string) (*sql.Stmt, error) -} - -// dbBase is the base struct for database management. -type dbBase struct { - db DB // DB interface object. +// Core is the base struct for database management. +type Core struct { + DB DB // DB interface object. group string // Configuration group name. debug *gtype.Bool // Enable debug mode for the database. cache *gcache.Cache // Cache manager. @@ -128,6 +124,11 @@ type dbBase struct { maxConnLifetime time.Duration // Max TTL for a connection. } +// Driver is the interface for integrating sql drivers into package gdb. +type Driver interface { + New(core *Core, node *ConfigNode) (DB, error) +} + // Sql is the sql recording struct. type Sql struct { Sql string // SQL string(may contain reserved char '?'). @@ -150,6 +151,13 @@ type TableField struct { Comment string // Comment. } +// dbLink is a common database function wrapper interface for internal usage. +type dbLink interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + Exec(sql string, args ...interface{}) (sql.Result, error) + Prepare(sql string) (*sql.Stmt, error) +} + // Value is the field value type. type Value = *gvar.Var @@ -176,10 +184,23 @@ const ( ) var ( - // Instance map. + // instances is the management map for instances. instances = gmap.NewStrAnyMap(true) + // driverMap manages all custom registered driver. + driverMap = map[string]Driver{ + "mysql": &DriverMysql{}, + "mssql": &DriverMssql{}, + "oracle": &DriverOracle{}, + "sqlite": &DriverSqlite{}, + } ) +// Register registers custom database driver to gdb. +func Register(name string, driver Driver) error { + driverMap[name] = driver + return nil +} + // New creates and returns an ORM object with global configurations. // The parameter specifies the configuration group name, // which is DEFAULT_GROUP_NAME in default. @@ -196,31 +217,24 @@ func New(name ...string) (db DB, err error) { } if _, ok := configs.config[group]; ok { if node, err := getConfigNodeByGroup(group, true); err == nil { - base := &dbBase{ - group: group, - debug: gtype.NewBool(), - cache: gcache.New(), - schema: gtype.NewString(), - logger: glog.New(), - prefix: node.Prefix, - // Default max connection life time if user does not configure. - maxConnLifetime: gDEFAULT_CONN_MAX_LIFE_TIME, + c := &Core{ + group: group, + debug: gtype.NewBool(), + cache: gcache.New(), + schema: gtype.NewString(), + logger: glog.New(), + prefix: node.Prefix, + maxConnLifetime: gDEFAULT_CONN_MAX_LIFE_TIME, // Default max connection life time if user does not configure. } - switch node.Type { - case "mysql": - base.db = &dbMysql{dbBase: base} - case "pgsql": - base.db = &dbPgsql{dbBase: base} - case "mssql": - base.db = &dbMssql{dbBase: base} - case "sqlite": - base.db = &dbSqlite{dbBase: base} - case "oracle": - base.db = &dbOracle{dbBase: base} - default: + if v, ok := driverMap[node.Type]; ok { + c.DB, err = v.New(c, node) + if err != nil { + return nil, err + } + return c.DB, nil + } else { return nil, errors.New(fmt.Sprintf(`unsupported database type "%s"`, node.Type)) } - return base.db, nil } else { return nil, err } @@ -321,9 +335,9 @@ func getConfigNodeByWeight(cg ConfigGroup) *ConfigNode { // getSqlDb retrieves and returns a underlying database connection object. // The parameter specifies whether retrieves master node connection if // master-slave nodes are configured. -func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error) { +func (c *Core) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err error) { // Load balance. - node, err := getConfigNodeByGroup(bs.group, master) + node, err := getConfigNodeByGroup(c.group, master) if err != nil { return nil, err } @@ -332,7 +346,7 @@ func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err er node.Charset = "utf8" } // Changes the schema. - nodeSchema := bs.schema.Val() + nodeSchema := c.schema.Val() if len(schema) > 0 && schema[0] != "" { nodeSchema = schema[0] } @@ -343,25 +357,25 @@ func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err er node = &n } // Cache the underlying connection object by node. - v := bs.cache.GetOrSetFuncLock(node.String(), func() interface{} { - sqlDb, err = bs.db.Open(node) + v := c.cache.GetOrSetFuncLock(node.String(), func() interface{} { + sqlDb, err = c.DB.Open(node) if err != nil { return nil } - if bs.maxIdleConnCount > 0 { - sqlDb.SetMaxIdleConns(bs.maxIdleConnCount) + if c.maxIdleConnCount > 0 { + sqlDb.SetMaxIdleConns(c.maxIdleConnCount) } else if node.MaxIdleConnCount > 0 { sqlDb.SetMaxIdleConns(node.MaxIdleConnCount) } - if bs.maxOpenConnCount > 0 { - sqlDb.SetMaxOpenConns(bs.maxOpenConnCount) + if c.maxOpenConnCount > 0 { + sqlDb.SetMaxOpenConns(c.maxOpenConnCount) } else if node.MaxOpenConnCount > 0 { sqlDb.SetMaxOpenConns(node.MaxOpenConnCount) } - if bs.maxConnLifetime > 0 { - sqlDb.SetConnMaxLifetime(bs.maxConnLifetime * time.Second) + if c.maxConnLifetime > 0 { + sqlDb.SetConnMaxLifetime(c.maxConnLifetime * time.Second) } else if node.MaxConnLifetime > 0 { sqlDb.SetConnMaxLifetime(node.MaxConnLifetime * time.Second) } @@ -371,40 +385,7 @@ func (bs *dbBase) getSqlDb(master bool, schema ...string) (sqlDb *sql.DB, err er sqlDb = v.(*sql.DB) } if node.Debug { - bs.db.SetDebug(node.Debug) + c.DB.SetDebug(node.Debug) } return } - -// SetSchema changes the schema for this database connection object. -// Importantly note that when schema configuration changed for the database, -// it affects all operations on the database object in the future. -func (bs *dbBase) SetSchema(schema string) { - bs.schema.Set(schema) -} - -// Master creates and returns a connection from master node if master-slave configured. -// It returns the default connection if master-slave not configured. -func (bs *dbBase) Master() (*sql.DB, error) { - return bs.getSqlDb(true, bs.schema.Val()) -} - -// Slave creates and returns a connection from slave node if master-slave configured. -// It returns the default connection if master-slave not configured. -func (bs *dbBase) Slave() (*sql.DB, error) { - return bs.getSqlDb(false, bs.schema.Val()) -} - -// getMaster acts like function Master but with additional parameter specifying -// the schema for the connection. It is defined for internal usage. -// Also see Master. -func (bs *dbBase) getMaster(schema ...string) (*sql.DB, error) { - return bs.getSqlDb(true, schema...) -} - -// getSlave acts like function Slave but with additional parameter specifying -// the schema for the connection. It is defined for internal usage. -// Also see Slave. -func (bs *dbBase) getSlave(schema ...string) (*sql.DB, error) { - return bs.getSqlDb(false, schema...) -} diff --git a/database/gdb/gdb_base.go b/database/gdb/gdb_core.go similarity index 70% rename from database/gdb/gdb_base.go rename to database/gdb/gdb_core.go index cbefda532..599d12051 100644 --- a/database/gdb/gdb_base.go +++ b/database/gdb/gdb_core.go @@ -16,7 +16,6 @@ import ( "strings" "github.com/gogf/gf/container/gvar" - "github.com/gogf/gf/os/gcache" "github.com/gogf/gf/os/gtime" "github.com/gogf/gf/text/gregex" "github.com/gogf/gf/util/gconv" @@ -32,22 +31,34 @@ var ( lastOperatorReg = regexp.MustCompile(`[<>=]+\s*$`) ) +// Master creates and returns a connection from master node if master-slave configured. +// It returns the default connection if master-slave not configured. +func (c *Core) Master() (*sql.DB, error) { + return c.getSqlDb(true, c.schema.Val()) +} + +// Slave creates and returns a connection from slave node if master-slave configured. +// It returns the default connection if master-slave not configured. +func (c *Core) Slave() (*sql.DB, error) { + return c.getSqlDb(false, c.schema.Val()) +} + // Query commits one query SQL to underlying driver and returns the execution result. // It is most commonly used for data querying. -func (bs *dbBase) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { - link, err := bs.db.Slave() +func (c *Core) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { + link, err := c.DB.Slave() if err != nil { return nil, err } - return bs.db.doQuery(link, query, args...) + return c.DB.DoQuery(link, query, args...) } // doQuery commits the query string and its arguments to underlying driver // through given link object and returns the execution result. -func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) { +func (c *Core) DoQuery(link dbLink, query string, args ...interface{}) (rows *sql.Rows, err error) { query, args = formatQuery(query, args) - query = bs.db.handleSqlBeforeExec(query) - if bs.db.getDebug() { + query = c.DB.HandleSqlBeforeExec(query) + if c.DB.GetDebug() { mTime1 := gtime.TimestampMilli() rows, err = link.Query(query, args...) mTime2 := gtime.TimestampMilli() @@ -59,7 +70,7 @@ func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows Start: mTime1, End: mTime2, } - bs.printSql(s) + c.printSql(s) } else { rows, err = link.Query(query, args...) } @@ -73,20 +84,20 @@ func (bs *dbBase) doQuery(link dbLink, query string, args ...interface{}) (rows // Exec commits one query SQL to underlying driver and returns the execution result. // It is most commonly used for data inserting and updating. -func (bs *dbBase) Exec(query string, args ...interface{}) (result sql.Result, err error) { - link, err := bs.db.Master() +func (c *Core) Exec(query string, args ...interface{}) (result sql.Result, err error) { + link, err := c.DB.Master() if err != nil { return nil, err } - return bs.db.doExec(link, query, args...) + return c.DB.DoExec(link, query, args...) } // doExec commits the query string and its arguments to underlying driver // through given link object and returns the execution result. -func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) { +func (c *Core) DoExec(link dbLink, query string, args ...interface{}) (result sql.Result, err error) { query, args = formatQuery(query, args) - query = bs.db.handleSqlBeforeExec(query) - if bs.db.getDebug() { + query = c.DB.HandleSqlBeforeExec(query) + if c.DB.GetDebug() { mTime1 := gtime.TimestampMilli() result, err = link.Exec(query, args...) mTime2 := gtime.TimestampMilli() @@ -98,7 +109,7 @@ func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result Start: mTime1, End: mTime2, } - bs.printSql(s) + c.printSql(s) } else { result, err = link.Exec(query, args...) } @@ -113,50 +124,50 @@ func (bs *dbBase) doExec(link dbLink, query string, args ...interface{}) (result // // The parameter specifies whether executing the sql on master node, // or else it executes the sql on slave node if master-slave configured. -func (bs *dbBase) Prepare(query string, execOnMaster ...bool) (*sql.Stmt, error) { +func (c *Core) Prepare(query string, execOnMaster ...bool) (*sql.Stmt, error) { err := (error)(nil) link := (dbLink)(nil) if len(execOnMaster) > 0 && execOnMaster[0] { - if link, err = bs.db.Master(); err != nil { + if link, err = c.DB.Master(); err != nil { return nil, err } } else { - if link, err = bs.db.Slave(); err != nil { + if link, err = c.DB.Slave(); err != nil { return nil, err } } - return bs.db.doPrepare(link, query) + return c.DB.DoPrepare(link, query) } // doPrepare calls prepare function on given link object and returns the statement object. -func (bs *dbBase) doPrepare(link dbLink, query string) (*sql.Stmt, error) { +func (c *Core) DoPrepare(link dbLink, query string) (*sql.Stmt, error) { return link.Prepare(query) } // GetAll queries and returns data records from database. -func (bs *dbBase) GetAll(query string, args ...interface{}) (Result, error) { - return bs.db.doGetAll(nil, query, args...) +func (c *Core) GetAll(query string, args ...interface{}) (Result, error) { + return c.DB.DoGetAll(nil, query, args...) } // doGetAll queries and returns data records from database. -func (bs *dbBase) doGetAll(link dbLink, query string, args ...interface{}) (result Result, err error) { +func (c *Core) DoGetAll(link dbLink, query string, args ...interface{}) (result Result, err error) { if link == nil { - link, err = bs.db.Slave() + link, err = c.DB.Slave() if err != nil { return nil, err } } - rows, err := bs.doQuery(link, query, args...) + rows, err := c.DB.DoQuery(link, query, args...) if err != nil || rows == nil { return nil, err } defer rows.Close() - return bs.db.rowsToResult(rows) + return c.DB.rowsToResult(rows) } // GetOne queries and returns one record from database. -func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) { - list, err := bs.GetAll(query, args...) +func (c *Core) GetOne(query string, args ...interface{}) (Record, error) { + list, err := c.DB.GetAll(query, args...) if err != nil { return nil, err } @@ -168,8 +179,8 @@ func (bs *dbBase) GetOne(query string, args ...interface{}) (Record, error) { // GetStruct queries one record from database and converts it to given struct. // The parameter should be a pointer to struct. -func (bs *dbBase) GetStruct(pointer interface{}, query string, args ...interface{}) error { - one, err := bs.GetOne(query, args...) +func (c *Core) GetStruct(pointer interface{}, query string, args ...interface{}) error { + one, err := c.DB.GetOne(query, args...) if err != nil { return err } @@ -181,8 +192,8 @@ func (bs *dbBase) GetStruct(pointer interface{}, query string, args ...interface // GetStructs queries records from database and converts them to given struct. // The parameter should be type of struct slice: []struct/[]*struct. -func (bs *dbBase) GetStructs(pointer interface{}, query string, args ...interface{}) error { - all, err := bs.GetAll(query, args...) +func (c *Core) GetStructs(pointer interface{}, query string, args ...interface{}) error { + all, err := c.DB.GetAll(query, args...) if err != nil { return err } @@ -198,7 +209,7 @@ func (bs *dbBase) GetStructs(pointer interface{}, query string, args ...interfac // If parameter is type of struct pointer, it calls GetStruct internally for // the conversion. If parameter is type of slice, it calls GetStructs internally // for conversion. -func (bs *dbBase) GetScan(pointer interface{}, query string, args ...interface{}) error { +func (c *Core) GetScan(pointer interface{}, query string, args ...interface{}) error { t := reflect.TypeOf(pointer) k := t.Kind() if k != reflect.Ptr { @@ -207,9 +218,9 @@ func (bs *dbBase) GetScan(pointer interface{}, query string, args ...interface{} k = t.Elem().Kind() switch k { case reflect.Array, reflect.Slice: - return bs.db.GetStructs(pointer, query, args...) + return c.DB.GetStructs(pointer, query, args...) case reflect.Struct: - return bs.db.GetStruct(pointer, query, args...) + return c.DB.GetStruct(pointer, query, args...) } return fmt.Errorf("element type should be type of struct/slice, unsupported: %v", k) } @@ -217,8 +228,8 @@ func (bs *dbBase) GetScan(pointer interface{}, query string, args ...interface{} // GetValue queries and returns the field value from database. // The sql should queries only one field from database, or else it returns only one // field of the result. -func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) { - one, err := bs.GetOne(query, args...) +func (c *Core) GetValue(query string, args ...interface{}) (Value, error) { + one, err := c.DB.GetOne(query, args...) if err != nil { return nil, err } @@ -229,13 +240,13 @@ func (bs *dbBase) GetValue(query string, args ...interface{}) (Value, error) { } // GetCount queries and returns the count from database. -func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) { +func (c *Core) GetCount(query string, args ...interface{}) (int, error) { // If the query fields do not contains function "COUNT", // it replaces the query string and adds the "COUNT" function to the fields. if !gregex.IsMatchString(`(?i)SELECT\s+COUNT\(.+\)\s+FROM`, query) { query, _ = gregex.ReplaceString(`(?i)(SELECT)\s+(.+)\s+(FROM)`, `$1 COUNT($2) $3`, query) } - value, err := bs.GetValue(query, args...) + value, err := c.DB.GetValue(query, args...) if err != nil { return 0, err } @@ -243,8 +254,8 @@ func (bs *dbBase) GetCount(query string, args ...interface{}) (int, error) { } // PingMaster pings the master node to check authentication or keeps the connection alive. -func (bs *dbBase) PingMaster() error { - if master, err := bs.db.Master(); err != nil { +func (c *Core) PingMaster() error { + if master, err := c.DB.Master(); err != nil { return err } else { return master.Ping() @@ -252,8 +263,8 @@ func (bs *dbBase) PingMaster() error { } // PingSlave pings the slave node to check authentication or keeps the connection alive. -func (bs *dbBase) PingSlave() error { - if slave, err := bs.db.Slave(); err != nil { +func (c *Core) PingSlave() error { + if slave, err := c.DB.Slave(); err != nil { return err } else { return slave.Ping() @@ -264,13 +275,13 @@ func (bs *dbBase) PingSlave() error { // You should call Commit or Rollback functions of the transaction object // if you no longer use the transaction. Commit or Rollback functions will also // close the transaction automatically. -func (bs *dbBase) Begin() (*TX, error) { - if master, err := bs.db.Master(); err != nil { +func (c *Core) Begin() (*TX, error) { + if master, err := c.DB.Master(); err != nil { return nil, err } else { if tx, err := master.Begin(); err == nil { return &TX{ - db: bs.db, + db: c.DB, tx: tx, master: master, }, nil @@ -289,8 +300,8 @@ func (bs *dbBase) Begin() (*TX, error) { // Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) // // The parameter specifies the batch operation count when given data is slice. -func (bs *dbBase) Insert(table string, data interface{}, batch ...int) (sql.Result, error) { - return bs.db.doInsert(nil, table, data, gINSERT_OPTION_DEFAULT, batch...) +func (c *Core) Insert(table string, data interface{}, batch ...int) (sql.Result, error) { + return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_DEFAULT, batch...) } // InsertIgnore does "INSERT IGNORE INTO ..." statement for the table. @@ -302,8 +313,8 @@ func (bs *dbBase) Insert(table string, data interface{}, batch ...int) (sql.Resu // Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) // // The parameter specifies the batch operation count when given data is slice. -func (bs *dbBase) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) { - return bs.db.doInsert(nil, table, data, gINSERT_OPTION_IGNORE, batch...) +func (c *Core) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) { + return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_IGNORE, batch...) } // Replace does "REPLACE INTO ..." statement for the table. @@ -318,8 +329,8 @@ func (bs *dbBase) InsertIgnore(table string, data interface{}, batch ...int) (sq // The parameter can be type of map/gmap/struct/*struct/[]map/[]struct, etc. // If given data is type of slice, it then does batch replacing, and the optional parameter // specifies the batch operation count. -func (bs *dbBase) Replace(table string, data interface{}, batch ...int) (sql.Result, error) { - return bs.db.doInsert(nil, table, data, gINSERT_OPTION_REPLACE, batch...) +func (c *Core) Replace(table string, data interface{}, batch ...int) (sql.Result, error) { + return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_REPLACE, batch...) } // Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table. @@ -333,8 +344,8 @@ func (bs *dbBase) Replace(table string, data interface{}, batch ...int) (sql.Res // // If given data is type of slice, it then does batch saving, and the optional parameter // specifies the batch operation count. -func (bs *dbBase) Save(table string, data interface{}, batch ...int) (sql.Result, error) { - return bs.db.doInsert(nil, table, data, gINSERT_OPTION_SAVE, batch...) +func (c *Core) Save(table string, data interface{}, batch ...int) (sql.Result, error) { + return c.DB.DoInsert(nil, table, data, gINSERT_OPTION_SAVE, batch...) } // doInsert inserts or updates data for given table. @@ -344,12 +355,12 @@ func (bs *dbBase) Save(table string, data interface{}, batch ...int) (sql.Result // 1: replace: if there's unique/primary key in the data, it deletes it from table and inserts a new one; // 2: save: if there's unique/primary key in the data, it updates it or else inserts a new one; // 3: ignore: if there's unique/primary key in the data, it ignores the inserting; -func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) { +func (c *Core) DoInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) { var fields []string var values []string var params []interface{} var dataMap Map - table = bs.db.handleTableName(table) + table = c.DB.handleTableName(table) rv := reflect.ValueOf(data) kind := rv.Kind() if kind == reflect.Ptr { @@ -358,7 +369,7 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i } switch kind { case reflect.Slice, reflect.Array: - return bs.db.doBatchInsert(link, table, data, option, batch...) + return c.DB.DoBatchInsert(link, table, data, option, batch...) case reflect.Map, reflect.Struct: dataMap = varToMapDeep(data) default: @@ -367,7 +378,7 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i if len(dataMap) == 0 { return nil, errors.New("data cannot be empty") } - charL, charR := bs.db.getChars() + charL, charR := c.DB.GetChars() for k, v := range dataMap { fields = append(fields, charL+k+charR) values = append(values, "?") @@ -388,11 +399,11 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i updateStr = fmt.Sprintf("ON DUPLICATE KEY UPDATE %s", updateStr) } if link == nil { - if link, err = bs.db.Master(); err != nil { + if link, err = c.DB.Master(); err != nil { return nil, err } } - return bs.db.doExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s", + return c.DB.DoExec(link, fmt.Sprintf("%s INTO %s(%s) VALUES(%s) %s", operation, table, strings.Join(fields, ","), strings.Join(values, ","), updateStr), params...) @@ -400,33 +411,33 @@ func (bs *dbBase) doInsert(link dbLink, table string, data interface{}, option i // BatchInsert batch inserts data. // The parameter must be type of slice of map or struct. -func (bs *dbBase) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) { - return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_DEFAULT, batch...) +func (c *Core) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) { + return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_DEFAULT, batch...) } // BatchInsert batch inserts data with ignore option. // The parameter must be type of slice of map or struct. -func (bs *dbBase) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) { - return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_IGNORE, batch...) +func (c *Core) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) { + return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_IGNORE, batch...) } // BatchReplace batch replaces data. // The parameter must be type of slice of map or struct. -func (bs *dbBase) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) { - return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_REPLACE, batch...) +func (c *Core) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) { + return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_REPLACE, batch...) } // BatchSave batch replaces data. // The parameter must be type of slice of map or struct. -func (bs *dbBase) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) { - return bs.db.doBatchInsert(nil, table, list, gINSERT_OPTION_SAVE, batch...) +func (c *Core) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) { + return c.DB.DoBatchInsert(nil, table, list, gINSERT_OPTION_SAVE, batch...) } // doBatchInsert batch inserts/replaces/saves data. -func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { +func (c *Core) DoBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { var keys, values []string var params []interface{} - table = bs.db.handleTableName(table) + table = c.DB.handleTableName(table) listMap := (List)(nil) switch v := list.(type) { case Result: @@ -461,7 +472,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt return result, errors.New("data list cannot be empty") } if link == nil { - if link, err = bs.db.Master(); err != nil { + if link, err = c.DB.Master(); err != nil { return } } @@ -473,7 +484,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt } // Prepare the result pointer. batchResult := new(batchSqlResult) - charL, charR := bs.db.getChars() + charL, charR := c.DB.GetChars() keysStr := charL + strings.Join(keys, charR+","+charL) + charR valueHolderStr := "(" + strings.Join(holders, ",") + ")" @@ -504,7 +515,7 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt } values = append(values, valueHolderStr) if len(values) == batchNum || (i == listMapLen-1 && len(values) > 0) { - r, err := bs.db.doExec( + r, err := c.DB.DoExec( link, fmt.Sprintf( "%s INTO %s(%s) VALUES%s %s", @@ -546,18 +557,18 @@ func (bs *dbBase) doBatchInsert(link dbLink, table string, list interface{}, opt // "status IN (?)", g.Slice{1,2,3} // "age IN(?,?)", 18, 50 // User{ Id : 1, UserName : "john"} -func (bs *dbBase) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { - newWhere, newArgs := formatWhere(bs.db, condition, args, false) +func (c *Core) Update(table string, data interface{}, condition interface{}, args ...interface{}) (sql.Result, error) { + newWhere, newArgs := formatWhere(c.DB, condition, args, false) if newWhere != "" { newWhere = " WHERE " + newWhere } - return bs.db.doUpdate(nil, table, data, newWhere, newArgs...) + return c.DB.DoUpdate(nil, table, data, newWhere, newArgs...) } // doUpdate does "UPDATE ... " statement for the table. // Also see Update. -func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) { - table = bs.db.handleTableName(table) +func (c *Core) DoUpdate(link dbLink, table string, data interface{}, condition string, args ...interface{}) (result sql.Result, err error) { + table = c.DB.handleTableName(table) updates := "" rv := reflect.ValueOf(data) kind := rv.Kind() @@ -570,7 +581,7 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio case reflect.Map, reflect.Struct: var fields []string for k, v := range varToMapDeep(data) { - fields = append(fields, bs.db.quoteWord(k)+"=?") + fields = append(fields, c.DB.QuoteWord(k)+"=?") params = append(params, v) } updates = strings.Join(fields, ",") @@ -585,11 +596,11 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio } // If no link passed, it then uses the master link. if link == nil { - if link, err = bs.db.Master(); err != nil { + if link, err = c.DB.Master(); err != nil { return nil, err } } - return bs.db.doExec(link, fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition), args...) + return c.DB.DoExec(link, fmt.Sprintf("UPDATE %s SET %s%s", table, updates, condition), args...) } // Delete does "DELETE FROM ... " statement for the table. @@ -603,38 +614,28 @@ func (bs *dbBase) doUpdate(link dbLink, table string, data interface{}, conditio // "status IN (?)", g.Slice{1,2,3} // "age IN(?,?)", 18, 50 // User{ Id : 1, UserName : "john"} -func (bs *dbBase) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) { - newWhere, newArgs := formatWhere(bs.db, condition, args, false) +func (c *Core) Delete(table string, condition interface{}, args ...interface{}) (result sql.Result, err error) { + newWhere, newArgs := formatWhere(c.DB, condition, args, false) if newWhere != "" { newWhere = " WHERE " + newWhere } - return bs.db.doDelete(nil, table, newWhere, newArgs...) + return c.DB.DoDelete(nil, table, newWhere, newArgs...) } // doDelete does "DELETE FROM ... " statement for the table. // Also see Delete. -func (bs *dbBase) doDelete(link dbLink, table string, condition string, args ...interface{}) (result sql.Result, err error) { +func (c *Core) DoDelete(link dbLink, table string, condition string, args ...interface{}) (result sql.Result, err error) { if link == nil { - if link, err = bs.db.Master(); err != nil { + if link, err = c.DB.Master(); err != nil { return nil, err } } - table = bs.db.handleTableName(table) - return bs.db.doExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...) -} - -// getCache returns the internal cache object. -func (bs *dbBase) getCache() *gcache.Cache { - return bs.cache -} - -// getPrefix returns the table prefix string configured. -func (bs *dbBase) getPrefix() string { - return bs.prefix + table = c.DB.handleTableName(table) + return c.DB.DoExec(link, fmt.Sprintf("DELETE FROM %s%s", table, condition), args...) } // rowsToResult converts underlying data record type sql.Rows to Result type. -func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) { +func (c *Core) rowsToResult(rows *sql.Rows) (Result, error) { if !rows.Next() { return nil, nil } @@ -671,7 +672,7 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) { // it should do a copy of it. v := make([]byte, len(value)) copy(v, value) - row[columnNames[i]] = gvar.New(bs.db.convertValue(v, columnTypes[i])) + row[columnNames[i]] = gvar.New(c.DB.convertValue(v, columnTypes[i])) } } records = append(records, row) @@ -687,34 +688,20 @@ func (bs *dbBase) rowsToResult(rows *sql.Rows) (Result, error) { // // Note that, this will automatically checks the table prefix whether already added, if true it does // nothing to the table name, or else adds the prefix to the table name. -func (bs *dbBase) handleTableName(table string) string { - charLeft, charRight := bs.db.getChars() - prefix := bs.db.getPrefix() +func (c *Core) handleTableName(table string) string { + charLeft, charRight := c.DB.GetChars() + prefix := c.DB.GetPrefix() return doHandleTableName(table, prefix, charLeft, charRight) } -// quoteWord checks given string a word, if true quotes it with security chars of the database -// and returns the quoted string; or else return without any change. -func (bs *dbBase) quoteWord(s string) string { - charLeft, charRight := bs.db.getChars() - return doQuoteWord(s, charLeft, charRight) -} - -// quoteString quotes string with quote chars. Strings like: -// "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc". -func (bs *dbBase) quoteString(s string) string { - charLeft, charRight := bs.db.getChars() - return doQuoteString(s, charLeft, charRight) -} - // printSql outputs the sql object to logger. // It is enabled when configuration "debug" is true. -func (bs *dbBase) printSql(v *Sql) { +func (c *Core) printSql(v *Sql) { s := fmt.Sprintf("[%d ms] %s", v.End-v.Start, v.Format) if v.Error != nil { s += "\nError: " + v.Error.Error() - bs.logger.StackWithFilter(gPATH_FILTER_KEY).Error(s) + c.logger.StackWithFilter(gPATH_FILTER_KEY).Error(s) } else { - bs.logger.StackWithFilter(gPATH_FILTER_KEY).Debug(s) + c.logger.StackWithFilter(gPATH_FILTER_KEY).Debug(s) } } diff --git a/database/gdb/gdb_config.go b/database/gdb/gdb_core_config.go similarity index 80% rename from database/gdb/gdb_config.go rename to database/gdb/gdb_core_config.go index 74523b220..46e52ef54 100644 --- a/database/gdb/gdb_config.go +++ b/database/gdb/gdb_core_config.go @@ -8,6 +8,7 @@ package gdb import ( "fmt" + "github.com/gogf/gf/os/gcache" "sync" "time" @@ -114,29 +115,29 @@ func GetDefaultGroup() string { } // SetLogger sets the logger for orm. -func (bs *dbBase) SetLogger(logger *glog.Logger) { - bs.logger = logger +func (c *Core) SetLogger(logger *glog.Logger) { + c.logger = logger } // GetLogger returns the logger of the orm. -func (bs *dbBase) GetLogger() *glog.Logger { - return bs.logger +func (c *Core) GetLogger() *glog.Logger { + return c.logger } // SetMaxIdleConnCount sets the max idle connection count for underlying connection pool. -func (bs *dbBase) SetMaxIdleConnCount(n int) { - bs.maxIdleConnCount = n +func (c *Core) SetMaxIdleConnCount(n int) { + c.maxIdleConnCount = n } // SetMaxOpenConnCount sets the max open connection count for underlying connection pool. -func (bs *dbBase) SetMaxOpenConnCount(n int) { - bs.maxOpenConnCount = n +func (c *Core) SetMaxOpenConnCount(n int) { + c.maxOpenConnCount = n } // SetMaxConnLifetime sets the connection TTL for underlying connection pool. // If parameter <= 0, it means the connection never expires. -func (bs *dbBase) SetMaxConnLifetime(d time.Duration) { - bs.maxConnLifetime = d +func (c *Core) SetMaxConnLifetime(d time.Duration) { + c.maxConnLifetime = d } // String returns the node as string. @@ -155,14 +156,36 @@ func (node *ConfigNode) String() string { } // SetDebug enables/disables the debug mode. -func (bs *dbBase) SetDebug(debug bool) { - if bs.debug.Val() == debug { +func (c *Core) SetDebug(debug bool) { + if c.debug.Val() == debug { return } - bs.debug.Set(debug) + c.debug.Set(debug) } -// getDebug returns the debug value. -func (bs *dbBase) getDebug() bool { - return bs.debug.Val() +// GetDebug returns the debug value. +func (c *Core) GetDebug() bool { + return c.debug.Val() +} + +// GetCache returns the internal cache object. +func (c *Core) GetCache() *gcache.Cache { + return c.cache +} + +// GetPrefix returns the table prefix string configured. +func (c *Core) GetPrefix() string { + return c.prefix +} + +// SetSchema changes the schema for this database connection object. +// Importantly note that when schema configuration changed for the database, +// it affects all operations on the database object in the future. +func (c *Core) SetSchema(schema string) { + c.schema.Set(schema) +} + +// GetSchema returns the schema configured. +func (c *Core) GetSchema() string { + return c.schema.Val() } diff --git a/database/gdb/gdb_core_utility.go b/database/gdb/gdb_core_utility.go new file mode 100644 index 000000000..b024fe28d --- /dev/null +++ b/database/gdb/gdb_core_utility.go @@ -0,0 +1,50 @@ +// Copyright 2019 gf Author(https://github.com/gogf/gf). All Rights Reserved. +// +// This Source Code Form is subject to the terms of the MIT License. +// If a copy of the MIT was not distributed with this file, +// You can obtain one at https://github.com/gogf/gf. +// + +package gdb + +import "database/sql" + +// GetMaster acts like function Master but with additional parameter specifying +// the schema for the connection. It is defined for internal usage. +// Also see Master. +func (c *Core) GetMaster(schema ...string) (*sql.DB, error) { + return c.getSqlDb(true, schema...) +} + +// GetSlave acts like function Slave but with additional parameter specifying +// the schema for the connection. It is defined for internal usage. +// Also see Slave. +func (c *Core) GetSlave(schema ...string) (*sql.DB, error) { + return c.getSqlDb(false, schema...) +} + +// QuoteWord checks given string a word, if true quotes it with security chars of the database +// and returns the quoted string; or else return without any change. +func (c *Core) QuoteWord(s string) string { + charLeft, charRight := c.DB.GetChars() + return doQuoteWord(s, charLeft, charRight) +} + +// QuoteString quotes string with quote chars. Strings like: +// "user", "user u", "user,user_detail", "user u, user_detail ut", "u.id asc". +func (c *Core) QuoteString(s string) string { + charLeft, charRight := c.DB.GetChars() + return doQuoteString(s, charLeft, charRight) +} + +// GetChars returns the security char for current database. +// It does nothing in default. +func (c *Core) GetChars() (charLeft string, charRight string) { + return "", "" +} + +// HandleSqlBeforeExec handles the sql before posts it to database. +// It does nothing in default. +func (c *Core) HandleSqlBeforeExec(sql string) string { + return sql +} diff --git a/database/gdb/gdb_mssql.go b/database/gdb/gdb_driver_mssql.go similarity index 82% rename from database/gdb/gdb_mssql.go rename to database/gdb/gdb_driver_mssql.go index f3d27bb25..441f65da8 100644 --- a/database/gdb/gdb_mssql.go +++ b/database/gdb/gdb_driver_mssql.go @@ -5,7 +5,7 @@ // You can obtain one at https://github.com/gogf/gf. // // Note: -// 1. It needs manually import: _ "github.com/lib/pq" +// 1. It needs manually import: _ "github.com/denisenkom/go-mssqldb" // 2. It does not support Save/Replace features. // 3. It does not support LastInsertId. @@ -22,12 +22,20 @@ import ( "github.com/gogf/gf/text/gregex" ) -type dbMssql struct { - *dbBase +// DriverMssql is the driver for SQL server database. +type DriverMssql struct { + *Core +} + +// New creates and returns a database object for SQL server. +func (d *DriverMssql) New(core *Core, node *ConfigNode) (DB, error) { + return &DriverMssql{ + Core: core, + }, nil } // Open creates and returns a underlying sql.DB object for mssql. -func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) { +func (d *DriverMssql) Open(config *ConfigNode) (*sql.DB, error) { source := "" if config.LinkInfo != "" { source = config.LinkInfo @@ -45,13 +53,13 @@ func (db *dbMssql) Open(config *ConfigNode) (*sql.DB, error) { } } -// getChars returns the security char for this type of database. -func (db *dbMssql) getChars() (charLeft string, charRight string) { +// GetChars returns the security char for this type of database. +func (d *DriverMssql) GetChars() (charLeft string, charRight string) { return "\"", "\"" } -// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver. -func (db *dbMssql) handleSqlBeforeExec(query string) string { +// HandleSqlBeforeExec deals with the sql string before commits it to underlying sql driver. +func (d *DriverMssql) HandleSqlBeforeExec(query string) string { var index int // Convert place holder char '?' to string "@px". str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string { @@ -59,10 +67,10 @@ func (db *dbMssql) handleSqlBeforeExec(query string) string { return fmt.Sprintf("@p%d", index) }) str, _ = gregex.ReplaceString("\"", "", str) - return db.parseSql(str) + return d.parseSql(str) } -func (db *dbMssql) parseSql(sql string) string { +func (d *DriverMssql) parseSql(sql string) string { // SELECT * FROM USER WHERE ID=1 LIMIT 1 if m, _ := gregex.MatchString(`^SELECT(.+)LIMIT 1$`, sql); len(m) > 1 { return fmt.Sprintf(`SELECT TOP 1 %s`, m[1]) @@ -163,14 +171,14 @@ func (db *dbMssql) parseSql(sql string) string { } // Tables retrieves and returns the tables of current schema. -func (db *dbMssql) Tables(schema ...string) (tables []string, err error) { +func (d *DriverMssql) Tables(schema ...string) (tables []string, err error) { var result Result - link, err := db.getSlave(schema...) + link, err := d.DB.GetSlave(schema...) if err != nil { return nil, err } - result, err = db.doGetAll(link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`) + result, err = d.DB.DoGetAll(link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`) if err != nil { return } @@ -183,24 +191,24 @@ func (db *dbMssql) Tables(schema ...string) (tables []string, err error) { } // TableFields retrieves and returns the fields information of specified table of current schema. -func (db *dbMssql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverMssql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { table = gstr.Trim(table) if gstr.Contains(table, " ") { panic("function TableFields supports only single table operations") } - checkSchema := db.schema.Val() + checkSchema := d.DB.GetSchema() if len(schema) > 0 && schema[0] != "" { checkSchema = schema[0] } - v := db.cache.GetOrSetFunc( + v := d.DB.GetCache().GetOrSetFunc( fmt.Sprintf(`mssql_table_fields_%s_%s`, table, checkSchema), func() interface{} { var result Result var link *sql.DB - link, err = db.getSlave(checkSchema) + link, err = d.DB.GetSlave(checkSchema) if err != nil { return nil } - result, err = db.doGetAll(link, fmt.Sprintf(` + result, err = d.DB.DoGetAll(link, fmt.Sprintf(` SELECT c.name as FIELD, CASE t.name WHEN 'numeric' THEN t.name + '(' + convert(varchar(20),c.xprec) + ',' + convert(varchar(20),c.xscale) + ')' WHEN 'char' THEN t.name + '(' + convert(varchar(20),c.length)+ ')' diff --git a/database/gdb/gdb_mysql.go b/database/gdb/gdb_driver_mysql.go similarity index 67% rename from database/gdb/gdb_mysql.go rename to database/gdb/gdb_driver_mysql.go index f4cfd45db..de5f0b44c 100644 --- a/database/gdb/gdb_mysql.go +++ b/database/gdb/gdb_driver_mysql.go @@ -12,15 +12,23 @@ import ( "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/text/gstr" - _ "github.com/gf-third/mysql" + _ "github.com/go-sql-driver/mysql" ) -type dbMysql struct { - *dbBase +// DriverMysql is the driver for mysql database. +type DriverMysql struct { + *Core +} + +// New creates and returns a database object for mysql. +func (d *DriverMysql) New(core *Core, node *ConfigNode) (DB, error) { + return &DriverMysql{ + Core: core, + }, nil } // Open creates and returns a underlying sql.DB object for mysql. -func (db *dbMysql) Open(config *ConfigNode) (*sql.DB, error) { +func (d *DriverMysql) Open(config *ConfigNode) (*sql.DB, error) { var source string if config.LinkInfo != "" { source = config.LinkInfo @@ -31,31 +39,31 @@ func (db *dbMysql) Open(config *ConfigNode) (*sql.DB, error) { ) } intlog.Printf("Open: %s", source) - if db, err := sql.Open("gf-mysql", source); err == nil { + if db, err := sql.Open("mysql", source); err == nil { return db, nil } else { return nil, err } } -// getChars returns the security char for this type of database. -func (db *dbMysql) getChars() (charLeft string, charRight string) { +// GetChars returns the security char for this type of database. +func (d *DriverMysql) GetChars() (charLeft string, charRight string) { return "`", "`" } -// handleSqlBeforeExec handles the sql before posts it to database. -func (db *dbMysql) handleSqlBeforeExec(sql string) string { +// HandleSqlBeforeExec handles the sql before posts it to database. +func (d *DriverMysql) HandleSqlBeforeExec(sql string) string { return sql } // Tables retrieves and returns the tables of current schema. -func (bs *dbBase) Tables(schema ...string) (tables []string, err error) { +func (d *DriverMysql) Tables(schema ...string) (tables []string, err error) { var result Result - link, err := bs.db.getSlave(schema...) + link, err := d.DB.GetSlave(schema...) if err != nil { return nil, err } - result, err = bs.db.doGetAll(link, `SHOW TABLES`) + result, err = d.DB.DoGetAll(link, `SHOW TABLES`) if err != nil { return } @@ -73,27 +81,27 @@ func (bs *dbBase) Tables(schema ...string) (tables []string, err error) { // As a map is unsorted, the TableField struct has a "Index" field marks its sequence in the fields. // // It's using cache feature to enhance the performance, which is never expired util the process restarts. -func (bs *dbBase) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverMysql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { table = gstr.Trim(table) if gstr.Contains(table, " ") { panic("function TableFields supports only single table operations") } - checkSchema := bs.schema.Val() + checkSchema := d.schema.Val() if len(schema) > 0 && schema[0] != "" { checkSchema = schema[0] } - v := bs.cache.GetOrSetFunc( + v := d.cache.GetOrSetFunc( fmt.Sprintf(`mysql_table_fields_%s_%s`, table, checkSchema), func() interface{} { var result Result var link *sql.DB - link, err = bs.db.getSlave(checkSchema) + link, err = d.DB.GetSlave(checkSchema) if err != nil { return nil } - result, err = bs.doGetAll( + result, err = d.DB.DoGetAll( link, - fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, bs.db.quoteWord(table)), + fmt.Sprintf(`SHOW FULL COLUMNS FROM %s`, d.DB.QuoteWord(table)), ) if err != nil { return nil diff --git a/database/gdb/gdb_oracle.go b/database/gdb/gdb_driver_oracle.go similarity index 82% rename from database/gdb/gdb_oracle.go rename to database/gdb/gdb_driver_oracle.go index 686dce1fa..de8936322 100644 --- a/database/gdb/gdb_oracle.go +++ b/database/gdb/gdb_driver_oracle.go @@ -24,8 +24,9 @@ import ( "github.com/gogf/gf/text/gregex" ) -type dbOracle struct { - *dbBase +// DriverOracle is the driver for oracle database. +type DriverOracle struct { + *Core } const ( @@ -33,8 +34,15 @@ const ( tableAlias2 = "GFORM2" ) +// New creates and returns a database object for oracle. +func (d *DriverOracle) New(core *Core, node *ConfigNode) (DB, error) { + return &DriverOracle{ + Core: core, + }, nil +} + // Open creates and returns a underlying sql.DB object for oracle. -func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) { +func (d *DriverOracle) Open(config *ConfigNode) (*sql.DB, error) { var source string if config.LinkInfo != "" { source = config.LinkInfo @@ -49,13 +57,13 @@ func (db *dbOracle) Open(config *ConfigNode) (*sql.DB, error) { } } -// getChars returns the security char for this type of database. -func (db *dbOracle) getChars() (charLeft string, charRight string) { +// GetChars returns the security char for this type of database. +func (d *DriverOracle) GetChars() (charLeft string, charRight string) { return "\"", "\"" } -// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver. -func (db *dbOracle) handleSqlBeforeExec(query string) string { +// HandleSqlBeforeExec deals with the sql string before commits it to underlying sql driver. +func (d *DriverOracle) HandleSqlBeforeExec(query string) string { var index int // Convert place holder char '?' to string ":x". str, _ := gregex.ReplaceStringFunc("\\?", query, func(s string) string { @@ -63,10 +71,10 @@ func (db *dbOracle) handleSqlBeforeExec(query string) string { return fmt.Sprintf(":%d", index) }) str, _ = gregex.ReplaceString("\"", "", str) - return db.parseSql(str) + return d.parseSql(str) } -func (db *dbOracle) parseSql(sql string) string { +func (d *DriverOracle) parseSql(sql string) string { patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))` if gregex.IsMatchString(patten, sql) == false { return sql @@ -124,9 +132,9 @@ func (db *dbOracle) parseSql(sql string) string { // Tables retrieves and returns the tables of current schema. // Note that it ignores the parameter in oracle database, as it is not necessary. -func (db *dbOracle) Tables(schema ...string) (tables []string, err error) { +func (d *DriverOracle) Tables(schema ...string) (tables []string, err error) { var result Result - result, err = db.doGetAll(nil, "SELECT TABLE_NAME FROM USER_TABLES ORDER BY TABLE_NAME") + result, err = d.DB.DoGetAll(nil, "SELECT TABLE_NAME FROM USER_TABLES ORDER BY TABLE_NAME") if err != nil { return } @@ -139,20 +147,20 @@ func (db *dbOracle) Tables(schema ...string) (tables []string, err error) { } // TableFields retrieves and returns the fields information of specified table of current schema. -func (db *dbOracle) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverOracle) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { table = gstr.Trim(table) if gstr.Contains(table, " ") { panic("function TableFields supports only single table operations") } - checkSchema := db.schema.Val() + checkSchema := d.DB.GetSchema() if len(schema) > 0 && schema[0] != "" { checkSchema = schema[0] } - v := db.cache.GetOrSetFunc( + v := d.DB.GetCache().GetOrSetFunc( fmt.Sprintf(`oracle_table_fields_%s_%s`, table, checkSchema), func() interface{} { result := (Result)(nil) - result, err = db.GetAll(fmt.Sprintf(` + result, err = d.DB.GetAll(fmt.Sprintf(` SELECT COLUMN_NAME AS FIELD, CASE DATA_TYPE WHEN 'NUMBER' THEN DATA_TYPE||'('||DATA_PRECISION||','||DATA_SCALE||')' WHEN 'FLOAT' THEN DATA_TYPE||'('||DATA_PRECISION||','||DATA_SCALE||')' @@ -177,11 +185,11 @@ func (db *dbOracle) TableFields(table string, schema ...string) (fields map[stri return } -func (db *dbOracle) getTableUniqueIndex(table string) (fields map[string]map[string]string, err error) { +func (d *DriverOracle) getTableUniqueIndex(table string) (fields map[string]map[string]string, err error) { table = strings.ToUpper(table) - v := db.cache.GetOrSetFunc("table_unique_index_"+table, func() interface{} { + v := d.DB.GetCache().GetOrSetFunc("table_unique_index_"+table, func() interface{} { res := (Result)(nil) - res, err = db.GetAll(fmt.Sprintf(` + res, err = d.DB.GetAll(fmt.Sprintf(` SELECT INDEX_NAME,COLUMN_NAME,CHAR_LENGTH FROM USER_IND_COLUMNS WHERE TABLE_NAME = '%s' AND INDEX_NAME IN(SELECT INDEX_NAME FROM USER_INDEXES WHERE TABLE_NAME='%s' AND UNIQUENESS='UNIQUE') @@ -203,7 +211,7 @@ func (db *dbOracle) getTableUniqueIndex(table string) (fields map[string]map[str return } -func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) { +func (d *DriverOracle) DoInsert(link dbLink, table string, data interface{}, option int, batch ...int) (result sql.Result, err error) { var fields []string var values []string var params []interface{} @@ -218,7 +226,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option case reflect.Slice: fallthrough case reflect.Array: - return db.db.doBatchInsert(link, table, data, option, batch...) + return d.DB.DoBatchInsert(link, table, data, option, batch...) case reflect.Map: fallthrough case reflect.Struct: @@ -231,7 +239,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option indexMap := make(map[string]string) indexExists := false if option != gINSERT_OPTION_DEFAULT { - index, err := db.getTableUniqueIndex(table) + index, err := d.getTableUniqueIndex(table) if err != nil { return nil, err } @@ -253,7 +261,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option onStr := make([]string, 0) updateStr := make([]string, 0) - charL, charR := db.db.getChars() + charL, charR := d.DB.GetChars() for k, v := range dataMap { k = strings.ToUpper(k) @@ -279,7 +287,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option } if link == nil { - if link, err = db.db.Master(); err != nil { + if link, err = d.DB.Master(); err != nil { return nil, err } } @@ -294,9 +302,9 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option table, tableAlias1, strings.Join(subSqlStr, ","), tableAlias2, strings.Join(onStr, "AND"), strings.Join(updateStr, ","), strings.Join(fields, ","), strings.Join(values, ","), ) - return db.db.doExec(link, tmp, params...) + return d.DB.DoExec(link, tmp, params...) case gINSERT_OPTION_IGNORE: - return db.db.doExec(link, + return d.DB.DoExec(link, fmt.Sprintf( "INSERT /*+ IGNORE_ROW_ON_DUPKEY_INDEX(%s(%s)) */ INTO %s(%s) VALUES(%s)", table, strings.Join(indexs, ","), table, strings.Join(fields, ","), strings.Join(values, ","), @@ -305,7 +313,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option } } - return db.db.doExec( + return d.DB.DoExec( link, fmt.Sprintf( "INSERT INTO %s(%s) VALUES(%s)", @@ -314,7 +322,7 @@ func (db *dbOracle) doInsert(link dbLink, table string, data interface{}, option params...) } -func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { +func (d *DriverOracle) DoBatchInsert(link dbLink, table string, list interface{}, option int, batch ...int) (result sql.Result, err error) { var keys []string var values []string var params []interface{} @@ -357,7 +365,7 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o return result, errors.New("empty data list") } if link == nil { - if link, err = db.db.Master(); err != nil { + if link, err = d.DB.Master(); err != nil { return } } @@ -368,14 +376,14 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o holders = append(holders, "?") } batchResult := new(batchSqlResult) - charL, charR := db.db.getChars() + charL, charR := d.DB.GetChars() keyStr := charL + strings.Join(keys, charL+","+charR) + charR valueHolderStr := strings.Join(holders, ",") // 当操作类型非insert时调用单笔的insert功能 if option != gINSERT_OPTION_DEFAULT { for _, v := range listMap { - r, err := db.doInsert(link, table, v, option, 1) + r, err := d.DB.DoInsert(link, table, v, option, 1) if err != nil { return r, err } @@ -402,10 +410,9 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o params = append(params, listMap[i][k]) } values = append(values, valueHolderStr) - intoStr = append(intoStr, fmt.Sprintf(" INTO %s(%s) VALUES(%s) ", table, keyStr, valueHolderStr)) if len(intoStr) == batchNum { - r, err := db.db.doExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...) + r, err := d.DB.DoExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...) if err != nil { return r, err } @@ -421,7 +428,7 @@ func (db *dbOracle) doBatchInsert(link dbLink, table string, list interface{}, o } // 处理最后不构成指定批量的数据 if len(intoStr) > 0 { - r, err := db.db.doExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...) + r, err := d.DB.DoExec(link, fmt.Sprintf("INSERT ALL %s SELECT * FROM DUAL", strings.Join(intoStr, " ")), params...) if err != nil { return r, err } diff --git a/database/gdb/gdb_pgsql.go b/database/gdb/gdb_driver_pgsql.go similarity index 74% rename from database/gdb/gdb_pgsql.go rename to database/gdb/gdb_driver_pgsql.go index 24b9b3e38..61e15094f 100644 --- a/database/gdb/gdb_pgsql.go +++ b/database/gdb/gdb_driver_pgsql.go @@ -21,12 +21,20 @@ import ( "github.com/gogf/gf/text/gregex" ) -type dbPgsql struct { - *dbBase +// DriverPgsql is the driver for postgresql database. +type DriverPgsql struct { + *Core +} + +// New creates and returns a database object for postgresql. +func (d *DriverPgsql) New(core *Core, node *ConfigNode) (DB, error) { + return &DriverPgsql{ + Core: core, + }, nil } // Open creates and returns a underlying sql.DB object for pgsql. -func (db *dbPgsql) Open(config *ConfigNode) (*sql.DB, error) { +func (d *DriverPgsql) Open(config *ConfigNode) (*sql.DB, error) { var source string if config.LinkInfo != "" { source = config.LinkInfo @@ -44,13 +52,13 @@ func (db *dbPgsql) Open(config *ConfigNode) (*sql.DB, error) { } } -// getChars returns the security char for this type of database. -func (db *dbPgsql) getChars() (charLeft string, charRight string) { +// GetChars returns the security char for this type of database. +func (d *DriverPgsql) GetChars() (charLeft string, charRight string) { return "\"", "\"" } -// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver. -func (db *dbPgsql) handleSqlBeforeExec(sql string) string { +// HandleSqlBeforeExec deals with the sql string before commits it to underlying sql driver. +func (d *DriverPgsql) HandleSqlBeforeExec(sql string) string { var index int // Convert place holder char '?' to string "$x". sql, _ = gregex.ReplaceStringFunc("\\?", sql, func(s string) string { @@ -62,9 +70,9 @@ func (db *dbPgsql) handleSqlBeforeExec(sql string) string { } // Tables retrieves and returns the tables of current schema. -func (db *dbPgsql) Tables(schema ...string) (tables []string, err error) { +func (d *DriverPgsql) Tables(schema ...string) (tables []string, err error) { var result Result - link, err := db.getSlave(schema...) + link, err := d.DB.GetSlave(schema...) if err != nil { return nil, err } @@ -73,7 +81,7 @@ func (db *dbPgsql) Tables(schema ...string) (tables []string, err error) { if len(schema) > 0 && schema[0] != "" { query = fmt.Sprintf("SELECT TABLENAME FROM PG_TABLES WHERE SCHEMANAME = '%s' ORDER BY TABLENAME", schema[0]) } - result, err = db.doGetAll(link, query) + result, err = d.DB.DoGetAll(link, query) if err != nil { return } @@ -86,25 +94,25 @@ func (db *dbPgsql) Tables(schema ...string) (tables []string, err error) { } // TableFields retrieves and returns the fields information of specified table of current schema. -func (db *dbPgsql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverPgsql) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { table = gstr.Trim(table) if gstr.Contains(table, " ") { panic("function TableFields supports only single table operations") } table, _ = gregex.ReplaceString("\"", "", table) - checkSchema := db.schema.Val() + checkSchema := d.DB.GetSchema() if len(schema) > 0 && schema[0] != "" { checkSchema = schema[0] } - v := db.cache.GetOrSetFunc( + v := d.DB.GetCache().GetOrSetFunc( fmt.Sprintf(`pgsql_table_fields_%s_%s`, table, checkSchema), func() interface{} { var result Result var link *sql.DB - link, err = db.getSlave(checkSchema) + link, err = d.DB.GetSlave(checkSchema) if err != nil { return nil } - result, err = db.doGetAll(link, fmt.Sprintf(` + result, err = d.DB.DoGetAll(link, fmt.Sprintf(` SELECT a.attname AS field, t.typname AS type FROM pg_class c, pg_attribute a LEFT OUTER JOIN pg_description b ON a.attrelid=b.objoid AND a.attnum = b.objsubid,pg_type t WHERE c.relname = '%s' and a.attnum > 0 and a.attrelid = c.oid and a.atttypid = t.oid diff --git a/database/gdb/gdb_sqlite.go b/database/gdb/gdb_driver_sqlite.go similarity index 64% rename from database/gdb/gdb_sqlite.go rename to database/gdb/gdb_driver_sqlite.go index 62dd3a4a2..659ae47d3 100644 --- a/database/gdb/gdb_sqlite.go +++ b/database/gdb/gdb_driver_sqlite.go @@ -16,12 +16,20 @@ import ( "github.com/gogf/gf/text/gstr" ) -type dbSqlite struct { - *dbBase +// DriverSqlite is the driver for sqlite database. +type DriverSqlite struct { + *Core +} + +// New creates and returns a database object for sqlite. +func (d *DriverSqlite) New(core *Core, node *ConfigNode) (DB, error) { + return &DriverSqlite{ + Core: core, + }, nil } // Open creates and returns a underlying sql.DB object for sqlite. -func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) { +func (d *DriverSqlite) Open(config *ConfigNode) (*sql.DB, error) { var source string if config.LinkInfo != "" { source = config.LinkInfo @@ -36,20 +44,20 @@ func (db *dbSqlite) Open(config *ConfigNode) (*sql.DB, error) { } } -// getChars returns the security char for this type of database. -func (db *dbSqlite) getChars() (charLeft string, charRight string) { +// GetChars returns the security char for this type of database. +func (d *DriverSqlite) GetChars() (charLeft string, charRight string) { return "`", "`" } // Tables retrieves and returns the tables of current schema. // TODO -func (db *dbSqlite) Tables(schema ...string) (tables []string, err error) { +func (d *DriverSqlite) Tables(schema ...string) (tables []string, err error) { return } // TableFields retrieves and returns the fields information of specified table of current schema. // TODO -func (db *dbSqlite) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { +func (d *DriverSqlite) TableFields(table string, schema ...string) (fields map[string]*TableField, err error) { table = gstr.Trim(table) if gstr.Contains(table, " ") { panic("function TableFields supports only single table operations") @@ -57,9 +65,9 @@ func (db *dbSqlite) TableFields(table string, schema ...string) (fields map[stri return } -// handleSqlBeforeExec deals with the sql string before commits it to underlying sql driver. +// HandleSqlBeforeExec deals with the sql string before commits it to underlying sql driver. // @todo 需要增加对Save方法的支持,可使用正则来实现替换, // @todo 将ON DUPLICATE KEY UPDATE触发器修改为两条SQL语句(INSERT OR IGNORE & UPDATE) -func (db *dbSqlite) handleSqlBeforeExec(sql string) string { +func (d *DriverSqlite) HandleSqlBeforeExec(sql string) string { return sql } diff --git a/database/gdb/gdb_func.go b/database/gdb/gdb_func.go index fa9ad177d..deb94c8be 100644 --- a/database/gdb/gdb_func.go +++ b/database/gdb/gdb_func.go @@ -269,7 +269,7 @@ func formatWhereInterfaces(db DB, where []interface{}, buffer *bytes.Buffer, new // formatWhereKeyValue handles each key-value pair of the parameter map. func formatWhereKeyValue(db DB, buffer *bytes.Buffer, newArgs []interface{}, key string, value interface{}) []interface{} { - key = db.quoteWord(key) + key = db.QuoteWord(key) if buffer.Len() > 0 { buffer.WriteString(" AND ") } diff --git a/database/gdb/gdb_model.go b/database/gdb/gdb_model.go index aa7bf739c..c5612f57f 100644 --- a/database/gdb/gdb_model.go +++ b/database/gdb/gdb_model.go @@ -68,10 +68,10 @@ const ( // Table creates and returns a new ORM model from given schema. // The parameter can be more than one table names, like : // "user", "user u", "user, user_detail", "user u, user_detail ud" -func (bs *dbBase) Table(table string) *Model { - table = bs.db.handleTableName(table) +func (c *Core) Table(table string) *Model { + table = c.DB.handleTableName(table) return &Model{ - db: bs.db, + db: c.DB, tablesInit: table, tables: table, fields: "*", @@ -82,21 +82,21 @@ func (bs *dbBase) Table(table string) *Model { } } -// Model is alias of dbBase.Table. -// See dbBase.Table. -func (bs *dbBase) Model(table string) *Model { - return bs.db.Table(table) +// Model is alias of Core.Table. +// See Core.Table. +func (c *Core) Model(table string) *Model { + return c.DB.Table(table) } -// From is alias of dbBase.Table. -// See dbBase.Table. +// From is alias of Core.Table. +// See Core.Table. // Deprecated. -func (bs *dbBase) From(table string) *Model { - return bs.db.Table(table) +func (c *Core) From(table string) *Model { + return c.DB.Table(table) } -// Table acts like dbBase.Table except it operates on transaction. -// See dbBase.Table. +// Table acts like Core.Table except it operates on transaction. +// See Core.Table. func (tx *TX) Table(table string) *Model { table = tx.db.handleTableName(table) return &Model{ @@ -403,7 +403,7 @@ func (m *Model) Or(where interface{}, args ...interface{}) *Model { // Group sets the "GROUP BY" statement for the model. func (m *Model) Group(groupBy string) *Model { model := m.getModel() - model.groupBy = m.db.quoteString(groupBy) + model.groupBy = m.db.QuoteString(groupBy) return model } @@ -417,7 +417,7 @@ func (m *Model) GroupBy(groupBy string) *Model { // Order sets the "ORDER BY" statement for the model. func (m *Model) Order(orderBy string) *Model { model := m.getModel() - model.orderBy = m.db.quoteString(orderBy) + model.orderBy = m.db.QuoteString(orderBy) return model } @@ -586,7 +586,7 @@ func (m *Model) doInsertWithOption(option int, data ...interface{}) (result sql. if m.batch > 0 { batch = m.batch } - return m.db.doBatchInsert( + return m.db.DoBatchInsert( m.getLink(true), m.tables, m.filterDataForInsertOrUpdate(list), @@ -595,7 +595,7 @@ func (m *Model) doInsertWithOption(option int, data ...interface{}) (result sql. ) } else if data, ok := m.data.(Map); ok { // Single insert. - return m.db.doInsert( + return m.db.DoInsert( m.getLink(true), m.tables, m.filterDataForInsertOrUpdate(data), @@ -626,7 +626,7 @@ func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) { if m.batch > 0 { batch = m.batch } - return m.db.doBatchInsert( + return m.db.DoBatchInsert( m.getLink(true), m.tables, m.filterDataForInsertOrUpdate(list), @@ -635,7 +635,7 @@ func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) { ) } else if data, ok := m.data.(Map); ok { // Single insert. - return m.db.doInsert( + return m.db.DoInsert( m.getLink(true), m.tables, m.filterDataForInsertOrUpdate(data), @@ -669,7 +669,7 @@ func (m *Model) Save(data ...interface{}) (result sql.Result, err error) { if m.batch > 0 { batch = m.batch } - return m.db.doBatchInsert( + return m.db.DoBatchInsert( m.getLink(true), m.tables, m.filterDataForInsertOrUpdate(list), @@ -678,7 +678,7 @@ func (m *Model) Save(data ...interface{}) (result sql.Result, err error) { ) } else if data, ok := m.data.(Map); ok { // Single save. - return m.db.doInsert( + return m.db.DoInsert( m.getLink(true), m.tables, m.filterDataForInsertOrUpdate(data), @@ -712,7 +712,7 @@ func (m *Model) Update(dataAndWhere ...interface{}) (result sql.Result, err erro return nil, errors.New("updating table with empty data") } condition, conditionArgs := m.formatCondition(false) - return m.db.doUpdate( + return m.db.DoUpdate( m.getLink(true), m.tables, m.filterDataForInsertOrUpdate(m.data), @@ -734,7 +734,7 @@ func (m *Model) Delete(where ...interface{}) (result sql.Result, err error) { } }() condition, conditionArgs := m.formatCondition(false) - return m.db.doDelete(m.getLink(true), m.tables, condition, conditionArgs...) + return m.db.DoDelete(m.getLink(true), m.tables, condition, conditionArgs...) } // Select is alias of Model.All. @@ -1059,10 +1059,16 @@ func (m *Model) getLink(master bool) dbLink { } switch linkType { case gLINK_TYPE_MASTER: - link, _ := m.db.getMaster(m.schema) + link, err := m.db.GetMaster(m.schema) + if err != nil { + panic(err) + } return link case gLINK_TYPE_SLAVE: - link, _ := m.db.getSlave(m.schema) + link, err := m.db.GetSlave(m.schema) + if err != nil { + panic(err) + } return link } return nil @@ -1077,17 +1083,17 @@ func (m *Model) getAll(query string, args ...interface{}) (result Result, err er if len(cacheKey) == 0 { cacheKey = query + "/" + gconv.String(args) } - if v := m.db.getCache().Get(cacheKey); v != nil { + if v := m.db.GetCache().Get(cacheKey); v != nil { return v.(Result), nil } } - result, err = m.db.doGetAll(m.getLink(false), query, args...) + result, err = m.db.DoGetAll(m.getLink(false), query, args...) // Cache the result. if len(cacheKey) > 0 && err == nil { if m.cacheDuration < 0 { - m.db.getCache().Remove(cacheKey) + m.db.GetCache().Remove(cacheKey) } else { - m.db.getCache().Set(cacheKey, result, m.cacheDuration) + m.db.GetCache().Set(cacheKey, result, m.cacheDuration) } } return result, err @@ -1113,7 +1119,7 @@ func (m *Model) getPrimaryKey() string { // checkAndRemoveCache checks and remove the cache if necessary. func (m *Model) checkAndRemoveCache() { if m.cacheEnabled && m.cacheDuration < 0 && len(m.cacheName) > 0 { - m.db.getCache().Remove(m.cacheName) + m.db.GetCache().Remove(m.cacheName) } } diff --git a/database/gdb/gdb_schema.go b/database/gdb/gdb_schema.go index 9d4631168..875869a55 100644 --- a/database/gdb/gdb_schema.go +++ b/database/gdb/gdb_schema.go @@ -14,9 +14,9 @@ type Schema struct { } // Schema creates and returns a schema. -func (bs *dbBase) Schema(schema string) *Schema { +func (c *Core) Schema(schema string) *Schema { return &Schema{ - db: bs.db, + db: c.DB, schema: schema, } } @@ -44,8 +44,8 @@ func (s *Schema) Table(table string) *Model { return m } -// Model is alias of dbBase.Table. -// See dbBase.Table. +// Model is alias of Core.Table. +// See Core.Table. func (s *Schema) Model(table string) *Model { return s.Table(table) } diff --git a/database/gdb/gdb_structure.go b/database/gdb/gdb_structure.go index 3a251fa97..dc7c23154 100644 --- a/database/gdb/gdb_structure.go +++ b/database/gdb/gdb_structure.go @@ -21,7 +21,7 @@ import ( // convertValue automatically checks and converts field value from database type // to golang variable type. -func (bs *dbBase) convertValue(fieldValue []byte, fieldType string) interface{} { +func (c *Core) convertValue(fieldValue []byte, fieldType string) interface{} { t, _ := gregex.ReplaceString(`\(.+\)`, "", fieldType) t = strings.ToLower(t) switch t { @@ -106,10 +106,10 @@ func (bs *dbBase) convertValue(fieldValue []byte, fieldType string) interface{} } // filterFields removes all key-value pairs which are not the field of given table. -func (bs *dbBase) filterFields(schema, table string, data map[string]interface{}) map[string]interface{} { +func (c *Core) filterFields(schema, table string, data map[string]interface{}) map[string]interface{} { // It must use data copy here to avoid its changing the origin data map. newDataMap := make(map[string]interface{}, len(data)) - if fields, err := bs.db.TableFields(table, schema); err == nil { + if fields, err := c.DB.TableFields(table, schema); err == nil { for k, v := range data { if _, ok := fields[k]; ok { newDataMap[k] = v diff --git a/database/gdb/gdb_transaction.go b/database/gdb/gdb_transaction.go index d270ab108..c717d2d91 100644 --- a/database/gdb/gdb_transaction.go +++ b/database/gdb/gdb_transaction.go @@ -32,15 +32,15 @@ func (tx *TX) Rollback() error { } // Query does query operation on transaction. -// See dbBase.Query. +// See Core.Query. func (tx *TX) Query(query string, args ...interface{}) (rows *sql.Rows, err error) { - return tx.db.doQuery(tx.tx, query, args...) + return tx.db.DoQuery(tx.tx, query, args...) } // Exec does none query operation on transaction. -// See dbBase.Exec. +// See Core.Exec. func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) { - return tx.db.doExec(tx.tx, query, args...) + return tx.db.DoExec(tx.tx, query, args...) } // Prepare creates a prepared statement for later queries or executions. @@ -49,7 +49,7 @@ func (tx *TX) Exec(query string, args ...interface{}) (sql.Result, error) { // The caller must call the statement's Close method // when the statement is no longer needed. func (tx *TX) Prepare(query string) (*sql.Stmt, error) { - return tx.db.doPrepare(tx.tx, query) + return tx.db.DoPrepare(tx.tx, query) } // GetAll queries and returns data records from database. @@ -154,7 +154,7 @@ func (tx *TX) GetCount(query string, args ...interface{}) (int, error) { // // The parameter specifies the batch operation count when given data is slice. func (tx *TX) Insert(table string, data interface{}, batch ...int) (sql.Result, error) { - return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_DEFAULT, batch...) + return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_DEFAULT, batch...) } // InsertIgnore does "INSERT IGNORE INTO ..." statement for the table. @@ -167,7 +167,7 @@ func (tx *TX) Insert(table string, data interface{}, batch ...int) (sql.Result, // // The parameter specifies the batch operation count when given data is slice. func (tx *TX) InsertIgnore(table string, data interface{}, batch ...int) (sql.Result, error) { - return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_IGNORE, batch...) + return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_IGNORE, batch...) } // Replace does "REPLACE INTO ..." statement for the table. @@ -183,7 +183,7 @@ func (tx *TX) InsertIgnore(table string, data interface{}, batch ...int) (sql.Re // If given data is type of slice, it then does batch replacing, and the optional parameter // specifies the batch operation count. func (tx *TX) Replace(table string, data interface{}, batch ...int) (sql.Result, error) { - return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_REPLACE, batch...) + return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_REPLACE, batch...) } // Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the table. @@ -198,31 +198,31 @@ func (tx *TX) Replace(table string, data interface{}, batch ...int) (sql.Result, // If given data is type of slice, it then does batch saving, and the optional parameter // specifies the batch operation count. func (tx *TX) Save(table string, data interface{}, batch ...int) (sql.Result, error) { - return tx.db.doInsert(tx.tx, table, data, gINSERT_OPTION_SAVE, batch...) + return tx.db.DoInsert(tx.tx, table, data, gINSERT_OPTION_SAVE, batch...) } // BatchInsert batch inserts data. // The parameter must be type of slice of map or struct. func (tx *TX) BatchInsert(table string, list interface{}, batch ...int) (sql.Result, error) { - return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_DEFAULT, batch...) + return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_DEFAULT, batch...) } // BatchInsert batch inserts data with ignore option. // The parameter must be type of slice of map or struct. func (tx *TX) BatchInsertIgnore(table string, list interface{}, batch ...int) (sql.Result, error) { - return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_IGNORE, batch...) + return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_IGNORE, batch...) } // BatchReplace batch replaces data. // The parameter must be type of slice of map or struct. func (tx *TX) BatchReplace(table string, list interface{}, batch ...int) (sql.Result, error) { - return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_REPLACE, batch...) + return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_REPLACE, batch...) } // BatchSave batch replaces data. // The parameter must be type of slice of map or struct. func (tx *TX) BatchSave(table string, list interface{}, batch ...int) (sql.Result, error) { - return tx.db.doBatchInsert(tx.tx, table, list, gINSERT_OPTION_SAVE, batch...) + return tx.db.DoBatchInsert(tx.tx, table, list, gINSERT_OPTION_SAVE, batch...) } // Update does "UPDATE ... " statement for the table. @@ -244,7 +244,7 @@ func (tx *TX) Update(table string, data interface{}, condition interface{}, args if newWhere != "" { newWhere = " WHERE " + newWhere } - return tx.db.doUpdate(tx.tx, table, data, newWhere, newArgs...) + return tx.db.DoUpdate(tx.tx, table, data, newWhere, newArgs...) } // Delete does "DELETE FROM ... " statement for the table. @@ -263,5 +263,5 @@ func (tx *TX) Delete(table string, condition interface{}, args ...interface{}) ( if newWhere != "" { newWhere = " WHERE " + newWhere } - return tx.db.doDelete(tx.tx, table, newWhere, newArgs...) + return tx.db.DoDelete(tx.tx, table, newWhere, newArgs...) } diff --git a/database/gredis/gredis.go b/database/gredis/gredis.go index 02c75cf4d..6bd31378a 100644 --- a/database/gredis/gredis.go +++ b/database/gredis/gredis.go @@ -54,9 +54,9 @@ type PoolStats struct { } const ( - gDEFAULT_POOL_IDLE_TIMEOUT = 60 * time.Second + gDEFAULT_POOL_IDLE_TIMEOUT = 30 * time.Second gDEFAULT_POOL_CONN_TIMEOUT = 10 * time.Second - gDEFAULT_POOL_MAX_LIFE_TIME = 60 * time.Second + gDEFAULT_POOL_MAX_LIFE_TIME = 30 * time.Second ) var ( @@ -80,6 +80,7 @@ func New(config Config) *Redis { config: config, pool: pools.GetOrSetFuncLock(fmt.Sprintf("%v", config), func() interface{} { return &redis.Pool{ + Wait: true, IdleTimeout: config.IdleTimeout, MaxActive: config.MaxActive, MaxIdle: config.MaxIdle, diff --git a/go.mod b/go.mod index 4c37a2c14..d82b1ab16 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,8 @@ require ( github.com/clbanning/mxj v1.8.4 github.com/fatih/structs v1.1.0 github.com/fsnotify/fsnotify v1.4.7 - github.com/gf-third/mysql v1.4.2 github.com/gf-third/yaml v1.0.1 + github.com/go-sql-driver/mysql v1.5.0 github.com/gomodule/redigo v2.0.0+incompatible github.com/google/uuid v1.1.1 github.com/gorilla/websocket v1.4.1 @@ -17,5 +17,4 @@ require ( github.com/olekukonko/tablewriter v0.0.1 golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e // indirect golang.org/x/text v0.3.2 - google.golang.org/appengine v1.6.5 // indirect ) From 8e40cded42efac18c71b20d501cf46778323be5e Mon Sep 17 00:00:00 2001 From: John Date: Sun, 8 Mar 2020 11:03:18 +0800 Subject: [PATCH 13/26] add custom driver feature for package gdb --- .../gdb/driver/{ => driver}/driver.go | 4 +- .example/database/gdb/driver/main.go | 1 + database/gdb/gdb.go | 3 +- database/gdb/gdb_core.go | 85 ++++++++------- database/gdb/gdb_core_utility.go | 38 ++++++- database/gdb/gdb_driver_mssql.go | 2 + database/gdb/gdb_driver_mysql.go | 2 + database/gdb/gdb_driver_oracle.go | 8 +- database/gdb/gdb_driver_pgsql.go | 2 + database/gdb/gdb_driver_sqlite.go | 2 + database/gdb/gdb_func.go | 100 +++++++++--------- database/gdb/gdb_model.go | 14 +-- database/gdb/gdb_unit_z_driver_test.go | 88 +++++++++++++++ 13 files changed, 244 insertions(+), 105 deletions(-) rename .example/database/gdb/driver/{ => driver}/driver.go (97%) create mode 100644 .example/database/gdb/driver/main.go create mode 100644 database/gdb/gdb_unit_z_driver_test.go diff --git a/.example/database/gdb/driver/driver.go b/.example/database/gdb/driver/driver/driver.go similarity index 97% rename from .example/database/gdb/driver/driver.go rename to .example/database/gdb/driver/driver/driver.go index 0d0599c9c..e253f2bff 100644 --- a/.example/database/gdb/driver/driver.go +++ b/.example/database/gdb/driver/driver/driver.go @@ -12,8 +12,6 @@ import ( "github.com/gogf/gf/database/gdb" "github.com/gogf/gf/internal/intlog" "github.com/gogf/gf/text/gstr" - - _ "github.com/gf-third/mysql" ) type MyDriver struct { @@ -32,7 +30,7 @@ func (d *MyDriver) Open(config *gdb.ConfigNode) (*sql.DB, error) { ) } intlog.Printf("Open: %s", source) - if db, err := sql.Open("gf-mysql", source); err == nil { + if db, err := sql.Open("mysql", source); err == nil { return db, nil } else { return nil, err diff --git a/.example/database/gdb/driver/main.go b/.example/database/gdb/driver/main.go new file mode 100644 index 000000000..06ab7d0f9 --- /dev/null +++ b/.example/database/gdb/driver/main.go @@ -0,0 +1 @@ +package main diff --git a/database/gdb/gdb.go b/database/gdb/gdb.go index cc3e1b429..6496c7e10 100644 --- a/database/gdb/gdb.go +++ b/database/gdb/gdb.go @@ -104,7 +104,7 @@ type DB interface { TableFields(table string, schema ...string) (map[string]*TableField, error) // Internal methods. - handleTableName(table string) string + QuotePrefixTableName(table string) string filterFields(schema, table string, data map[string]interface{}) map[string]interface{} convertValue(fieldValue []byte, fieldType string) interface{} rowsToResult(rows *sql.Rows) (Result, error) @@ -126,6 +126,7 @@ type Core struct { // Driver is the interface for integrating sql drivers into package gdb. type Driver interface { + // New creates and returns a database object for specified database server. New(core *Core, node *ConfigNode) (DB, error) } diff --git a/database/gdb/gdb_core.go b/database/gdb/gdb_core.go index 599d12051..7ac95a739 100644 --- a/database/gdb/gdb_core.go +++ b/database/gdb/gdb_core.go @@ -70,7 +70,7 @@ func (c *Core) DoQuery(link dbLink, query string, args ...interface{}) (rows *sq Start: mTime1, End: mTime2, } - c.printSql(s) + c.writeSqlToLogger(s) } else { rows, err = link.Query(query, args...) } @@ -109,7 +109,7 @@ func (c *Core) DoExec(link dbLink, query string, args ...interface{}) (result sq Start: mTime1, End: mTime2, } - c.printSql(s) + c.writeSqlToLogger(s) } else { result, err = link.Exec(query, args...) } @@ -350,6 +350,11 @@ func (c *Core) Save(table string, data interface{}, batch ...int) (sql.Result, e // doInsert inserts or updates data for given table. // +// The parameter can be type of map/gmap/struct/*struct/[]map/[]struct, etc. +// Eg: +// Data(g.Map{"uid": 10000, "name":"john"}) +// Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}) +// // The parameter