From 2c5fea8737b23926feabb729bd1ed958e7f91ee1 Mon Sep 17 00:00:00 2001 From: Dave Shanley Date: Thu, 6 Jul 2023 17:55:56 -0400 Subject: [PATCH] Added new header drop feature cleaned up regex handling for case sensitivity (had to remove viper) and also added query param to all rewrites. Signed-off-by: Dave Shanley --- README.md | 11 + cmd/root_command.go | 192 +++++++----------- config/paths_test.go | 51 ++++- daemon/api.go | 5 +- daemon/dto.go | 21 +- daemon/handle_request.go | 24 ++- daemon/wiretap_utils.go | 118 ++++++----- shared/config.go | 35 ++-- shared/language.go | 11 + ui/src/components/transaction/request-body.ts | 10 +- ui/src/model/http_transaction.ts | 2 +- 11 files changed, 273 insertions(+), 207 deletions(-) create mode 100644 shared/language.go diff --git a/README.md b/README.md index 1a9f42e..bde0f2c 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,17 @@ paths: '^/pb33f/(\w+)/test/': '' ``` +## Dropping certain headers + +To prevent certain headers from being proxies, you can drop them using the `headers` config property and the `drop` property +which is an array of headers to drop from all outbound requests.. + +```yaml +headers: + drop: + - Origin +``` + ## Command Line Interface ### Available Flags diff --git a/cmd/root_command.go b/cmd/root_command.go index e7d688a..0c25fed 100644 --- a/cmd/root_command.go +++ b/cmd/root_command.go @@ -5,14 +5,13 @@ package cmd import ( "embed" - "github.com/mitchellh/mapstructure" + "gopkg.in/yaml.v3" "net/url" "os" "github.com/pb33f/wiretap/shared" "github.com/pterm/pterm" "github.com/spf13/cobra" - "github.com/spf13/viper" ) var ( @@ -33,33 +32,12 @@ var ( configFlag, _ := cmd.Flags().GetString("config") - if configFlag == "" { - pterm.Info.Println("Attempting to locate wiretap configuration...") - viper.SetConfigFile(".wiretap") - viper.SetConfigType("env") - viper.AddConfigPath("$HOME/.wiretap") - viper.AddConfigPath(".") - } else { - viper.SetConfigFile(configFlag) - } - - cerr := viper.ReadInConfig() - if cerr != nil && configFlag != "" { - pterm.Error.Printf("No wiretap configuration located. Using defaults: %s\n", cerr.Error()) - } - if cerr != nil && configFlag == "" { - pterm.Info.Println("No wiretap configuration located. Using defaults.") - } - if cerr == nil { - pterm.Info.Printf("Located configuration file at: %s\n", viper.ConfigFileUsed()) - } - var spec string var port string var monitorPort string var wsPort string var staticDir string - var pathConfigurations map[string]*shared.WiretapPathConfig + var redirectHost string var redirectPort string var redirectScheme string @@ -67,55 +45,11 @@ var ( var redirectURL string var globalAPIDelay int - // extract from wiretap environment variables. - if viper.IsSet("PORT") { - port = viper.GetString("PORT") - } - - if viper.IsSet("SPEC") { - spec = viper.GetString("SPEC") - } - - if viper.IsSet("MONITOR_PORT") { - monitorPort = viper.GetString("MONITOR_PORT") - } - - if viper.IsSet("WEBSOCKET_PORT") { - wsPort = viper.GetString("WEBSOCKET_PORT") - } - - if viper.IsSet("STATIC_DIR") { - staticDir = viper.GetString("STATIC_DIR") - } - - if viper.IsSet("PATHS") { - paths := viper.Get("PATHS") - var pc map[string]*shared.WiretapPathConfig - err := mapstructure.Decode(paths, &pc) - if err != nil { - pterm.Error.Printf("Unable to decode paths from configuration: %s\n", err.Error()) - } else { - // print out the path configurations. - printLoadedPathConfigurations(pc) - pathConfigurations = pc - } - } - - if viper.IsSet("REDIRECT_URL") { - redirectURL = viper.GetString("REDIRECT_URL") - } - - if viper.IsSet("GLOBAL_API_DELAY") { - globalAPIDelay = viper.GetInt("GLOBAL_API_DELAY") - } - portFlag, _ := cmd.Flags().GetString("port") if portFlag != "" { port = portFlag } else { - if port == "" { - port = "9090" // default - } + port = "9090" // default } specFlag, _ := cmd.Flags().GetString("spec") @@ -127,9 +61,7 @@ var ( if monitorPortFlag != "" { monitorPort = monitorPortFlag } else { - if monitorPort == "" { - monitorPort = "9091" // default - } + monitorPort = "9091" // default } staticDirFlag, _ := cmd.Flags().GetString("static") @@ -141,18 +73,11 @@ var ( if wsPortFlag != "" { wsPort = wsPortFlag } else { - if wsPort == "" { - wsPort = "9092" // default - } + wsPort = "9092" // default } redirectURLFlag, _ := cmd.Flags().GetString("url") if redirectURLFlag != "" { - - if pathConfigurations != nil { - // warn the user that the path configurations will trump the switch - pterm.Warning.Println("Using the --url flag will be *overridden* by the path configuration 'target' setting") - } redirectURL = redirectURLFlag } @@ -161,6 +86,27 @@ var ( globalAPIDelay = globalAPIDelayFlag } + var config shared.WiretapConfiguration + if configFlag != "" { + + cBytes, err := os.ReadFile(configFlag) + if err != nil { + pterm.Error.Printf("Failed to read wiretap configuration '%s': %s\n", configFlag, err.Error()) + return err + } + err = yaml.Unmarshal(cBytes, &config) + if err != nil { + pterm.Error.Printf("Failed to parse wiretap configuration '%s': %s\n", configFlag, err.Error()) + return err + } + pterm.Info.Printf("Loaded wiretap configuration '%s'...\n\n", configFlag) + if config.RedirectURL != "" { + redirectURL = config.RedirectURL + } + } else { + pterm.Info.Println("No wiretap configuration located. Using defaults") + } + if spec == "" { pterm.Warning.Println("No OpenAPI specification provided. " + "Please provide a path to an OpenAPI specification using the --spec or -s flags.") @@ -177,48 +123,64 @@ var ( return nil } - if redirectURL != "" { - parsedURL, e := url.Parse(redirectURL) - if e != nil { - pterm.Println() - pterm.Error.Printf("URL is not valid. "+ - "Please provide a valid URL to redirect to. %s cannot be parsed\n\n", redirectURL) - pterm.Println() - return nil - } - if parsedURL.Scheme == "" || parsedURL.Host == "" { - pterm.Println() - pterm.Error.Printf("URL is not valid. "+ - "Please provide a valid URL to redirect to. %s cannot be parsed\n\n", redirectURL) - pterm.Println() - return nil - } - redirectHost = parsedURL.Hostname() - redirectPort = parsedURL.Port() - redirectScheme = parsedURL.Scheme - redirectBasePath = parsedURL.Path + parsedURL, e := url.Parse(redirectURL) + if e != nil { + pterm.Println() + pterm.Error.Printf("URL is not valid. "+ + "Please provide a valid URL to redirect to. %s cannot be parsed\n\n", redirectURL) + pterm.Println() + return nil } + if parsedURL.Scheme == "" || parsedURL.Host == "" { + pterm.Println() + pterm.Error.Printf("URL is not valid. "+ + "Please provide a valid URL to redirect to. %s cannot be parsed\n\n", redirectURL) + pterm.Println() + return nil + } + redirectHost = parsedURL.Hostname() + redirectPort = parsedURL.Port() + redirectScheme = parsedURL.Scheme + redirectBasePath = parsedURL.Path - config := shared.WiretapConfiguration{ - Contract: spec, - RedirectURL: redirectURL, - RedirectHost: redirectHost, - RedirectBasePath: redirectBasePath, - RedirectPort: redirectPort, - RedirectProtocol: redirectScheme, - Port: port, - MonitorPort: monitorPort, - GlobalAPIDelay: globalAPIDelay, - WebSocketPort: wsPort, - StaticDir: staticDir, - PathConfigurations: pathConfigurations, - FS: FS, + config.Contract = spec + config.RedirectURL = redirectURL + config.RedirectHost = redirectHost + config.RedirectBasePath = redirectBasePath + config.RedirectPort = redirectPort + config.RedirectProtocol = redirectScheme + if config.Port == "" { + config.Port = port + } + if config.MonitorPort == "" { + config.MonitorPort = monitorPort + } + if config.WebSocketPort == "" { + config.WebSocketPort = wsPort } + if config.GlobalAPIDelay == 0 { + config.GlobalAPIDelay = globalAPIDelay + } + if config.StaticDir == "" { + config.StaticDir = staticDir + } + config.FS = FS - if len(pathConfigurations) > 0 { + if len(config.PathConfigurations) > 0 { + printLoadedPathConfigurations(config.PathConfigurations) config.CompilePaths() } + if config.Headers != nil && len(config.Headers.DropHeaders) > 0 { + + pterm.Info.Printf("Dropping the following %d %s:\n", len(config.Headers.DropHeaders), + shared.Pluralize(len(config.Headers.DropHeaders), "header", "headers")) + for _, header := range config.Headers.DropHeaders { + pterm.Printf("🗑️ %s\n", pterm.LightMagenta(header)) + } + pterm.Println() + } + // ready to boot, let's go! _, pErr := runWiretapService(&config) diff --git a/config/paths_test.go b/config/paths_test.go index da9e840..f761b5f 100644 --- a/config/paths_test.go +++ b/config/paths_test.go @@ -8,6 +8,7 @@ import ( "github.com/pb33f/wiretap/shared" "github.com/spf13/viper" "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" "strings" "testing" ) @@ -36,7 +37,6 @@ paths: PathConfigurations: pc, } - // compile paths wcConfig.CompilePaths() res := FindPaths("/pb33f/test/123", wcConfig) @@ -74,7 +74,6 @@ paths: PathConfigurations: pc, } - // compile paths wcConfig.CompilePaths() path := RewritePath("/pb33f/test/123/slap/a/chap", wcConfig) @@ -106,7 +105,6 @@ paths: PathConfigurations: pc, } - // compile paths wcConfig.CompilePaths() path := RewritePath("/pb33f/cakes/test/123/smelly/jelly", wcConfig) @@ -138,10 +136,55 @@ paths: PathConfigurations: pc, } - // compile paths wcConfig.CompilePaths() path := RewritePath("/pb33f/cakes/test/lemons/321/smelly/jelly", wcConfig) assert.Equal(t, "https://localhost:9093/slippy/cakes/whip/321/lemons/smelly/jelly", path) } + +func TestRewritePath_Secure_With_Variables_CaseSensitive(t *testing.T) { + + config := ` +paths: + /en-US/burgerd/__raw/*: + target: localhost:80 + pathRewrite: + '^/en-US/burgerd/__raw/(\w+)/nobody/': '$1/-/' + /en-US/burgerd/services/*: + target: locahost:80 + pathRewrite: + '^/en-US/burgerd/services': '/services'` + + var c shared.WiretapConfiguration + _ = yaml.Unmarshal([]byte(config), &c) + + c.CompilePaths() + + path := RewritePath("/en-US/burgerd/__raw/noKetchupPlease/nobody/", &c) + assert.Equal(t, "http://localhost:80/noKetchupPlease/-/", path) + +} + +func TestRewritePath_Secure_With_Variables_CaseSensitive_AndQuery(t *testing.T) { + + config := ` +paths: + /en-US/burgerd/__raw/*: + target: localhost:80 + pathRewrite: + '^/en-US/burgerd/__raw/(\w+)/nobody/': '$1/-/' + /en-US/burgerd/services/*: + target: locahost:80 + pathRewrite: + '^/en-US/burgerd/services': '/services'` + + var c shared.WiretapConfiguration + _ = yaml.Unmarshal([]byte(config), &c) + + c.CompilePaths() + + path := RewritePath("/en-US/burgerd/__raw/noKetchupPlease/nobody/yummy/yum?onions=true", &c) + assert.Equal(t, "http://localhost:80/noKetchupPlease/-/yummy/yum?onions=true", path) + +} diff --git a/daemon/api.go b/daemon/api.go index 6726092..467ef60 100644 --- a/daemon/api.go +++ b/daemon/api.go @@ -51,7 +51,10 @@ func (ws *WiretapService) callAPI(req *http.Request) (*http.Response, error) { replaced := config.RewritePath(req.URL.Path, wiretapConfig) if replaced != "" { newUrl, _ := url.Parse(replaced) - pterm.Info.Printf("[wiretap] Re-writing path '%s' to '%s'\n", req.URL.Path, replaced) + if req.URL.RawQuery != "" { + newUrl.RawQuery = req.URL.RawQuery + } + pterm.Info.Printf("[wiretap] Re-writing path '%s' to '%s'\n", req.URL.String(), newUrl.String()) req.URL = newUrl } diff --git a/daemon/dto.go b/daemon/dto.go index 62f79a0..c13893e 100644 --- a/daemon/dto.go +++ b/daemon/dto.go @@ -103,14 +103,19 @@ func buildResponse(r *model.Request, response *http.Response) *HttpTransaction { func buildRequest(r *model.Request) *HttpTransaction { - storeManager := bus.GetBus().GetStoreManager() - controlsStore := storeManager.CreateStore(controls.ControlServiceChan) - config, _ := controlsStore.Get(shared.ConfigKey) - - newReq := cloneRequest(r.HttpRequest, - config.(*shared.WiretapConfiguration).RedirectProtocol, - config.(*shared.WiretapConfiguration).RedirectHost, - config.(*shared.WiretapConfiguration).RedirectPort) + config, _ := bus. + GetBus(). + GetStoreManager(). + GetStore(controls.ControlServiceChan). + Get(shared.ConfigKey) + + newReq := cloneRequest(CloneRequest{ + Request: r.HttpRequest, + Protocol: config.(*shared.WiretapConfiguration).RedirectProtocol, + Host: config.(*shared.WiretapConfiguration).RedirectHost, + Port: config.(*shared.WiretapConfiguration).RedirectPort, + DropHeaders: config.(*shared.WiretapConfiguration).Headers.DropHeaders, + }) var requestBody []byte diff --git a/daemon/handle_request.go b/daemon/handle_request.go index f383c34..71121a9 100644 --- a/daemon/handle_request.go +++ b/daemon/handle_request.go @@ -47,15 +47,21 @@ func (ws *WiretapService) handleHttpRequest(request *model.Request) { configStore, _ := ws.controlsStore.Get(shared.ConfigKey) config := configStore.(*shared.WiretapConfiguration) - newReq := cloneRequest(request.HttpRequest, - config.RedirectProtocol, - config.RedirectHost, - config.RedirectPort) - - apiRequest := cloneRequest(request.HttpRequest, - config.RedirectProtocol, - config.RedirectHost, - config.RedirectPort) + newReq := cloneRequest(CloneRequest{ + Request: request.HttpRequest, + Protocol: config.RedirectProtocol, + Host: config.RedirectHost, + Port: config.RedirectPort, + DropHeaders: config.Headers.DropHeaders, + }) + + apiRequest := cloneRequest(CloneRequest{ + Request: request.HttpRequest, + Protocol: config.RedirectProtocol, + Host: config.RedirectHost, + Port: config.RedirectPort, + DropHeaders: config.Headers.DropHeaders, + }) // validate the request go ws.validateRequest(request, newReq, requestValidator, paramValidator, responseValidator) diff --git a/daemon/wiretap_utils.go b/daemon/wiretap_utils.go index fc9913d..c93e960 100644 --- a/daemon/wiretap_utils.go +++ b/daemon/wiretap_utils.go @@ -4,67 +4,85 @@ package daemon import ( - "bytes" - "fmt" - "io" - "net/http" + "bytes" + "fmt" + "io" + "net/http" + "strings" ) func extractHeaders(resp *http.Response) map[string]any { - headers := make(map[string]any) - for k, v := range resp.Header { - headers[k] = v[0] - } - return headers + headers := make(map[string]any) + for k, v := range resp.Header { + headers[k] = v[0] + } + return headers } func reconstructURL(r *http.Request, protocol, host, port string) string { - url := fmt.Sprintf("%s://%s", protocol, host) - if port != "" { - url += fmt.Sprintf(":%s", port) - } - if r.URL.Path != "" { - url += r.URL.Path - } - if r.URL.RawQuery != "" { - url += fmt.Sprintf("?%s", r.URL.RawQuery) - } - return url + url := fmt.Sprintf("%s://%s", protocol, host) + if port != "" { + url += fmt.Sprintf(":%s", port) + } + if r.URL.Path != "" { + url += r.URL.Path + } + if r.URL.RawQuery != "" { + url += fmt.Sprintf("?%s", r.URL.RawQuery) + } + return url } -func cloneRequest(r *http.Request, protocol, host, port string) *http.Request { - // todo: replace with config/server etc. - // todo: check query params +type CloneRequest struct { + Request *http.Request + Protocol string + Host string + Port string + DropHeaders []string +} + +func cloneRequest(request CloneRequest) *http.Request { + // sniff and replace body. + b, _ := io.ReadAll(request.Request.Body) + _ = request.Request.Body.Close() + request.Request.Body = io.NopCloser(bytes.NewBuffer(b)) - // sniff and replace body. - b, _ := io.ReadAll(r.Body) - _ = r.Body.Close() - r.Body = io.NopCloser(bytes.NewBuffer(b)) + // create cloned request + newURL := reconstructURL(request.Request, request.Protocol, request.Host, request.Port) + newReq, _ := http.NewRequest(request.Request.Method, newURL, io.NopCloser(bytes.NewBuffer(b))) - // create cloned request - newURL := reconstructURL(r, protocol, host, port) - newReq, _ := http.NewRequest(r.Method, newURL, io.NopCloser(bytes.NewBuffer(b))) - newReq.Header = r.Header - return newReq + // copy headers, drop those that are specified. + for k, v := range request.Request.Header { + skip := false + for h := range request.DropHeaders { + if strings.EqualFold(request.DropHeaders[h], k) { + skip = true + } + } + if !skip { + newReq.Header.Set(k, v[0]) + } + } + return newReq } func cloneResponse(r *http.Response) *http.Response { - // sniff and replace body. - var b []byte - if r == nil { - return nil // something else went wrong, nothing to do. - } - if r.Body != nil { - b, _ = io.ReadAll(r.Body) - _ = r.Body.Close() - r.Body = io.NopCloser(bytes.NewBuffer(b)) - } - resp := &http.Response{ - StatusCode: r.StatusCode, - Header: r.Header, - } - if r.Body != nil { - resp.Body = io.NopCloser(bytes.NewBuffer(b)) - } - return resp + // sniff and replace body. + var b []byte + if r == nil { + return nil // something else went wrong, nothing to do. + } + if r.Body != nil { + b, _ = io.ReadAll(r.Body) + _ = r.Body.Close() + r.Body = io.NopCloser(bytes.NewBuffer(b)) + } + resp := &http.Response{ + StatusCode: r.StatusCode, + Header: r.Header, + } + if r.Body != nil { + resp.Body = io.NopCloser(bytes.NewBuffer(b)) + } + return resp } diff --git a/shared/config.go b/shared/config.go index 4a209e2..9767bac 100644 --- a/shared/config.go +++ b/shared/config.go @@ -10,18 +10,19 @@ import ( ) type WiretapConfiguration struct { - Contract string `json:"-"` - RedirectHost string `json:"redirectHost,omitempty"` - RedirectPort string `json:"redirectPort,omitempty"` - RedirectBasePath string `json:"redirectBasePath,omitempty"` - RedirectProtocol string `json:"redirectProtocol,omitempty"` - RedirectURL string `json:"redirectURL,omitempty"` - Port string `json:"port,omitempty"` - MonitorPort string `json:"monitorPort,omitempty"` - WebSocketPort string `json:"webSocketPort,omitempty"` - GlobalAPIDelay int `json:"globalAPIDelay,omitempty"` - StaticDir string `json:"staticDir,omitempty"` - PathConfigurations map[string]*WiretapPathConfig `json:"paths,omitempty"` + Contract string `json:"-" yaml:"-"` + RedirectHost string `json:"redirectHost,omitempty" yaml:"redirectHost,omitempty"` + RedirectPort string `json:"redirectPort,omitempty" yaml:"redirectPort,omitempty"` + RedirectBasePath string `json:"redirectBasePath,omitempty" yaml:"redirectBasePath,omitempty"` + RedirectProtocol string `json:"redirectProtocol,omitempty" yaml:"redirectProtocol,omitempty"` + RedirectURL string `json:"redirectURL,omitempty" yaml:"redirectURL,omitempty"` + Port string `json:"port,omitempty" yaml:"port,omitempty"` + MonitorPort string `json:"monitorPort,omitempty" yaml:"monitorPort,omitempty"` + WebSocketPort string `json:"webSocketPort,omitempty" yaml:"webSocketPort,omitempty"` + GlobalAPIDelay int `json:"globalAPIDelay,omitempty" yaml:"globalAPIDelay,omitempty"` + StaticDir string `json:"staticDir,omitempty" yaml:"staticDir,omitempty"` + PathConfigurations map[string]*WiretapPathConfig `json:"paths,omitempty" yaml:"paths,omitempty"` + Headers *WiretapHeaderConfig `json:"headers,omitempty" yaml:"headers,omitempty"` CompiledPaths map[string]*CompiledPath `json:"-"` FS embed.FS `json:"-"` } @@ -34,10 +35,10 @@ func (wtc *WiretapConfiguration) CompilePaths() { } type WiretapPathConfig struct { - Target string `json:"target,omitempty"` - PathRewrite map[string]string `json:"pathRewrite,omitempty"` + Target string `json:"target,omitempty" yaml:"target,omitempty"` + PathRewrite map[string]string `json:"pathRewrite,omitempty" yaml:"pathRewrite,omitempty"` CompiledPath *CompiledPath `json:"-"` - Secure bool `json:"secure,omitempty"` + Secure bool `json:"secure,omitempty" yaml:"secure"` } type CompiledPath struct { @@ -54,6 +55,10 @@ type CompiledPathRewrite struct { CompiledTarget glob.Glob } +type WiretapHeaderConfig struct { + DropHeaders []string `json:"drop,omitempty" yaml:"drop,omitempty"` +} + func (wpc *WiretapPathConfig) Compile(key string) *CompiledPath { cp := &CompiledPath{ PathConfig: wpc, diff --git a/shared/language.go b/shared/language.go new file mode 100644 index 0000000..e9767c9 --- /dev/null +++ b/shared/language.go @@ -0,0 +1,11 @@ +// Copyright 2023 Princess B33f Heavy Industries / Dave Shanley +// SPDX-License-Identifier: MIT + +package shared + +func Pluralize(n int, singular string, plural string) string { + if n == 1 { + return singular + } + return plural +} diff --git a/ui/src/components/transaction/request-body.ts b/ui/src/components/transaction/request-body.ts index 6ae433f..681704b 100644 --- a/ui/src/components/transaction/request-body.ts +++ b/ui/src/components/transaction/request-body.ts @@ -34,10 +34,12 @@ export class RequestBodyViewComponent extends LitElement { parseFormEncodedData(data: string): Map { const map = new Map(); - const pairs = data.split('&'); - for (const pair of pairs) { - const [key, value] = pair.split('='); - map.set(decodeURI(key), decodeURI(value)); + if (data) { + const pairs = data.split('&'); + for (const pair of pairs) { + const [key, value] = pair.split('='); + map.set(decodeURI(key), decodeURI(value)); + } } return map; } diff --git a/ui/src/model/http_transaction.ts b/ui/src/model/http_transaction.ts index 82c7511..b5eb5b1 100644 --- a/ui/src/model/http_transaction.ts +++ b/ui/src/model/http_transaction.ts @@ -154,7 +154,7 @@ export class HttpTransaction extends HttpTransactionBase { } // check if the keyword filter is in the response body. - if (this.httpResponse.responseBody?.toLowerCase().includes(keywordFilter.keyword.toLowerCase())) { + if (this.httpResponse?.responseBody?.toLowerCase().includes(keywordFilter.keyword.toLowerCase())) { return keywordFilter; }