Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@ jobs:
- name: Checkout code
uses: actions/checkout@v2

- name: Test
run: go test ./...

- name: Test (race)
run: go test -race ./...
1 change: 0 additions & 1 deletion compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ const (
encodingHeader = "Content-Encoding"
lenHeader = "Content-Length"

brEnc = "br"
gzEnc = "gzip"
)

Expand Down
3 changes: 3 additions & 0 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,12 @@ func (ghc *groupHandlerChain) Serve(rw http.ResponseWriter, req *http.Request, p
}
}
}

ctx.nextMW = func() {
if catchPanic != nil {
defer catchPanic()
}

for mwIdx < len(ghc.g.mw) && !ctx.done {
h := ghc.g.mw[mwIdx]
mwIdx++
Expand All @@ -153,6 +155,7 @@ func (ghc *groupHandlerChain) Serve(rw http.ResponseWriter, req *http.Request, p
break
}
}

ctx.nextMW = nil
}

Expand Down
5 changes: 2 additions & 3 deletions router/router_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package router

import (
"encoding/json"
"net/http"
"strings"
"testing"
Expand All @@ -19,8 +18,8 @@ func TestRouter(t *testing.T) {
req, _ = http.NewRequest("PATCH", "../"+ep, nil)
r.ServeHTTP(nil, req)
}
j, _ := json.MarshalIndent(r.swagger, "", " ")
t.Log("\n" + string(j))
// j, _ := json.MarshalIndent(r.swagger, "", " ")
// t.Log("\n" + string(j))
}

func TestRouterStar(t *testing.T) {
Expand Down
51 changes: 27 additions & 24 deletions router/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,56 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strings"
)

type nodePart string

func (np nodePart) Name() string { return string(np[1:]) }
func (np nodePart) Type() uint8 { return np[0] }
func (np nodePart) Name() string {
return string(np[1:])
}

func (np nodePart) Type() byte {
return np[0]
}
func (np nodePart) String() string {
if np.Type() == '/' {
return fmt.Sprintf("{%s}", np.Name())
}
return fmt.Sprintf("{%s '%c'}", np.Name(), np.Type())
}

var re = regexp.MustCompile(`([:*/]?[^:*]+)`)

