diff --git a/config/paths.go b/config/paths.go index cd1c89f..1265169 100644 --- a/config/paths.go +++ b/config/paths.go @@ -12,9 +12,17 @@ import ( ) const ( - RewriteIdHeader = "RewriteId" + PascalCaseRewriteIdHeader = "Rewriteid" + SnakeCaseRewriteIdHeader = "rewrite_id" + KebabCaseRewriteIdHeader = "Rewrite-Id" ) +var rewriteIdHeaders = []string{ + PascalCaseRewriteIdHeader, + SnakeCaseRewriteIdHeader, + KebabCaseRewriteIdHeader, +} + type PathRewrite struct { RewrittenPath string PathConfiguration *shared.WiretapPathConfig @@ -94,13 +102,49 @@ func rewriteTaget(path string, pathConfig *shared.WiretapPathConfig, configurati } } +func getRewriteIdHeaderValues(req *http.Request) ([]string, bool) { + + // Let's first try to get the header with expected key names + for _, possibleHeaderKey := range rewriteIdHeaders { + + if rewriteIdHeaderValues, ok := req.Header[possibleHeaderKey]; ok { + return rewriteIdHeaderValues, true + } + + if rewriteIdHeaderValues, ok := req.Header[strings.ToLower(possibleHeaderKey)]; ok { + return rewriteIdHeaderValues, true + } + + } + + // Let's now try to ignore case ; this may produce collisions if a user has two headers with similar keys, + // but different capitalization. This is okay, as this is a last ditch effort to find any possible match + var loweredHeaders = make(http.Header) + + for headerKey, headerValues := range req.Header { + for _, headerValue := range headerValues { + loweredHeaders.Set(strings.ToLower(headerKey), headerValue) + } + } + + for _, possibleHeaderKey := range rewriteIdHeaders { + + if rewriteIdHeaderValues, ok := loweredHeaders[strings.ToLower(possibleHeaderKey)]; ok { + return rewriteIdHeaderValues, true + } + + } + + return nil, false +} + func FindPathWithRewriteId(paths []*shared.WiretapPathConfig, req *http.Request) *shared.WiretapPathConfig { if req == nil { return nil } - if rewriteIdHeaderValues, ok := req.Header[RewriteIdHeader]; ok { + if rewriteIdHeaderValues, ok := getRewriteIdHeaderValues(req); ok { for _, pathRewriteConfig := range paths { // Iterate through header values - since it's a multi-value field