diff --git a/middleware_test.go b/middleware_test.go index 5d680bb..7f57d7f 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -107,29 +107,44 @@ func TestMiddleware_Recoverer(t *testing.T) { func TestWrap(t *testing.T) { handler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { t.Logf("%s", r.URL.String()) - assert.Equal(t, "/something/1/2", r.URL.Path) }) mw1 := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.URL.Path += "/1" + w.Header().Set("X-MW1", "1") h.ServeHTTP(w, r) }) } mw2 := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.URL.Path += "/2" + w.Header().Set("X-MW2", "2") h.ServeHTTP(w, r) }) } - h := Wrap(handler, mw1, mw2) - ts := httptest.NewServer(h) - defer ts.Close() + t.Run("no middleware", func(t *testing.T) { + h := Wrap(handler) + ts := httptest.NewServer(h) + defer ts.Close() - resp, err := http.Get(ts.URL + "/something") - require.NoError(t, err) - assert.Equal(t, 200, resp.StatusCode) + resp, err := http.Get(ts.URL + "/something") + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, "", resp.Header.Get("X-MW1")) + assert.Equal(t, "", resp.Header.Get("X-MW2")) + }) + + t.Run("with middleware", func(t *testing.T) { + h := Wrap(handler, mw1, mw2) + ts := httptest.NewServer(h) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/something") + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, "1", resp.Header.Get("X-MW1")) + assert.Equal(t, "2", resp.Header.Get("X-MW2")) + }) } func TestHeaders(t *testing.T) {