-
Notifications
You must be signed in to change notification settings - Fork 0
/
forcessl_test.go
49 lines (41 loc) · 1.04 KB
/
forcessl_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
package forcesslheroku
import (
"net/http"
"net/http/httptest"
"testing"
)
var testCases = []struct {
goEnv string
proto string
expectLoc string
}{
{goEnv: "production", proto: "http",
expectLoc: "https://example.com/test"},
{goEnv: "production", proto: "https"},
{goEnv: "test", proto: "http"},
{goEnv: "test", proto: "https"},
}
func TestForceSsl(t *testing.T) {
noopHandler := func(w http.ResponseWriter, r *http.Request) {}
forceSsl := ForceSsl(http.HandlerFunc(noopHandler))
for _, tt := range testCases {
getenv = func(key string) string {
switch key {
case goEnviron:
return tt.goEnv
default:
return ""
}
}
t.Run(tt.goEnv+"_"+tt.proto, func(t *testing.T) {
req := httptest.NewRequest("", "/test", nil)
req.Header.Add(xForwardedProtoHeader, tt.proto)
res := httptest.NewRecorder()
forceSsl.ServeHTTP(res, req)
if location := res.Header().Get("Location"); location != tt.expectLoc {
t.Errorf("expected Location header '%s', got '%s'",
tt.expectLoc, location)
}
})
}
}