diff --git a/main_test.go b/main_test.go index 7e721a29..c907ec14 100644 --- a/main_test.go +++ b/main_test.go @@ -896,6 +896,64 @@ func TestServe(t *testing.T) { }, startHTTP, }, + { + "http allow CORS false", + "testdata/http.yml", + func(t *testing.T) { + // TODO: rework this test because it doesn't fails when it should + // cf the discussion in https://github.com/ContentSquare/chproxy/pull/489 + q := "cors" + req, err := http.NewRequest("GET", "http://127.0.0.1:9090?query="+url.QueryEscape(q), nil) + checkErr(t, err) + resp, err := http.DefaultClient.Do(req) + checkErr(t, err) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d; expected: %d", resp.StatusCode, http.StatusOK) + } + defer resp.Body.Close() + checkHeader(t, resp, "Access-Control-Allow-Origin", "*") + }, + startHTTP, + }, + { + "http allow CORS true without request Origin header", + "testdata/http.allow.cors.yml", + func(t *testing.T) { + // TODO: rework this test because it doesn't fails when it should + // cf the discussion in https://github.com/ContentSquare/chproxy/pull/489 + q := "cors" + req, err := http.NewRequest("GET", "http://127.0.0.1:9090?query="+url.QueryEscape(q), nil) + checkErr(t, err) + resp, err := http.DefaultClient.Do(req) + checkErr(t, err) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d; expected: %d", resp.StatusCode, http.StatusOK) + } + defer resp.Body.Close() + checkHeader(t, resp, "Access-Control-Allow-Origin", "*") + }, + startHTTP, + }, + { + "http allow CORS true with request Origin header", + "testdata/http.allow.cors.yml", + func(t *testing.T) { + // TODO: rework this test because it doesn't fails when it should + // cf the discussion in https://github.com/ContentSquare/chproxy/pull/489 + q := "cors" + req, err := http.NewRequest("GET", "http://127.0.0.1:9090?query="+url.QueryEscape(q), nil) + checkErr(t, err) + req.Header.Set("Origin", "http://example.com") + resp, err := http.DefaultClient.Do(req) + checkErr(t, err) + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d; expected: %d", resp.StatusCode, http.StatusOK) + } + defer resp.Body.Close() + checkHeader(t, resp, "Access-Control-Allow-Origin", "http://example.com") + }, + startHTTP, + }, } // Wait until CHServer starts. @@ -1108,6 +1166,8 @@ func fakeCHHandler(w http.ResponseWriter, r *http.Request) { // execute sleep 1.5 sec time.Sleep(1500 * time.Millisecond) fmt.Fprint(w, b) + case q == "cors": + w.Header().Set("Access-Control-Allow-Origin", "*") default: if strings.Contains(string(query), killQueryPattern) { fakeCHState.kill() diff --git a/proxy.go b/proxy.go index 15ca73ea..c02dbd8a 100644 --- a/proxy.go +++ b/proxy.go @@ -107,14 +107,6 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { log.Debugf("%s: request start", s) requestSum.With(s.labels).Inc() - if s.user.allowCORS { - origin := req.Header.Get("Origin") - if len(origin) == 0 { - origin = "*" - } - rw.Header().Set("Access-Control-Allow-Origin", origin) - } - req.Body = &statReadCloser{ ReadCloser: req.Body, bytesRead: requestBodyBytes.With(s.labels), @@ -149,6 +141,14 @@ func (rp *reverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rp.proxyRequest(s, srw, srw, req) } + if s.user.allowCORS { + origin := req.Header.Get("Origin") + if len(origin) == 0 { + origin = "*" + } + rw.Header().Set("Access-Control-Allow-Origin", origin) + } + // It is safe calling getQuerySnippet here, since the request // has been already read in proxyRequest or serveFromCache. query := getQuerySnippet(req) diff --git a/testdata/http.allow.cors.yml b/testdata/http.allow.cors.yml new file mode 100644 index 00000000..db445b56 --- /dev/null +++ b/testdata/http.allow.cors.yml @@ -0,0 +1,15 @@ +log_debug: true +server: + http: + listen_addr: ":9090" + allowed_networks: ["127.0.0.1/24"] + +users: + - name: "default" + to_cluster: "default" + to_user: "default" + allow_cors: true + +clusters: + - name: "default" + nodes: ["127.0.0.1:18124"] \ No newline at end of file