Skip to content

Commit

Permalink
Context tools (#19)
Browse files Browse the repository at this point in the history
* add context transformation logic
  • Loading branch information
pansbro12 authored Apr 8, 2024
1 parent fd68fe3 commit f1032b5
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 0 deletions.
84 changes: 84 additions & 0 deletions context/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package context

import (
"context"
"net/http"
"strings"

"github.com/utilitywarehouse/castle-go"
http_internal "github.com/utilitywarehouse/castle-go/http"
)

type contextKey string

func (c contextKey) String() string {
return "castle context key " + string(c)
}

var castleCtxKey = contextKey("castle_context")

// ToCtxFromRequest adds the token and other request information (i.e. castle context) to the context.
func ToCtxFromRequest(ctx context.Context, r *http.Request) context.Context {
castleCtx := castle.Context{
RequestToken: func() string {
// grab the token from header if it exists
if tkn := tokenFromHTTPHeader(r.Header); tkn != "" {
return tkn
}

// otherwise, try grabbing it from form
return tokenFromHTTPForm(r)
}(),
IP: http_internal.IPFromRequest(r),
Headers: FilterHeaders(r.Header), // pass in as much context as possible
}
return context.WithValue(ctx, castleCtxKey, castleCtx)
}

func FromCtx(ctx context.Context) *castle.Context {
castleCtx, ok := ctx.Value(castleCtxKey).(castle.Context)
if ok {
return &castleCtx
}
return nil
}

func tokenFromHTTPHeader(header http.Header) string {
// recommended header name
if t := header.Get("X-Castle-Request-Token"); t != "" {
return t
}
// header name used in the frontends
if t := header.Get("Castle-Token"); t != "" {
return t
}
return ""
}

func tokenFromHTTPForm(r *http.Request) string {
// ParseForm is idempotent, so it's safe to call from anywhere
if err := r.ParseForm(); err != nil {
return ""
}

return r.Form.Get("castle_request_token")
}

func FilterHeaders(hs http.Header) map[string]string {
castleHeaders := make(map[string]string)
for key, value := range hs {
// Ensure cookies or authorization are never sent along.
// Everything else is fair game.
if _, ok := disallowedHeaders[strings.ToLower(key)]; ok {
continue
}
// View: https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html
castleHeaders[key] = strings.Join(value, ", ")
}
return castleHeaders
}

var disallowedHeaders = map[string]struct{}{
"cookie": {},
"authorization": {},
}
69 changes: 69 additions & 0 deletions context/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package context

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"

"github.com/utilitywarehouse/castle-go"
)

func TestToCtxFromRequest(t *testing.T) {
tests := map[string]struct {
input http.Request
expected castle.Context
}{
"castle token on header": {
input: func() http.Request {
req := httptest.NewRequest(http.MethodPost, "http://example.com", nil)
req.Header.Set("X-Castle-Request-Token", "foo")
req.RemoteAddr = "2.2.2.2"
return *req
}(),
expected: castle.Context{
IP: "2.2.2.2",
Headers: map[string]string{"X-Castle-Request-Token": "foo"},
RequestToken: "foo",
},
},
"castle token in form": {
input: func() http.Request {
req := httptest.NewRequest(http.MethodPost, "http://example.com/bar?castle_request_token=bar", nil)
req.RemoteAddr = "2.2.2.2"
return *req
}(),
expected: castle.Context{
IP: "2.2.2.2",
Headers: map[string]string{},
RequestToken: "bar",
},
},
"no castle token": {
input: func() http.Request {
req := http.Request{}
req.RemoteAddr = "2.2.2.2"

return req
}(),
expected: castle.Context{
IP: "2.2.2.2",
Headers: map[string]string{},
RequestToken: "",
},
},
}
for name, test := range tests {
test := test

t.Run(name, func(t *testing.T) {
ctx := context.Background()

gotCtx := ToCtxFromRequest(ctx, &test.input)
got := FromCtx(gotCtx)
assert.Equal(t, test.expected, *got)
})
}
}
90 changes: 90 additions & 0 deletions http/ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package http

import (
"errors"
"fmt"
"net"
"net/http"
"strings"
)

var cidrs []*net.IPNet

