forked from castle/castle-go
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add context transformation logic
- Loading branch information
Showing
4 changed files
with
318 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
package context | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"strings" | ||
|
||
"github.com/utilitywarehouse/castle-go" | ||
http_internal "github.com/utilitywarehouse/castle-go/http" | ||
) | ||
|
||
type contextKey string | ||
|
||
func (c contextKey) String() string { | ||
return "castle context key " + string(c) | ||
} | ||
|
||
var castleCtxKey = contextKey("castle_context") | ||
|
||
// ToCtxFromRequest adds the token and other request information (i.e. castle context) to the context. | ||
func ToCtxFromRequest(ctx context.Context, r *http.Request) context.Context { | ||
castleCtx := castle.Context{ | ||
RequestToken: func() string { | ||
// grab the token from header if it exists | ||
if tkn := tokenFromHTTPHeader(r.Header); tkn != "" { | ||
return tkn | ||
} | ||
|
||
// otherwise, try grabbing it from form | ||
return tokenFromHTTPForm(r) | ||
}(), | ||
IP: http_internal.IPFromRequest(r), | ||
Headers: FilterHeaders(r.Header), // pass in as much context as possible | ||
} | ||
return context.WithValue(ctx, castleCtxKey, castleCtx) | ||
} | ||
|
||
func FromCtx(ctx context.Context) *castle.Context { | ||
castleCtx, ok := ctx.Value(castleCtxKey).(castle.Context) | ||
if ok { | ||
return &castleCtx | ||
} | ||
return nil | ||
} | ||
|
||
func tokenFromHTTPHeader(header http.Header) string { | ||
// recommended header name | ||
if t := header.Get("X-Castle-Request-Token"); t != "" { | ||
return t | ||
} | ||
// header name used in the frontends | ||
if t := header.Get("Castle-Token"); t != "" { | ||
return t | ||
} | ||
return "" | ||
} | ||
|
||
func tokenFromHTTPForm(r *http.Request) string { | ||
// ParseForm is idempotent, so it's safe to call from anywhere | ||
if err := r.ParseForm(); err != nil { | ||
return "" | ||
} | ||
|
||
return r.Form.Get("castle_request_token") | ||
} | ||
|
||
func FilterHeaders(hs http.Header) map[string]string { | ||
castleHeaders := make(map[string]string) | ||
for key, value := range hs { | ||
// Ensure cookies or authorization are never sent along. | ||
// Everything else is fair game. | ||
if _, ok := disallowedHeaders[strings.ToLower(key)]; ok { | ||
continue | ||
} | ||
// View: https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html | ||
castleHeaders[key] = strings.Join(value, ", ") | ||
} | ||
return castleHeaders | ||
} | ||
|
||
var disallowedHeaders = map[string]struct{}{ | ||
"cookie": {}, | ||
"authorization": {}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package context | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
|
||
"github.com/utilitywarehouse/castle-go" | ||
) | ||
|
||
func TestToCtxFromRequest(t *testing.T) { | ||
tests := map[string]struct { | ||
input http.Request | ||
expected castle.Context | ||
}{ | ||
"castle token on header": { | ||
input: func() http.Request { | ||
req := httptest.NewRequest(http.MethodPost, "http://example.com", nil) | ||
req.Header.Set("X-Castle-Request-Token", "foo") | ||
req.RemoteAddr = "2.2.2.2" | ||
return *req | ||
}(), | ||
expected: castle.Context{ | ||
IP: "2.2.2.2", | ||
Headers: map[string]string{"X-Castle-Request-Token": "foo"}, | ||
RequestToken: "foo", | ||
}, | ||
}, | ||
"castle token in form": { | ||
input: func() http.Request { | ||
req := httptest.NewRequest(http.MethodPost, "http://example.com/bar?castle_request_token=bar", nil) | ||
req.RemoteAddr = "2.2.2.2" | ||
return *req | ||
}(), | ||
expected: castle.Context{ | ||
IP: "2.2.2.2", | ||
Headers: map[string]string{}, | ||
RequestToken: "bar", | ||
}, | ||
}, | ||
"no castle token": { | ||
input: func() http.Request { | ||
req := http.Request{} | ||
req.RemoteAddr = "2.2.2.2" | ||
|
||
return req | ||
}(), | ||
expected: castle.Context{ | ||
IP: "2.2.2.2", | ||
Headers: map[string]string{}, | ||
RequestToken: "", | ||
}, | ||
}, | ||
} | ||
for name, test := range tests { | ||
test := test | ||
|
||
t.Run(name, func(t *testing.T) { | ||
ctx := context.Background() | ||
|
||
gotCtx := ToCtxFromRequest(ctx, &test.input) | ||
got := FromCtx(gotCtx) | ||
assert.Equal(t, test.expected, *got) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package http | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"net" | ||
"net/http" | ||
"strings" | ||
) | ||
|
||
var cidrs []*net.IPNet | ||
|
||
func init() { | ||
maxCidrBlocks := []string{ | ||
"127.0.0.1/8", // localhost | ||
"10.0.0.0/8", // 24-bit block | ||
"172.16.0.0/12", // 20-bit block | ||
"192.168.0.0/16", // 16-bit block | ||
"169.254.0.0/16", // link local address | ||
"::1/128", // localhost IPv6 | ||
"fc00::/7", // unique local address IPv6 | ||
"fe80::/10", // link local address IPv6 | ||
} | ||
|
||
cidrs = make([]*net.IPNet, len(maxCidrBlocks)) | ||
for i, maxCidrBlock := range maxCidrBlocks { | ||
_, cidr, err := net.ParseCIDR(maxCidrBlock) | ||
if err != nil { | ||
panic(fmt.Sprintf("failed to parse CIDR block %q: %v", maxCidrBlock, err)) | ||
} | ||
cidrs[i] = cidr | ||
} | ||
} | ||
|
||
// IPFromRequest return client's real public IP address from http request headers. | ||
func IPFromRequest(r *http.Request) string { | ||
// If we have it, return this first. | ||
// | ||
// https://developers.cloudflare.com/fundamentals/get-started/reference/http-request-headers/#cf-connecting-ip | ||
if ip := r.Header.Get("Cf-Connecting-Ip"); ip != "" { | ||
return ip | ||
} | ||
|
||
// If we have it, try to return the first global address in X-Forwarded-For | ||
for _, ip := range strings.Split(r.Header.Get("X-Forwarded-For"), ",") { | ||
ip = strings.TrimSpace(ip) | ||
isPrivate, err := isPrivateAddress(ip) | ||
if !isPrivate && err == nil { | ||
return ip | ||
} | ||
} | ||
|
||
// Check X-Real-Ip header next | ||
if ip := r.Header.Get("X-Real-Ip"); ip != "" { | ||
return ip | ||
} | ||
|
||
// If all else fails, return the remote address | ||
// | ||
// If there are colon in remote address, remove the port number | ||
// otherwise, return remote address as is | ||
var ip string | ||
if strings.ContainsRune(r.RemoteAddr, ':') { | ||
ip, _, _ = net.SplitHostPort(r.RemoteAddr) //nolint:errcheck | ||
} else { | ||
ip = r.RemoteAddr | ||
} | ||
return ip | ||
} | ||
|
||
// isPrivateAddress works by checking if the address is under private CIDR blocks. | ||
// List of private CIDR blocks can be seen on : | ||
// | ||
// https://en.wikipedia.org/wiki/Private_network | ||
// | ||
// https://en.wikipedia.org/wiki/Link-local_address | ||
func isPrivateAddress(address string) (bool, error) { | ||
ipAddress := net.ParseIP(address) | ||
if ipAddress == nil { | ||
return false, errors.New("address is not valid") | ||
} | ||
|
||
for i := range cidrs { | ||
if cidrs[i].Contains(ipAddress) { | ||
return true, nil | ||
} | ||
} | ||
|
||
return false, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package http_test | ||
|
||
import ( | ||
"net/http" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
|
||
http_internal "github.com/utilitywarehouse/castle-go/http" | ||
) | ||
|
||
func TestIPFromRequest(t *testing.T) { | ||
tests := map[string]struct { | ||
input *http.Request | ||
expected string | ||
}{ | ||
"empty": { | ||
input: httpRequest(map[string]string{ | ||
"X-Real-Ip": "", | ||
"X-Forwarded-For": "", | ||
"Cf-Connecting-Ip": "", | ||
}, ""), | ||
expected: "", | ||
}, | ||
"cf-connecting-ip": { | ||
input: httpRequest(map[string]string{ | ||
"X-Real-Ip": "foo", | ||
"X-Forwarded-For": "bar", | ||
"Cf-Connecting-Ip": "cf-connecting-ip", | ||
}, "foobar"), | ||
expected: "cf-connecting-ip", | ||
}, | ||
"x-forwarded-for": { | ||
input: httpRequest(map[string]string{ | ||
"X-Real-Ip": "foo", | ||
"X-Forwarded-For": "127.0.0.1, 109.14.23.2", | ||
"Cf-Connecting-Ip": "", | ||
}, "foobar"), | ||
expected: "109.14.23.2", | ||
}, | ||
"x-real-ip": { | ||
input: httpRequest(map[string]string{ | ||
"X-Real-Ip": "x-real-ip", | ||
"X-Forwarded-For": "", | ||
"Cf-Connecting-Ip": "", | ||
}, "foobar"), | ||
expected: "x-real-ip", | ||
}, | ||
"remote-addr": { | ||
input: httpRequest(map[string]string{ | ||
"X-Real-Ip": "", | ||
"X-Forwarded-For": "", | ||
"Cf-Connecting-Ip": "", | ||
}, "remote-addr:8080"), | ||
expected: "remote-addr", | ||
}, | ||
} | ||
for name, test := range tests { | ||
t.Run(name, func(t *testing.T) { | ||
got := http_internal.IPFromRequest(test.input) | ||
assert.Equal(t, test.expected, got) | ||
}) | ||
} | ||
} | ||
|
||
func httpRequest(headers map[string]string, remoteAddr string) *http.Request { | ||
r := &http.Request{ | ||
RemoteAddr: remoteAddr, | ||
Header: make(http.Header), | ||
} | ||
for k, v := range headers { | ||
r.Header.Set(k, v) | ||
} | ||
return r | ||
} |