跳至主要內容

完整结构

程序员李某某大约 5 分钟

完整结构

lee

package lee

import (
	"html/template"
	"net/http"
	"strings"
)

type HandleFunc func(*Context)
type H map[string]interface{}

type Engine struct {
	*RGroup
	router        *router
	groups        []*RGroup
	htmlTemplates *template.Template // for html render
	funcMap       template.FuncMap   // for html render
} // 结构体

func New() *Engine {
	engine := &Engine{router: newRouter()}
	engine.RGroup = &RGroup{engine: engine}
	engine.groups = []*RGroup{engine.RGroup}
	return engine
}

func Default() *Engine {
	e := New()
	e.Use(Logger(), Recovery())
	return e
}

func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {

	var middlewares []HandleFunc

	for _, group := range e.groups {
		if strings.HasPrefix(r.URL.Path, group.pre) {
			middlewares = append(middlewares, group.middlewares...)
		}
	}

	ctx := newContext(w, r)
	ctx.handlers = middlewares
	ctx.engine = e
	tailNode, params := e.router.getRoute(ctx.Method, ctx.Path)
	if tailNode != nil {
		k := r.Method + "_" + tailNode.path // 处理函数的key

		ctx.Params = params                                      // 存到 context
		ctx.handlers = append(ctx.handlers, e.router.handles[k]) // // 取出,也追加到context中
	} else {
		ctx.handlers = append(ctx.handlers, func(ctx *Context) {
			ctx.String(http.StatusNotFound, "404 NOT FOUND: %s\n", ctx.Path)
		})
	}

	ctx.Next() // 挨个执行中间件和处理函数
}

func (e *Engine) SetFuncMap(funcMap template.FuncMap) {
	e.funcMap = funcMap
}

func (e *Engine) LoadHTMLGlob(pattern string) {
	e.htmlTemplates = template.Must(template.New("").Funcs(e.funcMap).ParseGlob(pattern))
}

func (e *Engine) Get(path string, handleFunc HandleFunc) {
	e.router.addRoute("GET", path, handleFunc)
}

func (e *Engine) Post(path string, handleFunc HandleFunc) {
	e.router.addRoute("POST", path, handleFunc)
}

func (e *Engine) Run(addr string) (err error) {
	return http.ListenAndServe(addr, e)
}

group

package lee

import (
	"net/http"
	"path"
)

type RGroup struct {
	pre         string       // 前缀
	middlewares []HandleFunc // 中间件
	parent      *RGroup      // 父节点
	engine      *Engine      // 都挂在一个Engine上
}

func (g *RGroup) Group(pre string) *RGroup {
	engine := g.engine
	newGroup := &RGroup{
		pre:    g.pre + pre, // 拼接
		parent: g,
		engine: engine,
	}
	engine.groups = append(engine.groups, newGroup) // 添加到Engine维护的groups中
	return newGroup
}

func (g *RGroup) addRoute(method string, pre string, handler HandleFunc) {
	path := g.pre + pre
	g.engine.router.addRoute(method, path, handler)
}

func (g *RGroup) Get(path string, handleFunc HandleFunc) {
	g.addRoute("GET", path, handleFunc)
}
func (g *RGroup) Post(path string, handleFunc HandleFunc) {
	g.addRoute("POST", path, handleFunc)
}

func (g *RGroup) Use(middlewares ...HandleFunc) {
	g.middlewares = append(g.middlewares, middlewares...)
}

func (g *RGroup) Static(relativePath, root string) {
	handler := g.createStaticHandler(relativePath, http.Dir(root))
	urlPattern := path.Join(relativePath, "/*filepath")

	g.Get(urlPattern, handler)
}

func (g *RGroup) createStaticHandler(relativePath string, fs http.FileSystem) HandleFunc {
	absolutePath := path.Join(g.pre, relativePath)
	fileServer := http.StripPrefix(absolutePath, http.FileServer(fs))
	return func(c *Context) {
		file := c.Param("filepath")
		// Check if file exists and/or if we have permission to access it
		if _, err := fs.Open(file); err != nil {
			c.Status(http.StatusNotFound)
			return
		}

		fileServer.ServeHTTP(c.Writer, c.Req)
	}
}