func init() {
maxCidrBlocks := []string{
"127.0.0.1/8", // localhost
"10.0.0.0/8", // 24-bit block
"172.16.0.0/12", // 20-bit block
"192.168.0.0/16", // 16-bit block
"169.254.0.0/16", // link local address
"::1/128", // localhost IPv6
"fc00::/7", // unique local address IPv6
"fe80::/10", // link local address IPv6
}

cidrs = make([]*net.IPNet, len(maxCidrBlocks))
for i, maxCidrBlock := range maxCidrBlocks {
_, cidr, err := net.ParseCIDR(maxCidrBlock)
if err != nil {
panic(fmt.Sprintf("failed to parse CIDR block %q: %v", maxCidrBlock, err))
}
cidrs[i] = cidr
}
}

// IPFromRequest return client's real public IP address from http request headers.
func IPFromRequest(r *http.Request) string {
// If we have it, return this first.
//
// https://developers.cloudflare.com/fundamentals/get-started/reference/http-request-headers/#cf-connecting-ip
if ip := r.Header.Get("Cf-Connecting-Ip"); ip != "" {
return ip
}

// If we have it, try to return the first global address in X-Forwarded-For
for _, ip := range strings.Split(r.Header.Get("X-Forwarded-For"), ",") {
ip = strings.TrimSpace(ip)
isPrivate, err := isPrivateAddress(ip)
if !isPrivate && err == nil {
return ip
}
}

// Check X-Real-Ip header next
if ip := r.Header.Get("X-Real-Ip"); ip != "" {
return ip
}

// If all else fails, return the remote address
//
// If there are colon in remote address, remove the port number
// otherwise, return remote address as is
var ip string
if strings.ContainsRune(r.RemoteAddr, ':') {
ip, _, _ = net.SplitHostPort(r.RemoteAddr) //nolint:errcheck
} else {
ip = r.RemoteAddr
}
return ip
}

// isPrivateAddress works by checking if the address is under private CIDR blocks.
// List of private CIDR blocks can be seen on :
//
// https://en.wikipedia.org/wiki/Private_network
//
// https://en.wikipedia.org/wiki/Link-local_address
func isPrivateAddress(address string) (bool, error) {
ipAddress := net.ParseIP(address)
if ipAddress == nil {
return false, errors.New("address is not valid")
}

for i := range cidrs {
if cidrs[i].Contains(ipAddress) {
return true, nil
}
}

return false, nil
}
75 changes: 75 additions & 0 deletions http/ip_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package http_test

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"

http_internal "github.com/utilitywarehouse/castle-go/http"
)

func TestIPFromRequest(t *testing.T) {
tests := map[string]struct {
input *http.Request
expected string
}{
"empty": {
input: httpRequest(map[string]string{
"X-Real-Ip": "",
"X-Forwarded-For": "",
"Cf-Connecting-Ip": "",
}, ""),
expected: "",
},
"cf-connecting-ip": {
input: httpRequest(map[string]string{
"X-Real-Ip": "foo",
"X-Forwarded-For": "bar",
"Cf-Connecting-Ip": "cf-connecting-ip",
}, "foobar"),
expected: "cf-connecting-ip",
},
"x-forwarded-for": {
input: httpRequest(map[string]string{
"X-Real-Ip": "foo",
"X-Forwarded-For": "127.0.0.1, 109.14.23.2",
"Cf-Connecting-Ip": "",
}, "foobar"),
expected: "109.14.23.2",
},
"x-real-ip": {
input: httpRequest(map[string]string{
"X-Real-Ip": "x-real-ip",
"X-Forwarded-For": "",
"Cf-Connecting-Ip": "",
}, "foobar"),
expected: "x-real-ip",
},
"remote-addr": {
input: httpRequest(map[string]string{
"X-Real-Ip": "",
"X-Forwarded-For": "",
"Cf-Connecting-Ip": "",
}, "remote-addr:8080"),
expected: "remote-addr",
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
got := http_internal.IPFromRequest(test.input)
assert.Equal(t, test.expected, got)
})
}
}

func httpRequest(headers map[string]string, remoteAddr string) *http.Request {
r := &http.Request{
RemoteAddr: remoteAddr,
Header: make(http.Header),
}
for k, v := range headers {
r.Header.Set(k, v)
}
return r
}

0 comments on commit f1032b5

Please sign in to comment.