跳至主要內容

第三天:动态路由

程序员李某某原创Golang框架源码路由大约 9 分钟

第三天:动态路由

Tire 树介绍

之前,我们用了一个非常简单的map结构存储了路由表,使用map存储键值对,索引非常高效,但是有一个弊端,键值对的存储的方式,只能用来索引静态路由。那如果我们想支持类似于/hello/:name这样的动态路由怎么办呢?所谓动态路由,即一条路由规则可以匹配某一类型而非某一条固定的路由。例如/hello/:name,可以匹配/hello/li、 hello/wang等。

动态路由有很多种实现方式,支持的规则、性能等有很大的差异。例如

  • 开源的路由实现gorouter支持在路由规则中嵌入正则表达式,例如/p/[0-9A-Za-z]+,即路径中的参数仅匹配数字和字母;
  • 另一个开源实现httprouter就不支持正则表达式。
  • 著名的Web开源框架gin 在早期的版本,并没有实现自己的路由,而是直接使用了httprouter,后来不知道什么原因,放弃了httprouter,自己实现了一个版本。

实现动态路由最常用的数据结构,被称为前缀树(Trie树)。看到名字你大概也能知道前缀树长啥样了:每一个节点的所有的子节点都拥有相同的前缀。这种结构非常适用于路由匹配,比如我们定义了如下路由规则:

/:lang/doc /:lang/tutorial /:lang/intro /about /p/blog /p/related 我们用前缀树来表示,是这样的。 HTTP请求的路径恰好是由/分隔的多段构成的,因此,每一段可以作为前缀树的一个节点。我们通过树结构查询,如果中间某一层的节点都不满足条件,那么就说明没有匹配到的路由,查询结束。

接下来我们实现的动态路由具备以下两个功能。

