diff --git a/middleware/helmet_test.go b/middleware/helmet_test.go index 4362074..809b535 100644 --- a/middleware/helmet_test.go +++ b/middleware/helmet_test.go @@ -1,6 +1,7 @@ package middleware import ( + "io" "net/http" "net/http/httptest" "testing" @@ -10,41 +11,42 @@ import ( func TestHelmet(t *testing.T) { r := mux.NewRouter() + r.Use(Helmet(HelmetOption{})) r.Get("/hello", func(writer http.ResponseWriter, request *http.Request) { _, _ = writer.Write([]byte("hello there")) }) - endpoint := httptest.NewRequest(http.MethodGet, "/hello", nil) + srv := httptest.NewServer(r) + defer srv.Close() + w, _ := testRequest(t, srv, "GET", "/hello", nil) - // test endpoint registered/reachable - w := httptest.NewRecorder() - r.ServeHTTP(w, endpoint) - if w.Code != http.StatusOK { - t.Error("not expecting status", w.Code) - return - } - - // no header test - w = httptest.NewRecorder() - r.ServeHTTP(w, endpoint) - csp := w.Header().Get("Content-Security-Policy") + csp := w.Header.Get("Content-Security-Policy") // must not have a csp header, technically no header related to helmet but lets test with one. - if csp != "" { - t.Error("csp header not expected") - } - - // introduce helmet middleware - r.Use(Helmet(HelmetOption{})) - - // header tests.. - w = httptest.NewRecorder() - r.ServeHTTP(w, endpoint) - // csp and other headers are expected - csp = w.Header().Get("Content-Security-Policy") - // fmt.Printf("csp %s", csp) if csp == "" { - t.Error("csp header missing") + t.Error("csp header is expected") } - // TODO need more tests +} + +func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) { + req, err := http.NewRequest(method, ts.URL+path, body) + if err != nil { + t.Fatal(err) + return nil, "" + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + return nil, "" + } + + respBody, err := io.ReadAll(io.Reader(resp.Body)) + if err != nil { + t.Fatal(err) + return nil, "" + } + defer resp.Body.Close() + + return resp, string(respBody) } diff --git a/router.go b/router.go index c2ca975..ad13ced 100644 --- a/router.go +++ b/router.go @@ -141,6 +141,15 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { panic("mux: method ServeHTTP called on nil") } + r.mux.ServeHTTP(w, req) +} + +// Proxy are request and +func (r *Router) Proxy(w http.ResponseWriter, req *http.Request) { + if r == nil { + panic("mux: method ServeHTTP called on nil") + } + h, pattern := r.mux.Handler(req) if pattern == "" { http.Error(w, "Not Found", http.StatusNotFound)