router

package lee

import "strings"

type router struct {
	handles map[string]HandleFunc
	roots   map[string]*node
}

func newRouter() *router {
	return &router{
		roots:   make(map[string]*node),
		handles: make(map[string]HandleFunc),
	}

}

func parsePath(path string) []string {
	split := strings.Split(path, "/")

	parts := make([]string, 0)
	if len(split) == 0 {
		return nil
	}
	for _, s := range split {
		if s != "" {
			parts = append(parts, s)
		} else {
			continue
		}
		if s[0] == '*' {
			break // *号结束
		}
	}
	return parts
}

func (r *router) addRoute(method string, path string, handler HandleFunc) {

	parts := parsePath(path)

	if _, ok := r.roots[method]; !ok { // 每个请求方法构建一个树
		r.roots[method] = &node{}
	}
	if parts != nil {
		r.roots[method].insert(path, parts, 0)
	}

	key := method + "_" + path
	r.handles[key] = handler
}

func (r *router) getRoute(method string, path string) (*node, map[string]string) {
	// 真实url切割
	searchParts := parsePath(path)
	params := make(map[string]string)
	// 对应请求方法的树
	root, ok := r.roots[method]

	if !ok {
		return nil, nil
	}

	// 树中匹配到的路径链表的最后一个节点
	n := root.search(searchParts, 0)

	if n != nil {
		// 控制层的动态路由url
		parts := parsePath(n.path)
		for index, part := range parts {
			if part[0] == ':' {
				// 封装动态路由参数的映射:k-控制层的动态路由去掉:后的, v-真实url中对应的值
				params[part[1:]] = searchParts[index]
			}
			if part[0] == '*' && len(part) > 1 {
				// 封装动态路由参数的映射:k-控制层的动态路由去掉:后的, v-真实url中对应的值(该部分及之后的路径)
				params[part[1:]] = strings.Join(searchParts[index:], "/")
				break
			}
		}
		return n, params
	}

	return nil, nil
}

context

package lee

import (
	"encoding/json"
	"fmt"
	"net/http"
)

type Context struct {
	Writer     http.ResponseWriter
	Req        *http.Request
	StatusCode int
	Path       string
	Method     string
	Params     map[string]string

	// middleware
	handlers []HandleFunc
	index    int

	engine *Engine
}

func newContext(w http.ResponseWriter, req *http.Request) *Context {
	return &Context{
		Path:   req.URL.Path,
		Method: req.Method,
		Req:    req,
		Writer: w,
		index:  -1,
	}
}

func (ctx *Context) Next() {
	ctx.index++
	s := len(ctx.handlers)
	for ; ctx.index < s; ctx.index++ {
		// 执行中间件
		ctx.handlers[ctx.index](ctx)
	}
}

func (ctx *Context) PostForm(k string) string {
	return ctx.Req.FormValue(k)
}

func (ctx *Context) Query(k string) string {
	return ctx.Req.URL.Query().Get(k)
}

func (ctx *Context) Param(k string) string {
	return ctx.Params[k]
}

func (ctx *Context) Status(code int) {
	ctx.StatusCode = code
	ctx.Writer.WriteHeader(code)
}

func (ctx *Context) SetHeader(k, v string) {
	ctx.Writer.Header().Set(k, v)
}

func (ctx *Context) JSON(code int, body interface{}) {
	ctx.SetHeader("Content-Type", "application/json")
	ctx.Status(code)
	if bytes, err := json.Marshal(body); err != nil {
		http.Error(ctx.Writer, err.Error(), 500)
	} else {
		_, err = ctx.Writer.Write(bytes)
		if err != nil {
			http.Error(ctx.Writer, err.Error(), 500)
		}
	}
}