// splitPathToParts takes in a path (ex: /api/v1/someEndpoint/:id/*any) and returns:
//
// pp -> the longest part before the first param (/api/v1/someEndpoint/:)
// rest -> all the params (id, any)
// num -> number of params (probably not needed...)
// stars -> number of stars, basically a sanity check, if it's not 0 or 1 then it's an invalid path
func splitPathToParts(p string) (pp string, rest []nodePart, num, stars int) {
parts := re.FindAllString(p, -1)
if len(parts) < 2 {
idx := strings.IndexAny(p, ":*")
if idx == -1 {
pp = p
return pp, rest, num, stars
return
}
pp = p[:idx]
for part := range strings.SplitSeq(p[idx:], "/") {
if len(part) == 0 {
continue
}
switch part[0] {
case '*':
stars++
fallthrough
case ':':
num++
rest = append(rest, nodePart(part))
default:
rest = append(rest, nodePart("/"+part))

pp = parts[0]
for _, part := range parts[1:] {
splitPathFn(part, '/', func(sp string, _, _ int) bool {
switch c := sp[0]; c {
case '*':
stars++
fallthrough
case ':':
num++
fallthrough
case '/':
rest = append(rest, nodePart(sp))
}
return false
})
}
}
return pp, rest, num, stars

return
}

func splitPathFn(s string, sep uint8, fn func(p string, pidx, idx int) bool) bool {
Expand Down
3 changes: 0 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
Expand All @@ -30,8 +29,6 @@ var DefaultPanicHandler = func(ctx *Context, v any, fr *oerrs.Frame) {
_ = ctx.Encode(500, resp)
}

var noopLogger = log.New(io.Discard, "", 0)

// DefaultOpts are the default options used for creating new servers.
var DefaultOpts = Options{
WriteTimeout: time.Minute,
Expand Down
25 changes: 20 additions & 5 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ func cmpData(a, b any) bool {
return av == bv
}

var pong = "pong"

func panicTyped(ctx *Context) (any, error) {
panic("well... poo")
}
Expand Down Expand Up @@ -124,7 +122,9 @@ func TestServer(t *testing.T) {
})

JSONGet(srv, "/panic2", panicTyped, true)
srv.AllowCORS("/cors", "GET")
srv.Use(AllowCORS(nil, nil, []string{"example.com"}))

srv.POST("/cors", func(ctx *Context) Response { return nil })

JSONGet(srv, "/ping/:id", func(ctx *Context) (string, error) {
return "pong:" + ctx.Params.Get("id"), nil
Expand Down Expand Up @@ -308,16 +308,31 @@ func TestServer(t *testing.T) {
t.Run("CORS", func(t *testing.T) {
var (
client http.Client
req, _ = http.NewRequest(http.MethodOptions, ts.URL+"/cors", nil)
req, _ = http.NewRequest(http.MethodPost, ts.URL+"/cors", nil)
)
req.Header.Add("Origin", "http://localhost")
req.Header.Add("Origin", "https://example.com")
req.Header.Add("Access-Control-Request-Method", "GET")
resp, _ := client.Do(req)
resp.Body.Close()
if resp.Header.Get("Access-Control-Allow-Methods") != "GET" {
t.Errorf("unexpected headers: %+v", resp.Header)
}
})

t.Run("CORS_BAD_ORIGIN", func(t *testing.T) {
var (
client http.Client
req, _ = http.NewRequest(http.MethodPost, ts.URL+"/cors", nil)
)
req.Header.Add("Origin", "https://badexample.com")
req.Header.Add("Access-Control-Request-Method", "GET")
resp, _ := client.Do(req)
resp.Body.Close()
if resp.Header.Get("Access-Control-Allow-Methods") == "GET" {
t.Errorf("unexpected headers: %+v", resp.Header)
}
})

t.Run("POST", func(t *testing.T) {
resp, err := http.Post(ts.URL+"/ping/hello", MimeJSON, strings.NewReader(`{"ping": "world"}`))
if err != nil {
Expand Down
19 changes: 14 additions & 5 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ func matchStarOrigin(set otk.Set, keys []string, origin string) bool {
if !found {
continue
}
if strings.HasSuffix(origin, orig) {

if origin == orig {
return true
}

Expand All @@ -130,7 +131,7 @@ func matchStarOrigin(set otk.Set, keys []string, origin string) bool {
// If headers is empty, it will respond with the requested headers.
// If origins is empty, it will respond with the requested origin.
// will automatically install an OPTIONS handler to each passed group.
func AllowCORS(methods, headers, origins []string, groups ...GroupType) Handler {
func AllowCORS(methods, headers, origins []string, groups ...*Group) Handler {
ms := strings.Join(methods, ", ")
hs := strings.Join(headers, ", ")

Expand All @@ -155,25 +156,33 @@ func AllowCORS(methods, headers, origins []string, groups ...GroupType) Handler
return
}

if len(ms) > 0 {
if len(methods) > 0 {
wh.Set("Access-Control-Allow-Methods", ms)
} else if rm := rh.Get("Access-Control-Request-Method"); rm != "" {
wh.Set("Access-Control-Allow-Methods", rm)
} else {
wh.Set("Access-Control-Allow-Methods", "*")
}

if len(hs) > 0 {
if len(headers) > 0 {
wh.Set("Access-Control-Allow-Headers", hs)
} else if rh := rh.Get("Access-Control-Request-Headers"); rh != "" {
wh.Set("Access-Control-Allow-Headers", rh)
} else {
wh.Set("Access-Control-Allow-Headers", "*")
}

wh.Set("Access-Control-Max-Age", "86400") // 24 hours

if ctx.Req.Method == http.MethodOptions {
return RespEmpty
}

return
}

for _, g := range groups {
g.AddRoute("OPTIONS", "/*x", fn)
g.Use(fn)
}

return fn
Expand Down
5 changes: 5 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,9 @@ func TestMatchStarOrigin(t *testing.T) {
if matchStarOrigin(nil, []string{"example.com"}, "1034.example.com") {
t.Fatal("shouldn't match 1034.example.com")
}

if matchStarOrigin(nil, []string{"example.com"}, "1034.evilexample.com") {
t.Fatal("shouldn't match 1034.evilexample.com")
}

}