improve middleware feature for ghttp.Server

This commit is contained in:
John
2019-12-04 10:03:03 +08:00
parent 890865251b
commit a06ca31530
14 changed files with 300 additions and 185 deletions

View File

@ -11,6 +11,7 @@ import (
"bytes"
"fmt"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
@ -244,3 +245,21 @@ func CallerFileLineShort() string {
_, path, line := Caller()
return fmt.Sprintf(`%s:%d`, filepath.Base(path), line)
}
// FuncPath returns the complete function path of given <f>.
func FuncPath(f interface{}) string {
return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
}
// FuncName returns the function name of given <f>.
func FuncName(f interface{}) string {
path := FuncPath(f)
if path == "" {
return ""
}
index := strings.LastIndexByte(path, '/')
if index < 0 {
index = strings.LastIndexByte(path, '\\')
}
return path[index+1:]
}

View File

@ -30,7 +30,6 @@ type Request struct {
LeaveTime int64 // Request ending time in microseconds.
Middleware *Middleware // The middleware manager.
handlers []*handlerParsedItem // All matched handlers containing handler, hook and middleware for this request .
handlerIndex int // Index number for executing sequence purpose of handlers.
hasHookHandler bool // A bool marking whether there's hook handler in the handlers for performance purpose.
hasServeHandler bool // A bool marking whether there's serving handler in the handlers for performance purpose.
parsedQuery bool // A bool marking whether the GET parameters parsed.
@ -125,7 +124,7 @@ func (r *Request) IsAjaxRequest() bool {
return strings.EqualFold(r.Header.Get("X-Requested-With"), "XMLHttpRequest")
}
// GetClientIp returns the client ip of this request.
// GetClientIp returns the client ip of this request without port.
func (r *Request) GetClientIp() string {
if len(r.clientIp) == 0 {
if r.clientIp = r.Header.Get("X-Real-IP"); r.clientIp == "" {

View File

@ -17,23 +17,25 @@ import (
// Middleware is the plugin for request workflow management.
type Middleware struct {
served bool // Is the request served, which is used for checking response status 404.
request *Request // The request object pointer.
served bool // Is the request served, which is used for checking response status 404.
request *Request // The request object pointer.
handlerIndex int // Index number for executing sequence purpose for handler items.
handlerMDIndex int // Index number for executing sequence purpose for bound middleware of handler item.
}
// Next calls the next workflow handler.
func (m *Middleware) Next() {
item := (*handlerParsedItem)(nil)
loop := true
var item *handlerParsedItem
var loop = true
for loop {
// Check whether the request is exited.
if m.request.IsExited() || m.request.handlerIndex >= len(m.request.handlers) {
if m.request.IsExited() || m.handlerIndex >= len(m.request.handlers) {
break
}
item = m.request.handlers[m.request.handlerIndex]
m.request.handlerIndex++
item = m.request.handlers[m.handlerIndex]
// Filter the HOOK handlers, which are designed to be called in another standalone procedure.
if item.handler.itemType == gHANDLER_TYPE_HOOK {
m.handlerIndex++
continue
}
// Router values switching.
@ -42,7 +44,20 @@ func (m *Middleware) Next() {
m.request.Router = item.handler.router
gutil.TryCatch(func() {
// Execute bound middleware array of the item if it's not empty.
if m.handlerMDIndex < len(item.handler.middleware) {
md := item.handler.middleware[m.handlerMDIndex]
m.handlerMDIndex++
niceCallFunc(func() {
md(m.request)
})
loop = false
return
}
m.handlerIndex++
switch item.handler.itemType {
// Service controller.
case gHANDLER_TYPE_CONTROLLER:
m.served = true
if m.request.IsExited() {
@ -63,6 +78,7 @@ func (m *Middleware) Next() {
})
}
// Service object.
case gHANDLER_TYPE_OBJECT:
m.served = true
if m.request.IsExited() {
@ -84,6 +100,7 @@ func (m *Middleware) Next() {
})
}
// Service handler.
case gHANDLER_TYPE_HANDLER:
m.served = true
if m.request.IsExited() {
@ -93,6 +110,7 @@ func (m *Middleware) Next() {
item.handler.itemFunc(m.request)
})
// Global middleware array.
case gHANDLER_TYPE_MIDDLEWARE:
niceCallFunc(func() {
item.handler.itemFunc(m.request)
@ -107,11 +125,13 @@ func (m *Middleware) Next() {
})
}
// Check the http status code after all handler and middleware done.
if m.request.Response.Status == 0 {
if m.request.Middleware.served || m.request.Response.buffer.Len() > 0 {
m.request.Response.WriteHeader(http.StatusOK)
} else {
m.request.Response.WriteHeader(http.StatusNotFound)
if m.request.IsExited() || m.handlerIndex >= len(m.request.handlers) {
if m.request.Response.Status == 0 {
if m.request.Middleware.served {
m.request.Response.WriteHeader(http.StatusOK)
} else {
m.request.Response.WriteHeader(http.StatusNotFound)
}
}
}
}

View File

@ -10,6 +10,7 @@ import (
"bytes"
"errors"
"fmt"
"github.com/gogf/gf/debug/gdebug"
"net/http"
"os"
"reflect"
@ -62,15 +63,16 @@ type (
// 路由函数注册信息
handlerItem struct {
itemId int // 用于标识该注册函数的唯一性ID
itemName string // 注册的函数名称信息(用于路由信息打印)
itemType int // 注册函数类型(对象/函数/控制器/中间件/钩子函数)
itemFunc HandlerFunc // 函数内存地址(与以上两个参数二选一)
initFunc HandlerFunc // 初始化请求回调函数(对象注册方式下有效)
shutFunc HandlerFunc // 完成请求回调函数(对象注册方式下有效)
ctrlInfo *handlerController // 控制器服务函数反射信息
hookName string // 钩子类型名称(注册函数类型为钩子函数下有效)
router *Router // 注册时绑定的路由对象
itemId int // 用于标识该注册函数的唯一性ID
itemName string // 注册的函数名称信息(用于路由信息打印)
itemType int // 注册函数类型(对象/函数/控制器/中间件/钩子函数)
itemFunc HandlerFunc // 函数内存地址(与以上两个参数二选一)
initFunc HandlerFunc // 初始化请求回调函数(对象注册方式下有效)
shutFunc HandlerFunc // 完成请求回调函数(对象注册方式下有效)
middleware []HandlerFunc // 绑定的中间件列表
ctrlInfo *handlerController // 控制器服务函数反射信息
hookName string // 钩子类型名称(注册函数类型为钩子函数下有效)
router *Router // 注册时绑定的路由对象
}
// 根据特定URL.Path解析后的路由检索结果项
@ -317,12 +319,13 @@ func (s *Server) Start() error {
// 打印展示路由表
func (s *Server) DumpRoutesMap() {
if s.config.DumpRouteMap && len(s.routesMap) > 0 {
glog.Header(false).Println(fmt.Sprintf("\n%s", s.GetRouteMap()))
glog.Header(false).Println(fmt.Sprintf("\n%s", s.getRouteMapString()))
}
}
// 获得路由表(格式化字符串)
func (s *Server) GetRouteMap() string {
func (s *Server) getRouteMapString() string {
// Route table for dumping.
type tableItem struct {
middleware string
domain string
@ -342,12 +345,11 @@ func (s *Server) GetRouteMap() string {
tablewriter.ALIGN_CENTER,
tablewriter.ALIGN_CENTER,
tablewriter.ALIGN_CENTER,
tablewriter.ALIGN_LEFT,
tablewriter.ALIGN_CENTER,
tablewriter.ALIGN_CENTER,
tablewriter.ALIGN_LEFT,
tablewriter.ALIGN_LEFT,
tablewriter.ALIGN_LEFT,
tablewriter.ALIGN_CENTER,
})
m := make(map[string]*garray.SortedArray)
@ -363,10 +365,20 @@ func (s *Server) GetRouteMap() string {
priority: len(registeredItems) - index - 1,
}
if item.handler.itemType == gHANDLER_TYPE_MIDDLEWARE {
item.middleware = "MIDDLEWARE"
item.middleware = "GLOBAL MIDDLEWARE"
}
if len(item.handler.middleware) > 0 {
for _, v := range item.handler.middleware {
if item.middleware != "" {
item.middleware += ","
}
item.middleware += gdebug.FuncName(v)
}
}
// If the domain does not exist in the dump map, it create the map.
// The value of the map is a custom sorted array.
if _, ok := m[item.domain]; !ok {
// 注意排序函数的逻辑,从小到达排序
// Sort in ASC order.
m[item.domain] = garray.NewSortedArraySize(100, func(v1, v2 interface{}) int {
item1 := v1.(*tableItem)
item2 := v2.(*tableItem)

View File

@ -35,6 +35,12 @@ func (d *Domain) BindHandler(pattern string, handler HandlerFunc) {
}
}
func (d *Domain) doBindHandler(pattern string, handler HandlerFunc, middleware []HandlerFunc) {
for domain, _ := range d.m {
d.s.doBindHandler(pattern+"@"+domain, handler, middleware)
}
}
// 执行对象方法
func (d *Domain) BindObject(pattern string, obj interface{}, methods ...string) {
for domain, _ := range d.m {
@ -42,6 +48,12 @@ func (d *Domain) BindObject(pattern string, obj interface{}, methods ...string)
}
}
func (d *Domain) doBindObject(pattern string, obj interface{}, methods string, middleware []HandlerFunc) {
for domain, _ := range d.m {
d.s.doBindObject(pattern+"@"+domain, obj, methods, middleware)
}
}
// 执行对象方法注册methods参数不区分大小写
func (d *Domain) BindObjectMethod(pattern string, obj interface{}, method string) {
for domain, _ := range d.m {
@ -49,6 +61,12 @@ func (d *Domain) BindObjectMethod(pattern string, obj interface{}, method string
}
}
func (d *Domain) doBindObjectMethod(pattern string, obj interface{}, method string, middleware []HandlerFunc) {
for domain, _ := range d.m {
d.s.doBindObjectMethod(pattern+"@"+domain, obj, method, middleware)
}
}
// RESTful执行对象注册
func (d *Domain) BindObjectRest(pattern string, obj interface{}) {
for domain, _ := range d.m {
@ -56,6 +74,12 @@ func (d *Domain) BindObjectRest(pattern string, obj interface{}) {
}
}
func (d *Domain) doBindObjectRest(pattern string, obj interface{}, middleware []HandlerFunc) {
for domain, _ := range d.m {
d.s.doBindObjectRest(pattern+"@"+domain, obj, middleware)
}
}
// 控制器注册
func (d *Domain) BindController(pattern string, c Controller, methods ...string) {
for domain, _ := range d.m {
@ -63,6 +87,12 @@ func (d *Domain) BindController(pattern string, c Controller, methods ...string)
}
}
func (d *Domain) doBindController(pattern string, c Controller, methods string, middleware []HandlerFunc) {
for domain, _ := range d.m {
d.s.doBindController(pattern+"@"+domain, c, methods, middleware)
}
}
// 控制器方法注册methods参数区分大小写
func (d *Domain) BindControllerMethod(pattern string, c Controller, method string) {
for domain, _ := range d.m {
@ -70,6 +100,12 @@ func (d *Domain) BindControllerMethod(pattern string, c Controller, method strin
}
}
func (d *Domain) doBindControllerMethod(pattern string, c Controller, method string, middleware []HandlerFunc) {
for domain, _ := range d.m {
d.s.doBindControllerMethod(pattern+"@"+domain, c, method, middleware)
}
}
// RESTful控制器注册
func (d *Domain) BindControllerRest(pattern string, c Controller) {
for domain, _ := range d.m {
@ -77,6 +113,12 @@ func (d *Domain) BindControllerRest(pattern string, c Controller) {
}
}
func (d *Domain) doBindControllerRest(pattern string, c Controller, middleware []HandlerFunc) {
for domain, _ := range d.m {
d.s.doBindControllerRest(pattern+"@"+domain, c, middleware)
}
}
// 绑定指定的hook回调函数, hook参数的值由ghttp server设定参数不区分大小写
// 目前hook支持Init/Shut
func (d *Domain) BindHookHandler(pattern string, hook string, handler HandlerFunc) {

View File

@ -18,10 +18,11 @@ import (
// 分组路由对象
type RouterGroup struct {
parent *RouterGroup // 父级分组路由
server *Server // Server
domain *Domain // Domain
prefix string // URI前缀
parent *RouterGroup // 父级分组路由
server *Server // Server
domain *Domain // Domain
prefix string // URI前缀
middleware []HandlerFunc // 分组路由绑定的中间件
}
// 分组路由批量绑定项
@ -44,6 +45,7 @@ var (
// 处理预绑定路由项
func (s *Server) handlePreBindItems() {
for _, item := range preBindItems {
// Handle the items of current server.
if item.group.server != nil && item.group.server != s {
continue
}
@ -62,16 +64,16 @@ func (s *Server) Group(prefix string, groups ...func(group *RouterGroup)) *Route
if prefix == "/" {
prefix = ""
}
rg := &RouterGroup{
group := &RouterGroup{
server: s,
prefix: prefix,
}
if len(groups) > 0 {
for _, v := range groups {
v(rg)
v(group)
}
}
return rg
return group
}
// 获取分组路由对象(绑定域名)
@ -82,16 +84,16 @@ func (d *Domain) Group(prefix string, groups ...func(group *RouterGroup)) *Route
if prefix == "/" {
prefix = ""
}
rg := &RouterGroup{
group := &RouterGroup{
domain: d,
prefix: prefix,
}
if len(groups) > 0 {
for _, v := range groups {
v(rg)
v(group)
}
}
return rg
return group
}
// 层级递归创建分组路由注册项
@ -99,27 +101,34 @@ func (g *RouterGroup) Group(prefix string, groups ...func(group *RouterGroup)) *
if prefix == "/" {
prefix = ""
}
rg := &RouterGroup{
group := &RouterGroup{
parent: g,
server: g.server,
domain: g.domain,
prefix: prefix,
}
if len(g.middleware) > 0 {
group.middleware = make([]HandlerFunc, len(g.middleware))
copy(group.middleware, g.middleware)
}
if len(groups) > 0 {
for _, v := range groups {
v(rg)
v(group)
}
}
return rg
return group
}
func (g *RouterGroup) Clone() *RouterGroup {
return &RouterGroup{
parent: g.parent,
server: g.server,
domain: g.domain,
prefix: g.prefix,
newGroup := &RouterGroup{
parent: g.parent,
server: g.server,
domain: g.domain,
prefix: g.prefix,
middleware: make([]HandlerFunc, len(g.middleware)),
}
copy(newGroup.middleware, g.middleware)
return newGroup
}
// 执行分组路由批量绑定
@ -211,23 +220,8 @@ func (g *RouterGroup) Hook(pattern string, hook string, handler HandlerFunc) *Ro
}
func (g *RouterGroup) Middleware(handlers ...HandlerFunc) *RouterGroup {
group := g.Clone()
for _, handler := range handlers {
if gstr.Contains(g.prefix, "*") {
group.preBind("MIDDLEWARE", "/", handler)
} else {
group.preBind("MIDDLEWARE", "/*", handler)
}
}
return group
}
func (g *RouterGroup) MiddlewarePattern(pattern string, handlers ...HandlerFunc) *RouterGroup {
group := g.Clone()
for _, handler := range handlers {
group.preBind("MIDDLEWARE", pattern, handler)
}
return group
g.middleware = append(g.middleware, handlers...)
return g
}
func (g *RouterGroup) preBind(bindType string, pattern string, object interface{}, params ...interface{}) *RouterGroup {
@ -279,64 +273,54 @@ func (g *RouterGroup) doBind(bindType string, pattern string, object interface{}
bindType = "HOOK"
}
switch bindType {
case "MIDDLEWARE":
if h, ok := object.(HandlerFunc); ok {
if g.server != nil {
g.server.BindMiddleware(pattern, h)
} else {
g.domain.BindMiddleware(pattern, h)
}
} else {
glog.Fatalf("invalid middleware handler for pattern:%s", pattern)
}
case "HANDLER":
if h, ok := object.(HandlerFunc); ok {
if g.server != nil {
g.server.BindHandler(pattern, h)
g.server.doBindHandler(pattern, h, g.middleware)
} else {
g.domain.BindHandler(pattern, h)
g.domain.doBindHandler(pattern, h, g.middleware)
}
} else if g.isController(object) {
if len(extras) > 0 {
if g.server != nil {
g.server.BindControllerMethod(pattern, object.(Controller), extras[0])
g.server.doBindControllerMethod(pattern, object.(Controller), extras[0], g.middleware)
} else {
g.domain.BindControllerMethod(pattern, object.(Controller), extras[0])
g.domain.doBindControllerMethod(pattern, object.(Controller), extras[0], g.middleware)
}
} else {
if g.server != nil {
g.server.BindController(pattern, object.(Controller))
g.server.doBindController(pattern, object.(Controller), "", g.middleware)
} else {
g.domain.BindController(pattern, object.(Controller))
g.domain.doBindController(pattern, object.(Controller), "", g.middleware)
}
}
} else {
if len(extras) > 0 {
if g.server != nil {
g.server.BindObjectMethod(pattern, object, extras[0])
g.server.doBindObjectMethod(pattern, object, extras[0], g.middleware)
} else {
g.domain.BindObjectMethod(pattern, object, extras[0])
g.domain.doBindObjectMethod(pattern, object, extras[0], g.middleware)
}
} else {
if g.server != nil {
g.server.BindObject(pattern, object)
g.server.doBindObject(pattern, object, "", g.middleware)
} else {
g.domain.BindObject(pattern, object)
g.domain.doBindObject(pattern, object, "", g.middleware)
}
}
}
case "REST":
if g.isController(object) {
if g.server != nil {
g.server.BindControllerRest(pattern, object.(Controller))
g.server.doBindControllerRest(pattern, object.(Controller), g.middleware)
} else {
g.domain.BindControllerRest(pattern, object.(Controller))
g.domain.doBindControllerRest(pattern, object.(Controller), g.middleware)
}
} else {
if g.server != nil {
g.server.BindObjectRest(pattern, object)
g.server.doBindObjectRest(pattern, object, g.middleware)
} else {
g.domain.BindObjectRest(pattern, object)
g.domain.doBindObjectRest(pattern, object, g.middleware)
}
}
case "HOOK":
@ -365,9 +349,12 @@ func (g *RouterGroup) isController(value interface{}) bool {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.FieldByName("Request").IsValid() && v.FieldByName("Response").IsValid() &&
v.FieldByName("Server").IsValid() && v.FieldByName("Cookie").IsValid() &&
v.FieldByName("Session").IsValid() && v.FieldByName("View").IsValid() {
if v.FieldByName("Request").IsValid() &&
v.FieldByName("Response").IsValid() &&
v.FieldByName("Server").IsValid() &&
v.FieldByName("Cookie").IsValid() &&
v.FieldByName("Session").IsValid() &&
v.FieldByName("View").IsValid() {
return true
}
return false

View File

@ -7,16 +7,15 @@
package ghttp
import (
"github.com/gogf/gf/debug/gdebug"
"net/http"
"reflect"
"runtime"
)
// 绑定指定的hook回调函数, pattern参数同BindHandler支持命名路由hook参数的值由ghttp server设定参数不区分大小写
func (s *Server) BindHookHandler(pattern string, hook string, handler HandlerFunc) {
s.setHandler(pattern, &handlerItem{
itemType: gHANDLER_TYPE_HOOK,
itemName: runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name(),
itemName: gdebug.FuncPath(handler),
itemFunc: handler,
hookName: hook,
})

View File

@ -7,8 +7,7 @@
package ghttp
import (
"reflect"
"runtime"
"github.com/gogf/gf/debug/gdebug"
)
const (
@ -20,7 +19,7 @@ func (s *Server) BindMiddleware(pattern string, handlers ...HandlerFunc) {
for _, handler := range handlers {
s.setHandler(pattern, &handlerItem{
itemType: gHANDLER_TYPE_MIDDLEWARE,
itemName: runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name(),
itemName: gdebug.FuncPath(handler),
itemFunc: handler,
})
}
@ -31,7 +30,7 @@ func (s *Server) BindMiddlewareDefault(handlers ...HandlerFunc) {
for _, handler := range handlers {
s.setHandler(gDEFAULT_MIDDLEWARE_PATTERN, &handlerItem{
itemType: gHANDLER_TYPE_MIDDLEWARE,
itemName: runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name(),
itemName: gdebug.FuncPath(handler),
itemFunc: handler,
})
}

View File

@ -117,7 +117,7 @@ func (s *Server) searchHandlers(method, path, domain string) (parsedItems []*han
// 注意当不带任何动态路由规则时len(match) == 1
if match, err := gregex.MatchString(item.router.RegRule, path); err == nil && len(match) > 0 {
parsedItem := &handlerParsedItem{item, nil}
// 如果需要query匹配那么需要重新正则解析URL
// 如果需要路由规则中带有URI名称匹配那么需要重新正则解析URL
if len(item.router.RegNames) > 0 {
if len(match) > len(item.router.RegNames) {
parsedItem.values = make(map[string]string)

View File

@ -20,7 +20,36 @@ import (
// 绑定控制器,控制器需要实现 gmvc.Controller 接口,
// 这种方式绑定的控制器每一次请求都会初始化一个新的控制器对象进行处理,对应不同的请求会话,
// 第三个参数methods用以指定需要注册的方法支持多个方法名称多个方法以英文“,”号分隔,区分大小写.
func (s *Server) BindController(pattern string, c Controller, methods ...string) {
func (s *Server) BindController(pattern string, controller Controller, method ...string) {
bindMethod := ""
if len(method) > 0 {
bindMethod = method[0]
}
s.doBindController(pattern, controller, bindMethod, nil)
}
// 绑定路由到指定的方法执行, 第三个参数method仅支持一个方法注册不支持多个并且区分大小写。
func (s *Server) BindControllerMethod(pattern string, controller Controller, method string) {
s.doBindControllerMethod(pattern, controller, method, nil)
}
// 绑定控制器(RESTFul)控制器需要实现gmvc.Controller接口
// 方法会识别HTTP方法并做REST绑定处理例如Post方法会绑定到HTTP POST的方法请求处理Delete方法会绑定到HTTP DELETE的方法请求处理
// 因此只会绑定HTTP Method对应的方法其他方法不会自动注册绑定
// 这种方式绑定的控制器每一次请求都会初始化一个新的控制器对象进行处理,对应不同的请求会话
func (s *Server) BindControllerRest(pattern string, controller Controller) {
s.doBindControllerRest(pattern, controller, nil)
}
func (s *Server) doBindController(pattern string, controller Controller, method string, middleware []HandlerFunc) {
// Convert input method to map for convenience and high performance searching.
var methodMap map[string]bool
if len(method) > 0 {
methodMap = make(map[string]bool)
for _, v := range strings.Split(method, ",") {
methodMap[strings.TrimSpace(v)] = true
}
}
// 当pattern中的method为all时去掉该method以便于后续方法判断
domain, method, path, err := s.parsePattern(pattern)
if err != nil {
@ -30,17 +59,9 @@ func (s *Server) BindController(pattern string, c Controller, methods ...string)
if strings.EqualFold(method, gDEFAULT_METHOD) {
pattern = s.serveHandlerKey("", path, domain)
}
methodMap := (map[string]bool)(nil)
if len(methods) > 0 {
methodMap = make(map[string]bool)
for _, v := range strings.Split(methods[0], ",") {
methodMap[strings.TrimSpace(v)] = true
}
}
// 遍历控制器获取方法列表并构造成uri
m := make(handlerMap)
v := reflect.ValueOf(c)
v := reflect.ValueOf(controller)
t := v.Type()
sname := t.Elem().Name()
pkgPath := t.Elem().PkgPath()
@ -77,6 +98,7 @@ func (s *Server) BindController(pattern string, c Controller, methods ...string)
name: mname,
reflect: v.Elem().Type(),
},
middleware: middleware,
}
// 如果方法中带有Index方法那么额外自动增加一个路由规则匹配主URI
// 例如: pattern为/user, 那么会同时注册/user及/user/index
@ -95,16 +117,16 @@ func (s *Server) BindController(pattern string, c Controller, methods ...string)
name: mname,
reflect: v.Elem().Type(),
},
middleware: middleware,
}
}
}
s.bindHandlerByMap(m)
}
// 绑定路由到指定的方法执行, 第三个参数method仅支持一个方法注册不支持多个并且区分大小写。
func (s *Server) BindControllerMethod(pattern string, c Controller, method string) {
func (s *Server) doBindControllerMethod(pattern string, controller Controller, method string, middleware []HandlerFunc) {
m := make(handlerMap)
v := reflect.ValueOf(c)
v := reflect.ValueOf(controller)
t := v.Type()
sname := t.Elem().Name()
mname := strings.TrimSpace(method)
@ -132,18 +154,15 @@ func (s *Server) BindControllerMethod(pattern string, c Controller, method strin
name: mname,
reflect: v.Elem().Type(),
},
middleware: middleware,
}
s.bindHandlerByMap(m)
}
// 绑定控制器(RESTFul)控制器需要实现gmvc.Controller接口
// 方法会识别HTTP方法并做REST绑定处理例如Post方法会绑定到HTTP POST的方法请求处理Delete方法会绑定到HTTP DELETE的方法请求处理
// 因此只会绑定HTTP Method对应的方法其他方法不会自动注册绑定
// 这种方式绑定的控制器每一次请求都会初始化一个新的控制器对象进行处理,对应不同的请求会话
func (s *Server) BindControllerRest(pattern string, c Controller) {
func (s *Server) doBindControllerRest(pattern string, controller Controller, middleware []HandlerFunc) {
// 遍历控制器获取方法列表并构造成uri
m := make(handlerMap)
v := reflect.ValueOf(c)
v := reflect.ValueOf(controller)
t := v.Type()
sname := t.Elem().Name()
pkgPath := t.Elem().PkgPath()
@ -172,6 +191,7 @@ func (s *Server) BindControllerRest(pattern string, c Controller) {
name: mname,
reflect: v.Elem().Type(),
},
middleware: middleware,
}
}
s.bindHandlerByMap(m)

View File

@ -8,38 +8,33 @@ package ghttp
import (
"bytes"
"reflect"
"runtime"
"github.com/gogf/gf/debug/gdebug"
"strings"
"github.com/gogf/gf/os/glog"
"github.com/gogf/gf/text/gstr"
)
// 注意该方法是直接绑定函数的内存地址,执行的时候直接执行该方法,不会存在初始化新的控制器逻辑
func (s *Server) BindHandler(pattern string, handler HandlerFunc) {
s.bindHandlerItem(pattern, &handlerItem{
itemName: runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name(),
itemType: gHANDLER_TYPE_HANDLER,
itemFunc: handler,
})
s.doBindHandler(pattern, handler, nil)
}
// 绑定URI到操作函数/方法
// pattern的格式形如/user/list, put:/user, delete:/user, post:/user@johng.cn
// 支持RESTful的请求格式具体业务逻辑由绑定的处理方法来执行
func (s *Server) bindHandlerItem(pattern string, item *handlerItem) {
if s.Status() == SERVER_STATUS_RUNNING {
glog.Error("server handlers cannot be changed while running")
return
}
s.setHandler(pattern, item)
func (s *Server) doBindHandler(pattern string, handler HandlerFunc, middleware []HandlerFunc) {
s.setHandler(pattern, &handlerItem{
itemName: gdebug.FuncPath(handler),
itemType: gHANDLER_TYPE_HANDLER,
itemFunc: handler,
middleware: middleware,
})
}
// 通过映射数组绑定URI到操作函数/方法
func (s *Server) bindHandlerByMap(m handlerMap) {
for p, h := range m {
s.bindHandlerItem(p, h)
s.setHandler(p, h)
}
}
@ -48,8 +43,8 @@ func (s *Server) bindHandlerByMap(m handlerMap) {
// 规则2pattern中的URI包含{.method}关键字,则替换该关键字为方法名称;
// 规则2如果不满足规则1那么直接将防发明附加到pattern中的URI后面
func (s *Server) mergeBuildInNameToPattern(pattern string, structName, methodName string, allowAppend bool) string {
structName = s.nameToUrlPart(structName)
methodName = s.nameToUrlPart(methodName)
structName = s.nameToUri(structName)
methodName = s.nameToUri(methodName)
pattern = strings.Replace(pattern, "{.struct}", structName, -1)
if strings.Index(pattern, "{.method}") != -1 {
return strings.Replace(pattern, "{.method}", methodName, -1)
@ -75,7 +70,7 @@ func (s *Server) mergeBuildInNameToPattern(pattern string, structName, methodNam
// 规则1: 不处理名称以原有名称构建成URI
// 规则2: 仅转为小写,单词间不使用连接符号
// 规则3: 采用驼峰命名方式
func (s *Server) nameToUrlPart(name string) string {
func (s *Server) nameToUri(name string) string {
switch s.config.NameToUriType {
case URI_TYPE_FULLNAME:
return name

View File

@ -19,7 +19,35 @@ import (
// 绑定对象到URI请求处理中会自动识别方法名称并附加到对应的URI地址后面
// 第三个参数methods用以指定需要注册的方法支持多个方法名称多个方法以英文“,”号分隔,区分大小写
func (s *Server) BindObject(pattern string, obj interface{}, methods ...string) {
func (s *Server) BindObject(pattern string, object interface{}, method ...string) {
bindMethod := ""
if len(method) > 0 {
bindMethod = method[0]
}
s.doBindObject(pattern, object, bindMethod, nil)
}
// 绑定对象到URI请求处理中会自动识别方法名称并附加到对应的URI地址后面
// 第三个参数method仅支持一个方法注册不支持多个并且区分大小写。
func (s *Server) BindObjectMethod(pattern string, object interface{}, method string) {
s.doBindObjectMethod(pattern, object, method, nil)
}
// 绑定对象到URI请求处理中会自动识别方法名称并附加到对应的URI地址后面,
// 需要注意对象方法的定义必须按照 ghttp.HandlerFunc 来定义
func (s *Server) BindObjectRest(pattern string, object interface{}) {
s.doBindObjectRest(pattern, object, nil)
}
func (s *Server) doBindObject(pattern string, object interface{}, method string, middleware []HandlerFunc) {
// Convert input method to map for convenience and high performance searching.
var methodMap map[string]bool
if len(method) > 0 {
methodMap = make(map[string]bool)
for _, v := range strings.Split(method, ",") {
methodMap[strings.TrimSpace(v)] = true
}
}
// 当pattern中的method为all时去掉该method以便于后续方法判断
domain, method, path, err := s.parsePattern(pattern)
if err != nil {
@ -30,15 +58,8 @@ func (s *Server) BindObject(pattern string, obj interface{}, methods ...string)
pattern = s.serveHandlerKey("", path, domain)
}
methodMap := (map[string]bool)(nil)
if len(methods) > 0 {
methodMap = make(map[string]bool)
for _, v := range strings.Split(methods[0], ",") {
methodMap[strings.TrimSpace(v)] = true
}
}
m := make(handlerMap)
v := reflect.ValueOf(obj)
v := reflect.ValueOf(object)
t := v.Type()
sname := t.Elem().Name()
initFunc := (func(*Request))(nil)
@ -78,11 +99,12 @@ func (s *Server) BindObject(pattern string, obj interface{}, methods ...string)
}
key := s.mergeBuildInNameToPattern(pattern, sname, mname, true)
m[key] = &handlerItem{
itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname),
itemType: gHANDLER_TYPE_OBJECT,
itemFunc: itemFunc,
initFunc: initFunc,
shutFunc: shutFunc,
itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname),
itemType: gHANDLER_TYPE_OBJECT,
itemFunc: itemFunc,
initFunc: initFunc,
shutFunc: shutFunc,
middleware: middleware,
}
// 如果方法中带有Index方法那么额外自动增加一个路由规则匹配主URI。
// 注意当pattern带有内置变量时不会自动加该路由。
@ -93,11 +115,12 @@ func (s *Server) BindObject(pattern string, obj interface{}, methods ...string)
k = "/" + k
}
m[k] = &handlerItem{
itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname),
itemType: gHANDLER_TYPE_OBJECT,
itemFunc: itemFunc,
initFunc: initFunc,
shutFunc: shutFunc,
itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname),
itemType: gHANDLER_TYPE_OBJECT,
itemFunc: itemFunc,
initFunc: initFunc,
shutFunc: shutFunc,
middleware: middleware,
}
}
}
@ -106,9 +129,9 @@ func (s *Server) BindObject(pattern string, obj interface{}, methods ...string)
// 绑定对象到URI请求处理中会自动识别方法名称并附加到对应的URI地址后面
// 第三个参数method仅支持一个方法注册不支持多个并且区分大小写。
func (s *Server) BindObjectMethod(pattern string, obj interface{}, method string) {
func (s *Server) doBindObjectMethod(pattern string, object interface{}, method string, middleware []HandlerFunc) {
m := make(handlerMap)
v := reflect.ValueOf(obj)
v := reflect.ValueOf(object)
t := v.Type()
sname := t.Elem().Name()
mname := strings.TrimSpace(method)
@ -139,21 +162,20 @@ func (s *Server) BindObjectMethod(pattern string, obj interface{}, method string
}
key := s.mergeBuildInNameToPattern(pattern, sname, mname, false)
m[key] = &handlerItem{
itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname),
itemType: gHANDLER_TYPE_OBJECT,
itemFunc: itemFunc,
initFunc: initFunc,
shutFunc: shutFunc,
itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname),
itemType: gHANDLER_TYPE_OBJECT,
itemFunc: itemFunc,
initFunc: initFunc,
shutFunc: shutFunc,
middleware: middleware,
}
s.bindHandlerByMap(m)
}
// 绑定对象到URI请求处理中会自动识别方法名称并附加到对应的URI地址后面,
// 需要注意对象方法的定义必须按照 ghttp.HandlerFunc 来定义
func (s *Server) BindObjectRest(pattern string, obj interface{}) {
func (s *Server) doBindObjectRest(pattern string, object interface{}, middleware []HandlerFunc) {
m := make(handlerMap)
v := reflect.ValueOf(obj)
v := reflect.ValueOf(object)
t := v.Type()
sname := t.Elem().Name()
initFunc := (func(*Request))(nil)
@ -184,11 +206,12 @@ func (s *Server) BindObjectRest(pattern string, obj interface{}) {
}
key := s.mergeBuildInNameToPattern(mname+":"+pattern, sname, mname, false)
m[key] = &handlerItem{
itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname),
itemType: gHANDLER_TYPE_OBJECT,
itemFunc: itemFunc,
initFunc: initFunc,
shutFunc: shutFunc,
itemName: fmt.Sprintf(`%s.%s.%s`, pkgPath, objName, mname),
itemType: gHANDLER_TYPE_OBJECT,
itemFunc: itemFunc,
initFunc: initFunc,
shutFunc: shutFunc,
middleware: middleware,
}
}
s.bindHandlerByMap(m)

View File

@ -131,7 +131,7 @@ func Test_BindMiddleware_Basic3(t *testing.T) {
})
}
func Test_BindMiddleware_Must_Be_Called(t *testing.T) {
func Test_BindMiddleware_Basic4(t *testing.T) {
p := ports.PopRand()
s := g.Server(p)
s.Group("/", func(group *ghttp.RouterGroup) {
@ -157,9 +157,9 @@ func Test_BindMiddleware_Must_Be_Called(t *testing.T) {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
gtest.Assert(client.GetContent("/"), "12")
gtest.Assert(client.GetContent("/"), "Not Found")
gtest.Assert(client.GetContent("/test"), "1test2")
gtest.Assert(client.PutContent("/test/none"), "12")
gtest.Assert(client.PutContent("/test/none"), "Not Found")
})
}
@ -188,7 +188,7 @@ func Test_Middleware_With_Static(t *testing.T) {
gtest.Assert(client.GetContent("/"), "index")
gtest.Assert(client.GetContent("/test.html"), "test")
gtest.Assert(client.GetContent("/none"), "12")
gtest.Assert(client.GetContent("/none"), "Not Found")
gtest.Assert(client.GetContent("/user/list"), "1list2")
})
}
@ -214,7 +214,7 @@ func Test_Middleware_Status(t *testing.T) {
client := ghttp.NewClient()
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
gtest.Assert(client.GetContent("/"), "404")
gtest.Assert(client.GetContent("/"), "Not Found")
gtest.Assert(client.GetContent("/user/list"), "200")
resp, err := client.Get("/")
@ -268,7 +268,7 @@ func Test_Middleware_Hook_With_Static(t *testing.T) {
time.Sleep(100 * time.Millisecond)
gtest.Assert(a.Len(), 4)
gtest.Assert(client.GetContent("/none"), "a12b")
gtest.Assert(client.GetContent("/none"), "ab")
time.Sleep(100 * time.Millisecond)
gtest.Assert(a.Len(), 6)
@ -610,7 +610,7 @@ func Test_Middleware_CORSAndAuth(t *testing.T) {
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
gtest.Assert(client.GetContent("/"), "Not Found")
gtest.Assert(client.GetContent("/api.v2"), "Forbidden")
gtest.Assert(client.GetContent("/api.v2"), "Not Found")
gtest.Assert(client.GetContent("/api.v2/user/list"), "Forbidden")
gtest.Assert(client.GetContent("/api.v2/user/list", "token=123456"), "list")
})

View File

@ -67,17 +67,17 @@ func Test_Router_Group_Group(t *testing.T) {
client.SetPrefix(fmt.Sprintf("http://127.0.0.1:%d", p))
gtest.Assert(client.GetContent("/"), "Not Found")
gtest.Assert(client.GetContent("/api.v2"), "12")
gtest.Assert(client.GetContent("/api.v2"), "Not Found")
gtest.Assert(client.GetContent("/api.v2/test"), "1test2")
gtest.Assert(client.GetContent("/api.v2/hook"), "hook any12")
gtest.Assert(client.GetContent("/api.v2/hook/name"), "hook namehook any12")
gtest.Assert(client.GetContent("/api.v2/hook/name/any"), "hook any12")
gtest.Assert(client.GetContent("/api.v2/hook"), "hook any")
gtest.Assert(client.GetContent("/api.v2/hook/name"), "hook namehook any")
gtest.Assert(client.GetContent("/api.v2/hook/name/any"), "hook any")
gtest.Assert(client.GetContent("/api.v2/order/list"), "1list2")
gtest.Assert(client.GetContent("/api.v2/order/update"), "12")
gtest.Assert(client.GetContent("/api.v2/order/update"), "Not Found")
gtest.Assert(client.PutContent("/api.v2/order/update"), "1update2")
gtest.Assert(client.GetContent("/api.v2/user/drop"), "12")
gtest.Assert(client.GetContent("/api.v2/user/drop"), "Not Found")
gtest.Assert(client.DeleteContent("/api.v2/user/drop"), "1drop2")
gtest.Assert(client.GetContent("/api.v2/user/edit"), "12")
gtest.Assert(client.GetContent("/api.v2/user/edit"), "Not Found")
gtest.Assert(client.PostContent("/api.v2/user/edit"), "1edit2")
gtest.Assert(client.GetContent("/api.v2/user/info"), "1info2")
})