参数匹配:。例如 /p/:lang/doc,可以匹配 /p/c/doc 和 /p/go/doc。 通配*。例如 /static/*filepath,可以匹配/static/fav.ico,也可以匹配/static/js/jQuery.js,这种模式常用于静态服务器,能够递归地匹配子路径。

Tire 树实现

  • 首先需要解析一下路径,用/分割

      func parsePath(path string) []string {
          split := strings.Split(path, "/")
    
          parts := make([]string, 0)
          if len(split) == 0 {
              return nil
          }
          for _, s := range split {
              if s != "" {
                  parts = append(parts, s)
              } else {
                  continue    // 第一个元素是空格,跳过
              }
              if s[0] == '*' {
                  break // *号结束
              }
          }
          return parts
      }
    
  • 我们的目标是,当调用e.Get("/sayHello", handleFunc)时,会将分割好的路径存在Tire树中,我们先回顾性调用关系

    • func (e *Engine) Get(path string, handleFunc HandleFunc)
    • func (r *router) addRoute(method string, path string, handler HandleFunc)
    • 就是这里了,解析已经有了,就剩下存储了
  • 设计下存储结构

    • 树结构,肯定有指向孩子节点的指针
    • 也肯定得有存储值的字段
    • 还需要一个字段标记这个节点是否是动态的
    • 还需要一个字段存储完整路径,用来获取处理函数的key
      type node struct {
          path     string  // 完整路径,只在最后一层存储,其他父节点为 nil
          part     string  // 当前节点存储的路由的一部分
          children []*node // 子节点
          isWild   bool    // 该路径片段是否是动态的,part 含有 : 或 * 时为 true
      }
    
  • 插入

    // 将路径和处理函数的map封装到router中,并增加存储 node 的字段
    type router struct {
          handles map[string]HandleFunc
          roots   map[string]*node
      }
    
      func (r *router) addRoute(method string, path string, handler HandleFunc) {
    
          parts := parsePath(path)
    
          if _, ok := r.roots[method]; !ok { // 每个请求方法构建一个树
              r.roots[method] = &node{}
          }
          if parts != nil {
              r.roots[method].insert(path, parts, 0)
          }
    
          key := method + "_" + path
          r.handles[key] = handler
      } 
    
      func (n *node) insert(path string, parts []string, height int) {
          if len(parts) == height { // 递归结束条件
              n.path = path // 最后一层的子节点,用于存储完整路径(因为中间不知哪个部分是动态的)
              return
          }
          // 取出当前层
          part := parts[height]
          child := n.matchChild(part) // 在现有树中查是否已经存在
          if child == nil {           // 没有就创建一个新的子节点
              child = &node{part: part, isWild: part[0] == ':' || part[0] == '*'}
              n.children = append(n.children, child)
          }
          child.insert(path, parts, height+1) // 通过这个子节点再从下一层匹配
      }
    
      func (n *node) matchChild(part string) *node {
          for _, child := range n.children {
              // 能完全匹配上就返回,匹配到任何动态的也都返回
              if child.part == part || child.isWild {
                  return child
              }
          }
          return nil
      }    
    
  • 插入写完啦,接下来就是获取了,我们在gin中获取路径参数是在context中获取的,所以我们需要给context接一个字段来存储,

      type Context struct {
          Writer     http.ResponseWriter
          Req        *http.Request
          StatusCode int
          Path       string
          Method     string
          Params     map[string]string
      }
    
  • 当触发func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request)时,我们就应该获取对应的参数,并且映射到对应的处理函数上

      func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    
          ctx := &Context{
              Req:    r,
              Writer: w,
              Path:   r.URL.Path,
              Method: r.Method,
          }
          // 获取对应的参数 和 包含映射路径的尾部节点
          tailNode, params := e.router.getRoute(ctx.Method, ctx.Path)
          if tailNode != nil {
              ctx.Params = params                 // 存到 context
              k := r.Method + "_" + tailNode.path // 处理函数的key
              if handleFunc, ok := e.router.handles[k]; ok {
                  handleFunc(ctx)
              } else {
                  ctx.String(http.StatusNotFound, "404 NOT FOUND: %s\n", ctx.Path)
              }
          }
      }  
    
    
      func (r *router) getRoute(method string, path string) (*node, map[string]string) {
          // 真实url切割
          searchParts := parsePath(path)
          params := make(map[string]string)
          // 对应请求方法的树
          root, ok := r.roots[method]
    
          if !ok {
              return nil, nil
          }
    
          // 树中匹配到的路径链表的最后一个节点
          n := root.search(searchParts, 0)
    
          if n != nil {
              // 控制层的动态路由url
              parts := parsePath(n.path)
              for index, part := range parts {
                  if part[0] == ':' {
                      // 封装动态路由参数的映射:k-控制层的动态路由去掉:后的, v-真实url中对应的值
                      params[part[1:]] = searchParts[index]
                  }
                  if part[0] == '*' && len(part) > 1 {
                      // 封装动态路由参数的映射:k-控制层的动态路由去掉:后的, v-真实url中对应的值(该部分及之后的路径)
                      params[part[1:]] = strings.Join(searchParts[index:], "/")
                      break
                  }
              }
              return n, params
          }
    
          return nil, nil
      }
    
      func (n *node) search(parts []string, height int) *node {
          // 递归结束条件:到最后一层的子节点了,或者遇到*开头的了
          if len(parts) == height || strings.HasPrefix(n.part, "*") {
              // 空串说明没有这个路径信息,url错误
              if n.path == "" {
                  return nil
              }
              return n
          }
    
          part := parts[height]
          children := n.matchChildren(part)
    
          for _, child := range children {
              // 接着匹配下一层
              result := child.search(parts, height+1)
              if result != nil {
                  return result
              }
          }
    
          return nil
      }
    
      func (n *node) matchChildren(part string) []*node {
          nodes := make([]*node, 0)
          for _, child := range n.children {
              // 能完全匹配上就返回,匹配到任何动态的也都返回
              if child.part == part || child.isWild {
                  nodes = append(nodes, child)
              }
          }
          return nodes
      }    
    
  • 为了方便读取值,给context绑定一个获取方法

      func (ctx *Context) Param(k string) string {
          return ctx.Params[k]
      }
    

文件拆分

随着功能的复杂,我们先拆分下文件

./lee
    - lee.go
    - rotuer.go
    - trie.go

lee.go

package lee

import (
	"encoding/json"
	"fmt"
	"net/http"
)

type HandleFunc func(*Context)
type H map[string]interface{}

type Engine struct {
	router *router
} // 结构体

func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {

	ctx := &Context{
		Req:    r,
		Writer: w,
		Path:   r.URL.Path,
		Method: r.Method,
	}
	tailNode, params := e.router.getRoute(ctx.Method, ctx.Path)
	if tailNode != nil {
		ctx.Params = params                 // 存到 context
		k := r.Method + "_" + tailNode.path // 处理函数的key
		if handleFunc, ok := e.router.handles[k]; ok {
			handleFunc(ctx)
		} else {
			ctx.String(http.StatusNotFound, "404 NOT FOUND: %s\n", ctx.Path)
		}
	}
}

func New() *Engine {
	return &Engine{router: newRouter()}
}

func (e *Engine) Get(path string, handleFunc HandleFunc) {
	e.router.addRoute("GET", path, handleFunc)
}

func (e *Engine) Post(path string, handleFunc HandleFunc) {
	e.router.addRoute("POST", path, handleFunc)
}

func (e *Engine) Run(addr string) (err error) {
	return http.ListenAndServe(addr, e)
}

type Context struct {
	Writer     http.ResponseWriter
	Req        *http.Request
	StatusCode int
	Path       string
	Method     string
	Params     map[string]string
}

func (ctx *Context) PostForm(k string) string {
	return ctx.Req.FormValue(k)
}

func (ctx *Context) Query(k string) string {
	return ctx.Req.URL.Query().Get(k)
}

func (ctx *Context) Status(code int) {
	ctx.StatusCode = code
	ctx.Writer.WriteHeader(code)
}

func (ctx *Context) SetHeader(k, v string) {
	ctx.Writer.Header().Set(k, v)
}

func (ctx *Context) JSON(code int, body interface{}) {
	ctx.SetHeader("Content-Type", "application/json")
	ctx.Status(code)
	if bytes, err := json.Marshal(body); err != nil {
		http.Error(ctx.Writer, err.Error(), 500)
	} else {
		_, err = ctx.Writer.Write(bytes)
		if err != nil {
			http.Error(ctx.Writer, err.Error(), 500)
		}
	}
}

func (ctx *Context) String(code int, format string, v ...interface{}) {
	ctx.SetHeader("Content-Type", "text/plain")
	ctx.Status(code)
	if _, err := ctx.Writer.Write([]byte(fmt.Sprintf(format, v...))); err != nil {
		http.Error(ctx.Writer, err.Error(), 500)
	}
}

rotuer.go

package lee

import "strings"

type router struct {
	handles map[string]HandleFunc
	roots   map[string]*node
}

func newRouter() *router {
	return &router{
		roots:   make(map[string]*node),
		handles: make(map[string]HandleFunc),
	}

}

func parsePath(path string) []string {
	split := strings.Split(path, "/")

	parts := make([]string, 0)
	if len(split) == 0 {
		return nil
	}
	for _, s := range split {
		if s != "" {
			parts = append(parts, s)
		} else {
			continue
		}
		if s[0] == '*' {
			break // *号结束
		}
	}
	return parts
}

func (r *router) addRoute(method string, path string, handler HandleFunc) {

	parts := parsePath(path)

	if _, ok := r.roots[method]; !ok { // 每个请求方法构建一个树
		r.roots[method] = &node{}
	}
	if parts != nil {
		r.roots[method].insert(path, parts, 0)
	}

	key := method + "_" + path
	r.handles[key] = handler
}

func (r *router) getRoute(method string, path string) (*node, map[string]string) {
	// 真实url切割
	searchParts := parsePath(path)
	params := make(map[string]string)
	// 对应请求方法的树
	root, ok := r.roots[method]

	if !ok {
		return nil, nil
	}

	// 树中匹配到的路径链表的最后一个节点
	n := root.search(searchParts, 0)

	if n != nil {
		// 控制层的动态路由url
		parts := parsePath(n.path)
		for index, part := range parts {
			if part[0] == ':' {
				// 封装动态路由参数的映射:k-控制层的动态路由去掉:后的, v-真实url中对应的值
				params[part[1:]] = searchParts[index]
			}
			if part[0] == '*' && len(part) > 1 {
				// 封装动态路由参数的映射:k-控制层的动态路由去掉:后的, v-真实url中对应的值(该部分及之后的路径)
				params[part[1:]] = strings.Join(searchParts[index:], "/")
				break
			}
		}
		return n, params
	}

	return nil, nil
}

trie.go

package lee

import "strings"

type node struct {
	path     string  // 完整路径,只在最后一层的子节点存储
	part     string  // 当前节点存储的路由的一部分
	children []*node // 子节点
	isWild   bool    // 该路径片段是否是动态的,part 含有 : 或 * 时为 true
}

func (n *node) matchChild(part string) *node {
	for _, child := range n.children {
		// 能完全匹配上就返回,匹配到任何动态的也都返回
		if child.part == part || child.isWild {
			return child
		}
	}
	return nil
}

func (n *node) insert(path string, parts []string, height int) {
	if len(parts) == height { // 递归结束条件
		n.path = path // 最后一层的子节点,用于存储完整路径(因为中间不知哪个部分是动态的)
		return
	}
	// 取出当前层
	part := parts[height]
	child := n.matchChild(part) // 在现有树中查是否已经存在
	if child == nil {           // 没有就创建一个新的子节点
		child = &node{part: part, isWild: part[0] == ':' || part[0] == '*'}
		n.children = append(n.children, child)
	}
	child.insert(path, parts, height+1) // 通过这个子节点再从下一层匹配
}

func (n *node) search(parts []string, height int) *node {
	// 递归结束条件:到最后一层的子节点了,或者遇到*开头的了
	if len(parts) == height || strings.HasPrefix(n.part, "*") {
		// 空串说明没有这个路径信息,url错误
		if n.path == "" {
			return nil
		}
		return n
	}

	part := parts[height]
	children := n.matchChildren(part)

	for _, child := range children {
		// 接着匹配下一层
		result := child.search(parts, height+1)
		if result != nil {
			return result
		}
	}

	return nil
}
func (n *node) matchChildren(part string) []*node {
	nodes := make([]*node, 0)
	for _, child := range n.children {
		// 能完全匹配上就返回,匹配到任何动态的也都返回
		if child.part == part || child.isWild {
			nodes = append(nodes, child)
		}
	}
	return nodes
}

上次编辑于:
贡献者: ext.liyuanhao3