完整结构
大约 5 分钟
完整结构
lee
package lee
import (
"html/template"
"net/http"
"strings"
)
type HandleFunc func(*Context)
type H map[string]interface{}
type Engine struct {
*RGroup
router *router
groups []*RGroup
htmlTemplates *template.Template // for html render
funcMap template.FuncMap // for html render
} // 结构体
func New() *Engine {
engine := &Engine{router: newRouter()}
engine.RGroup = &RGroup{engine: engine}
engine.groups = []*RGroup{engine.RGroup}
return engine
}
func Default() *Engine {
e := New()
e.Use(Logger(), Recovery())
return e
}
func (e *Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var middlewares []HandleFunc
for _, group := range e.groups {
if strings.HasPrefix(r.URL.Path, group.pre) {
middlewares = append(middlewares, group.middlewares...)
}
}
ctx := newContext(w, r)
ctx.handlers = middlewares
ctx.engine = e
tailNode, params := e.router.getRoute(ctx.Method, ctx.Path)
if tailNode != nil {
k := r.Method + "_" + tailNode.path // 处理函数的key
ctx.Params = params // 存到 context
ctx.handlers = append(ctx.handlers, e.router.handles[k]) // // 取出,也追加到context中
} else {
ctx.handlers = append(ctx.handlers, func(ctx *Context) {
ctx.String(http.StatusNotFound, "404 NOT FOUND: %s\n", ctx.Path)
})
}
ctx.Next() // 挨个执行中间件和处理函数
}
func (e *Engine) SetFuncMap(funcMap template.FuncMap) {
e.funcMap = funcMap
}
func (e *Engine) LoadHTMLGlob(pattern string) {
e.htmlTemplates = template.Must(template.New("").Funcs(e.funcMap).ParseGlob(pattern))
}
func (e *Engine) Get(path string, handleFunc HandleFunc) {
e.router.addRoute("GET", path, handleFunc)
}
func (e *Engine) Post(path string, handleFunc HandleFunc) {
e.router.addRoute("POST", path, handleFunc)
}
func (e *Engine) Run(addr string) (err error) {
return http.ListenAndServe(addr, e)
}
group
package lee
import (
"net/http"
"path"
)
type RGroup struct {
pre string // 前缀
middlewares []HandleFunc // 中间件
parent *RGroup // 父节点
engine *Engine // 都挂在一个Engine上
}
func (g *RGroup) Group(pre string) *RGroup {
engine := g.engine
newGroup := &RGroup{
pre: g.pre + pre, // 拼接
parent: g,
engine: engine,
}
engine.groups = append(engine.groups, newGroup) // 添加到Engine维护的groups中
return newGroup
}
func (g *RGroup) addRoute(method string, pre string, handler HandleFunc) {
path := g.pre + pre
g.engine.router.addRoute(method, path, handler)
}
func (g *RGroup) Get(path string, handleFunc HandleFunc) {
g.addRoute("GET", path, handleFunc)
}
func (g *RGroup) Post(path string, handleFunc HandleFunc) {
g.addRoute("POST", path, handleFunc)
}
func (g *RGroup) Use(middlewares ...HandleFunc) {
g.middlewares = append(g.middlewares, middlewares...)
}
func (g *RGroup) Static(relativePath, root string) {
handler := g.createStaticHandler(relativePath, http.Dir(root))
urlPattern := path.Join(relativePath, "/*filepath")
g.Get(urlPattern, handler)
}
func (g *RGroup) createStaticHandler(relativePath string, fs http.FileSystem) HandleFunc {
absolutePath := path.Join(g.pre, relativePath)
fileServer := http.StripPrefix(absolutePath, http.FileServer(fs))
return func(c *Context) {
file := c.Param("filepath")
// Check if file exists and/or if we have permission to access it
if _, err := fs.Open(file); err != nil {
c.Status(http.StatusNotFound)
return
}
fileServer.ServeHTTP(c.Writer, c.Req)
}
}
router
package lee
import "strings"
type router struct {
handles map[string]HandleFunc
roots map[string]*node
}
func newRouter() *router {
return &router{
roots: make(map[string]*node),
handles: make(map[string]HandleFunc),
}
}
func parsePath(path string) []string {
split := strings.Split(path, "/")
parts := make([]string, 0)
if len(split) == 0 {
return nil
}
for _, s := range split {
if s != "" {
parts = append(parts, s)
} else {
continue
}
if s[0] == '*' {
break // *号结束
}
}
return parts
}
func (r *router) addRoute(method string, path string, handler HandleFunc) {
parts := parsePath(path)
if _, ok := r.roots[method]; !ok { // 每个请求方法构建一个树
r.roots[method] = &node{}
}
if parts != nil {
r.roots[method].insert(path, parts, 0)
}
key := method + "_" + path
r.handles[key] = handler
}
func (r *router) getRoute(method string, path string) (*node, map[string]string) {
// 真实url切割
searchParts := parsePath(path)
params := make(map[string]string)
// 对应请求方法的树
root, ok := r.roots[method]
if !ok {
return nil, nil
}
// 树中匹配到的路径链表的最后一个节点
n := root.search(searchParts, 0)
if n != nil {
// 控制层的动态路由url
parts := parsePath(n.path)
for index, part := range parts {
if part[0] == ':' {
// 封装动态路由参数的映射:k-控制层的动态路由去掉:后的, v-真实url中对应的值
params[part[1:]] = searchParts[index]
}
if part[0] == '*' && len(part) > 1 {
// 封装动态路由参数的映射:k-控制层的动态路由去掉:后的, v-真实url中对应的值(该部分及之后的路径)
params[part[1:]] = strings.Join(searchParts[index:], "/")
break
}
}
return n, params
}
return nil, nil
}
context
package lee
import (
"encoding/json"
"fmt"
"net/http"
)
type Context struct {
Writer http.ResponseWriter
Req *http.Request
StatusCode int
Path string
Method string
Params map[string]string
// middleware
handlers []HandleFunc
index int
engine *Engine
}
func newContext(w http.ResponseWriter, req *http.Request) *Context {
return &Context{
Path: req.URL.Path,
Method: req.Method,
Req: req,
Writer: w,
index: -1,
}
}
func (ctx *Context) Next() {
ctx.index++
s := len(ctx.handlers)
for ; ctx.index < s; ctx.index++ {
// 执行中间件
ctx.handlers[ctx.index](ctx)
}
}
func (ctx *Context) PostForm(k string) string {
return ctx.Req.FormValue(k)
}
func (ctx *Context) Query(k string) string {
return ctx.Req.URL.Query().Get(k)
}
func (ctx *Context) Param(k string) string {
return ctx.Params[k]
}
func (ctx *Context) Status(code int) {
ctx.StatusCode = code
ctx.Writer.WriteHeader(code)
}
func (ctx *Context) SetHeader(k, v string) {
ctx.Writer.Header().Set(k, v)
}
func (ctx *Context) JSON(code int, body interface{}) {
ctx.SetHeader("Content-Type", "application/json")
ctx.Status(code)
if bytes, err := json.Marshal(body); err != nil {
http.Error(ctx.Writer, err.Error(), 500)
} else {
_, err = ctx.Writer.Write(bytes)
if err != nil {
http.Error(ctx.Writer, err.Error(), 500)
}
}
}
func (ctx *Context) String(code int, format string, v ...interface{}) {
ctx.SetHeader("Content-Type", "text/plain")
ctx.Status(code)
if _, err := ctx.Writer.Write([]byte(fmt.Sprintf(format, v...))); err != nil {
http.Error(ctx.Writer, err.Error(), 500)
}
}
func (ctx *Context) Fail(code int, msg string) {
ctx.Status(code)
ctx.JSON(code, H{"message": msg})
}
func (ctx *Context) HTML(code int, name string, data interface{}) {
ctx.SetHeader("Content-Type", "text/html")
ctx.Status(code)
if err := ctx.engine.htmlTemplates.ExecuteTemplate(ctx.Writer, name, data); err != nil {
ctx.Fail(500, err.Error())
}
}
trie
package lee
import "strings"
type node struct {
path string // 完整路径,只在最后一层的子节点存储
part string // 当前节点存储的路由的一部分
children []*node // 子节点
isWild bool // 该路径片段是否是动态的,part 含有 : 或 * 时为 true
}
func (n *node) matchChild(part string) *node {
for _, child := range n.children {
// 能完全匹配上就返回,匹配到任何动态的也都返回
if child.part == part || child.isWild {
return child
}
}
return nil
}
func (n *node) insert(path string, parts []string, height int) {
if len(parts) == height { // 递归结束条件
n.path = path // 最后一层的子节点,用于存储完整路径(因为中间不知哪个部分是动态的)
return
}
// 取出当前层
part := parts[height]
child := n.matchChild(part) // 在现有树中查是否已经存在
if child == nil { // 没有就创建一个新的子节点
child = &node{part: part, isWild: part[0] == ':' || part[0] == '*'}
n.children = append(n.children, child)
}
child.insert(path, parts, height+1) // 通过这个子节点再从下一层匹配
}
func (n *node) search(parts []string, height int) *node {
// 递归结束条件:到最后一层的子节点了,或者遇到*开头的了
if len(parts) == height || strings.HasPrefix(n.part, "*") {
// 空串说明没有这个路径信息,url错误
if n.path == "" {
return nil
}
return n
}
part := parts[height]
children := n.matchChildren(part)
for _, child := range children {
// 接着匹配下一层
result := child.search(parts, height+1)
if result != nil {
return result
}
}
return nil
}
func (n *node) matchChildren(part string) []*node {
nodes := make([]*node, 0)
for _, child := range n.children {
// 能完全匹配上就返回,匹配到任何动态的也都返回
if child.part == part || child.isWild {
nodes = append(nodes, child)
}
}
return nodes
}
logger
package lee
import (
"log"
"time"
)
func Logger() HandleFunc {
return func(c *Context) {
// Start timer
t := time.Now()
// Process request
c.Next()
// Calculate resolution time
log.Printf("############# logger [%d] %s in %v", c.StatusCode, c.Req.RequestURI, time.Since(t))
}
}
recovery
package lee
import (
"fmt"
"log"
"net/http"
"runtime"
"strings"
)
func Recovery() HandleFunc {
return func(c *Context) {
defer func() {
if err := recover(); err != nil {
message := fmt.Sprintf("%s", err)
log.Printf("%s\n\n", trace(message))
c.Fail(http.StatusInternalServerError, "Internal Server Error")
}
}()
c.Next()
}
}
// print stack trace for debug
func trace(message string) string {
var pcs [32]uintptr
n := runtime.Callers(3, pcs[:]) // skip first 3 caller
var str strings.Builder
str.WriteString(message + "\nTraceback:")
for _, pc := range pcs[:n] {
fn := runtime.FuncForPC(pc)
file, line := fn.FileLine(pc)
str.WriteString(fmt.Sprintf("\n\t%s:%d", file, line))
}
return str.String()
}
