From 9780c0d0fa77ad0f9780dc2121d2220846fd1ecf Mon Sep 17 00:00:00 2001 From: OneOfOne Date: Mon, 13 Oct 2025 16:13:11 -0500 Subject: [PATCH 1/3] chore(cleanup): remove the use of regexp --- .github/workflows/build.yml | 3 +++ compression.go | 1 - router/router_test.go | 5 ++-- router/utils.go | 51 ++++++++++++++++++++----------------- server.go | 3 --- server_test.go | 2 -- utils.go | 1 + 7 files changed, 33 insertions(+), 33 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 682f2fb..459341e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -23,5 +23,8 @@ jobs: - name: Checkout code uses: actions/checkout@v2 + - name: Test + run: go test ./... + - name: Test (race) run: go test -race ./... diff --git a/compression.go b/compression.go index a393657..4cebd0f 100644 --- a/compression.go +++ b/compression.go @@ -13,7 +13,6 @@ const ( encodingHeader = "Content-Encoding" lenHeader = "Content-Length" - brEnc = "br" gzEnc = "gzip" ) diff --git a/router/router_test.go b/router/router_test.go index a446cbd..2386c6d 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -1,7 +1,6 @@ package router import ( - "encoding/json" "net/http" "strings" "testing" @@ -19,8 +18,8 @@ func TestRouter(t *testing.T) { req, _ = http.NewRequest("PATCH", "../"+ep, nil) r.ServeHTTP(nil, req) } - j, _ := json.MarshalIndent(r.swagger, "", " ") - t.Log("\n" + string(j)) + // j, _ := json.MarshalIndent(r.swagger, "", " ") + // t.Log("\n" + string(j)) } func TestRouterStar(t *testing.T) { diff --git a/router/utils.go b/router/utils.go index 0359e1a..89f3877 100644 --- a/router/utils.go +++ b/router/utils.go @@ -4,14 +4,18 @@ import ( "fmt" "io" "net/http" - "regexp" "strings" ) type nodePart string -func (np nodePart) Name() string { return string(np[1:]) } -func (np nodePart) Type() uint8 { return np[0] } +func (np nodePart) Name() string { + return string(np[1:]) +} + +func (np nodePart) Type() byte { + return np[0] +} func (np nodePart) String() string { if np.Type() == '/' { return fmt.Sprintf("{%s}", np.Name()) @@ -19,8 +23,6 @@ func (np nodePart) String() string { return fmt.Sprintf("{%s '%c'}", np.Name(), np.Type()) } -var re = regexp.MustCompile(`([:*/]?[^:*]+)`) - // splitPathToParts takes in a path (ex: /api/v1/someEndpoint/:id/*any) and returns: // // pp -> the longest part before the first param (/api/v1/someEndpoint/:) @@ -28,29 +30,30 @@ var re = regexp.MustCompile(`([:*/]?[^:*]+)`) // num -> number of params (probably not needed...) // stars -> number of stars, basically a sanity check, if it's not 0 or 1 then it's an invalid path func splitPathToParts(p string) (pp string, rest []nodePart, num, stars int) { - parts := re.FindAllString(p, -1) - if len(parts) < 2 { + idx := strings.IndexAny(p, ":*") + if idx == -1 { pp = p - return pp, rest, num, stars + return } + pp = p[:idx] + for part := range strings.SplitSeq(p[idx:], "/") { + if len(part) == 0 { + continue + } + switch part[0] { + case '*': + stars++ + fallthrough + case ':': + num++ + rest = append(rest, nodePart(part)) + default: + rest = append(rest, nodePart("/"+part)) - pp = parts[0] - for _, part := range parts[1:] { - splitPathFn(part, '/', func(sp string, _, _ int) bool { - switch c := sp[0]; c { - case '*': - stars++ - fallthrough - case ':': - num++ - fallthrough - case '/': - rest = append(rest, nodePart(sp)) - } - return false - }) + } } - return pp, rest, num, stars + + return } func splitPathFn(s string, sep uint8, fn func(p string, pidx, idx int) bool) bool { diff --git a/server.go b/server.go index 9774e06..af3d700 100644 --- a/server.go +++ b/server.go @@ -5,7 +5,6 @@ import ( "context" "errors" "fmt" - "io" "log" "net" "net/http" @@ -30,8 +29,6 @@ var DefaultPanicHandler = func(ctx *Context, v any, fr *oerrs.Frame) { _ = ctx.Encode(500, resp) } -var noopLogger = log.New(io.Discard, "", 0) - // DefaultOpts are the default options used for creating new servers. var DefaultOpts = Options{ WriteTimeout: time.Minute, diff --git a/server_test.go b/server_test.go index dc4d2b4..65b3a8c 100644 --- a/server_test.go +++ b/server_test.go @@ -94,8 +94,6 @@ func cmpData(a, b any) bool { return av == bv } -var pong = "pong" - func panicTyped(ctx *Context) (any, error) { panic("well... poo") } diff --git a/utils.go b/utils.go index 477a0f8..905f5aa 100644 --- a/utils.go +++ b/utils.go @@ -117,6 +117,7 @@ func matchStarOrigin(set otk.Set, keys []string, origin string) bool { if !found { continue } + if strings.HasSuffix(origin, orig) { return true } From 08b3de39a66e6031c3b8ffd613e4997f62b9ce60 Mon Sep 17 00:00:00 2001 From: OneOfOne Date: Tue, 14 Oct 2025 14:03:01 -0500 Subject: [PATCH 2/3] fix(cors): make sure to match exact domains --- utils.go | 2 +- utils_test.go | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/utils.go b/utils.go index 905f5aa..bf85eca 100644 --- a/utils.go +++ b/utils.go @@ -118,7 +118,7 @@ func matchStarOrigin(set otk.Set, keys []string, origin string) bool { continue } - if strings.HasSuffix(origin, orig) { + if origin == orig { return true } diff --git a/utils_test.go b/utils_test.go index c8fdaa6..7de0281 100644 --- a/utils_test.go +++ b/utils_test.go @@ -11,4 +11,9 @@ func TestMatchStarOrigin(t *testing.T) { if matchStarOrigin(nil, []string{"example.com"}, "1034.example.com") { t.Fatal("shouldn't match 1034.example.com") } + + if matchStarOrigin(nil, []string{"example.com"}, "1034.evilexample.com") { + t.Fatal("shouldn't match 1034.evilexample.com") + } + } From 1d98c28a3ef58a99d2222f2287826628d09431a5 Mon Sep 17 00:00:00 2001 From: OneOfOne Date: Tue, 14 Oct 2025 17:41:56 -0500 Subject: [PATCH 3/3] fix(cors): return headers for all requests if applicable --- group.go | 3 +++ server_test.go | 23 ++++++++++++++++++++--- utils.go | 16 ++++++++++++---- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/group.go b/group.go index 752a247..31ee9b4 100644 --- a/group.go +++ b/group.go @@ -137,10 +137,12 @@ func (ghc *groupHandlerChain) Serve(rw http.ResponseWriter, req *http.Request, p } } } + ctx.nextMW = func() { if catchPanic != nil { defer catchPanic() } + for mwIdx < len(ghc.g.mw) && !ctx.done { h := ghc.g.mw[mwIdx] mwIdx++ @@ -153,6 +155,7 @@ func (ghc *groupHandlerChain) Serve(rw http.ResponseWriter, req *http.Request, p break } } + ctx.nextMW = nil } diff --git a/server_test.go b/server_test.go index 65b3a8c..1c16e74 100644 --- a/server_test.go +++ b/server_test.go @@ -122,7 +122,9 @@ func TestServer(t *testing.T) { }) JSONGet(srv, "/panic2", panicTyped, true) - srv.AllowCORS("/cors", "GET") + srv.Use(AllowCORS(nil, nil, []string{"example.com"})) + + srv.POST("/cors", func(ctx *Context) Response { return nil }) JSONGet(srv, "/ping/:id", func(ctx *Context) (string, error) { return "pong:" + ctx.Params.Get("id"), nil @@ -306,9 +308,10 @@ func TestServer(t *testing.T) { t.Run("CORS", func(t *testing.T) { var ( client http.Client - req, _ = http.NewRequest(http.MethodOptions, ts.URL+"/cors", nil) + req, _ = http.NewRequest(http.MethodPost, ts.URL+"/cors", nil) ) - req.Header.Add("Origin", "http://localhost") + req.Header.Add("Origin", "https://example.com") + req.Header.Add("Access-Control-Request-Method", "GET") resp, _ := client.Do(req) resp.Body.Close() if resp.Header.Get("Access-Control-Allow-Methods") != "GET" { @@ -316,6 +319,20 @@ func TestServer(t *testing.T) { } }) + t.Run("CORS_BAD_ORIGIN", func(t *testing.T) { + var ( + client http.Client + req, _ = http.NewRequest(http.MethodPost, ts.URL+"/cors", nil) + ) + req.Header.Add("Origin", "https://badexample.com") + req.Header.Add("Access-Control-Request-Method", "GET") + resp, _ := client.Do(req) + resp.Body.Close() + if resp.Header.Get("Access-Control-Allow-Methods") == "GET" { + t.Errorf("unexpected headers: %+v", resp.Header) + } + }) + t.Run("POST", func(t *testing.T) { resp, err := http.Post(ts.URL+"/ping/hello", MimeJSON, strings.NewReader(`{"ping": "world"}`)) if err != nil { diff --git a/utils.go b/utils.go index bf85eca..d6f81a9 100644 --- a/utils.go +++ b/utils.go @@ -131,7 +131,7 @@ func matchStarOrigin(set otk.Set, keys []string, origin string) bool { // If headers is empty, it will respond with the requested headers. // If origins is empty, it will respond with the requested origin. // will automatically install an OPTIONS handler to each passed group. -func AllowCORS(methods, headers, origins []string, groups ...GroupType) Handler { +func AllowCORS(methods, headers, origins []string, groups ...*Group) Handler { ms := strings.Join(methods, ", ") hs := strings.Join(headers, ", ") @@ -156,25 +156,33 @@ func AllowCORS(methods, headers, origins []string, groups ...GroupType) Handler return } - if len(ms) > 0 { + if len(methods) > 0 { wh.Set("Access-Control-Allow-Methods", ms) } else if rm := rh.Get("Access-Control-Request-Method"); rm != "" { wh.Set("Access-Control-Allow-Methods", rm) + } else { + wh.Set("Access-Control-Allow-Methods", "*") } - if len(hs) > 0 { + if len(headers) > 0 { wh.Set("Access-Control-Allow-Headers", hs) } else if rh := rh.Get("Access-Control-Request-Headers"); rh != "" { wh.Set("Access-Control-Allow-Headers", rh) + } else { + wh.Set("Access-Control-Allow-Headers", "*") } wh.Set("Access-Control-Max-Age", "86400") // 24 hours + if ctx.Req.Method == http.MethodOptions { + return RespEmpty + } + return } for _, g := range groups { - g.AddRoute("OPTIONS", "/*x", fn) + g.Use(fn) } return fn