func (ctx *Context) String(code int, format string, v ...interface{}) {
	ctx.SetHeader("Content-Type", "text/plain")
	ctx.Status(code)
	if _, err := ctx.Writer.Write([]byte(fmt.Sprintf(format, v...))); err != nil {
		http.Error(ctx.Writer, err.Error(), 500)
	}
}

func (ctx *Context) Fail(code int, msg string) {
	ctx.Status(code)
	ctx.JSON(code, H{"message": msg})
}

func (ctx *Context) HTML(code int, name string, data interface{}) {
	ctx.SetHeader("Content-Type", "text/html")
	ctx.Status(code)
	if err := ctx.engine.htmlTemplates.ExecuteTemplate(ctx.Writer, name, data); err != nil {
		ctx.Fail(500, err.Error())
	}
}

trie

package lee

import "strings"

type node struct {
	path     string  // 完整路径,只在最后一层的子节点存储
	part     string  // 当前节点存储的路由的一部分
	children []*node // 子节点
	isWild   bool    // 该路径片段是否是动态的,part 含有 : 或 * 时为 true
}

func (n *node) matchChild(part string) *node {
	for _, child := range n.children {
		// 能完全匹配上就返回,匹配到任何动态的也都返回
		if child.part == part || child.isWild {
			return child
		}
	}
	return nil
}

func (n *node) insert(path string, parts []string, height int) {
	if len(parts) == height { // 递归结束条件
		n.path = path // 最后一层的子节点,用于存储完整路径(因为中间不知哪个部分是动态的)
		return
	}
	// 取出当前层
	part := parts[height]
	child := n.matchChild(part) // 在现有树中查是否已经存在
	if child == nil {           // 没有就创建一个新的子节点
		child = &node{part: part, isWild: part[0] == ':' || part[0] == '*'}
		n.children = append(n.children, child)
	}
	child.insert(path, parts, height+1) // 通过这个子节点再从下一层匹配
}

func (n *node) search(parts []string, height int) *node {
	// 递归结束条件:到最后一层的子节点了,或者遇到*开头的了
	if len(parts) == height || strings.HasPrefix(n.part, "*") {
		// 空串说明没有这个路径信息,url错误
		if n.path == "" {
			return nil
		}
		return n
	}

	part := parts[height]
	children := n.matchChildren(part)

	for _, child := range children {
		// 接着匹配下一层
		result := child.search(parts, height+1)
		if result != nil {
			return result
		}
	}

	return nil
}
func (n *node) matchChildren(part string) []*node {
	nodes := make([]*node, 0)
	for _, child := range n.children {
		// 能完全匹配上就返回,匹配到任何动态的也都返回
		if child.part == part || child.isWild {
			nodes = append(nodes, child)
		}
	}
	return nodes
}

logger

package lee

import (
	"log"
	"time"
)

func Logger() HandleFunc {
	return func(c *Context) {
		// Start timer
		t := time.Now()
		// Process request
		c.Next()
		// Calculate resolution time
		log.Printf("############# logger [%d] %s in %v", c.StatusCode, c.Req.RequestURI, time.Since(t))
	}
}

recovery

package lee

import (
	"fmt"
	"log"
	"net/http"
	"runtime"
	"strings"
)

func Recovery() HandleFunc {
	return func(c *Context) {
		defer func() {
			if err := recover(); err != nil {
				message := fmt.Sprintf("%s", err)
				log.Printf("%s\n\n", trace(message))
				c.Fail(http.StatusInternalServerError, "Internal Server Error")
			}
		}()

		c.Next()
	}
}

// print stack trace for debug
func trace(message string) string {
	var pcs [32]uintptr
	n := runtime.Callers(3, pcs[:]) // skip first 3 caller

	var str strings.Builder
	str.WriteString(message + "\nTraceback:")
	for _, pc := range pcs[:n] {
		fn := runtime.FuncForPC(pc)
		file, line := fn.FileLine(pc)
		str.WriteString(fmt.Sprintf("\n\t%s:%d", file, line))
	}
	return str.String()
}

上次编辑于:
贡献者: ext.liyuanhao3