Skip to content

Commit

Permalink
Added new header drop feature
Browse files Browse the repository at this point in the history
cleaned up regex handling for case sensitivity (had to remove viper) and also added query param to all rewrites.

Signed-off-by: Dave Shanley <[email protected]>
  • Loading branch information
daveshanley committed Jul 6, 2023
1 parent e05f703 commit 2c5fea8
Show file tree
Hide file tree
Showing 11 changed files with 273 additions and 207 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ paths:
'^/pb33f/(\w+)/test/': ''
```
## Dropping certain headers
To prevent certain headers from being proxies, you can drop them using the `headers` config property and the `drop` property
which is an array of headers to drop from all outbound requests..

```yaml
headers:
drop:
- Origin
```

## Command Line Interface

### Available Flags
Expand Down
192 changes: 77 additions & 115 deletions cmd/root_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ package cmd

import (
"embed"
"github.com/mitchellh/mapstructure"
"gopkg.in/yaml.v3"
"net/url"
"os"

"github.com/pb33f/wiretap/shared"
"github.com/pterm/pterm"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)

var (
Expand All @@ -33,89 +32,24 @@ var (

configFlag, _ := cmd.Flags().GetString("config")

if configFlag == "" {
pterm.Info.Println("Attempting to locate wiretap configuration...")
viper.SetConfigFile(".wiretap")
viper.SetConfigType("env")
viper.AddConfigPath("$HOME/.wiretap")
viper.AddConfigPath(".")
} else {
viper.SetConfigFile(configFlag)
}

cerr := viper.ReadInConfig()
if cerr != nil && configFlag != "" {
pterm.Error.Printf("No wiretap configuration located. Using defaults: %s\n", cerr.Error())
}
if cerr != nil && configFlag == "" {
pterm.Info.Println("No wiretap configuration located. Using defaults.")
}
if cerr == nil {
pterm.Info.Printf("Located configuration file at: %s\n", viper.ConfigFileUsed())
}

var spec string
var port string
var monitorPort string
var wsPort string
var staticDir string
var pathConfigurations map[string]*shared.WiretapPathConfig

var redirectHost string
var redirectPort string
var redirectScheme string
var redirectBasePath string
var redirectURL string
var globalAPIDelay int

// extract from wiretap environment variables.
if viper.IsSet("PORT") {
port = viper.GetString("PORT")
}

if viper.IsSet("SPEC") {
spec = viper.GetString("SPEC")
}

if viper.IsSet("MONITOR_PORT") {
monitorPort = viper.GetString("MONITOR_PORT")
}

if viper.IsSet("WEBSOCKET_PORT") {
wsPort = viper.GetString("WEBSOCKET_PORT")
}

if viper.IsSet("STATIC_DIR") {
staticDir = viper.GetString("STATIC_DIR")
}

if viper.IsSet("PATHS") {
paths := viper.Get("PATHS")
var pc map[string]*shared.WiretapPathConfig
err := mapstructure.Decode(paths, &pc)
if err != nil {
pterm.Error.Printf("Unable to decode paths from configuration: %s\n", err.Error())
} else {
// print out the path configurations.
printLoadedPathConfigurations(pc)
pathConfigurations = pc
}
}

if viper.IsSet("REDIRECT_URL") {
redirectURL = viper.GetString("REDIRECT_URL")
}

if viper.IsSet("GLOBAL_API_DELAY") {
globalAPIDelay = viper.GetInt("GLOBAL_API_DELAY")
}

portFlag, _ := cmd.Flags().GetString("port")
if portFlag != "" {
port = portFlag
} else {
if port == "" {
port = "9090" // default
}
port = "9090" // default
}

specFlag, _ := cmd.Flags().GetString("spec")
Expand All @@ -127,9 +61,7 @@ var (
if monitorPortFlag != "" {
monitorPort = monitorPortFlag
} else {
if monitorPort == "" {
monitorPort = "9091" // default
}
monitorPort = "9091" // default
}

staticDirFlag, _ := cmd.Flags().GetString("static")
Expand All @@ -141,18 +73,11 @@ var (
if wsPortFlag != "" {
wsPort = wsPortFlag
} else {
if wsPort == "" {
wsPort = "9092" // default
}
wsPort = "9092" // default
}

redirectURLFlag, _ := cmd.Flags().GetString("url")
if redirectURLFlag != "" {

if pathConfigurations != nil {
// warn the user that the path configurations will trump the switch
pterm.Warning.Println("Using the --url flag will be *overridden* by the path configuration 'target' setting")
}
redirectURL = redirectURLFlag
}

Expand All @@ -161,6 +86,27 @@ var (
globalAPIDelay = globalAPIDelayFlag
}

var config shared.WiretapConfiguration
if configFlag != "" {

cBytes, err := os.ReadFile(configFlag)
if err != nil {
pterm.Error.Printf("Failed to read wiretap configuration '%s': %s\n", configFlag, err.Error())
return err
}
err = yaml.Unmarshal(cBytes, &config)
if err != nil {
pterm.Error.Printf("Failed to parse wiretap configuration '%s': %s\n", configFlag, err.Error())
return err
}
pterm.Info.Printf("Loaded wiretap configuration '%s'...\n\n", configFlag)
if config.RedirectURL != "" {
redirectURL = config.RedirectURL
}
} else {
pterm.Info.Println("No wiretap configuration located. Using defaults")
}

if spec == "" {
pterm.Warning.Println("No OpenAPI specification provided. " +
"Please provide a path to an OpenAPI specification using the --spec or -s flags.")
Expand All @@ -177,48 +123,64 @@ var (
return nil
}

if redirectURL != "" {
parsedURL, e := url.Parse(redirectURL)
if e != nil {
pterm.Println()
pterm.Error.Printf("URL is not valid. "+
"Please provide a valid URL to redirect to. %s cannot be parsed\n\n", redirectURL)
pterm.Println()
return nil
}
if parsedURL.Scheme == "" || parsedURL.Host == "" {
pterm.Println()
pterm.Error.Printf("URL is not valid. "+
"Please provide a valid URL to redirect to. %s cannot be parsed\n\n", redirectURL)
pterm.Println()
return nil
}
redirectHost = parsedURL.Hostname()
redirectPort = parsedURL.Port()
redirectScheme = parsedURL.Scheme
redirectBasePath = parsedURL.Path
parsedURL, e := url.Parse(redirectURL)
if e != nil {
pterm.Println()
pterm.Error.Printf("URL is not valid. "+
"Please provide a valid URL to redirect to. %s cannot be parsed\n\n", redirectURL)
pterm.Println()
return nil
}
if parsedURL.Scheme == "" || parsedURL.Host == "" {
pterm.Println()
pterm.Error.Printf("URL is not valid. "+
"Please provide a valid URL to redirect to. %s cannot be parsed\n\n", redirectURL)
pterm.Println()
return nil
}
redirectHost = parsedURL.Hostname()
redirectPort = parsedURL.Port()
redirectScheme = parsedURL.Scheme
redirectBasePath = parsedURL.Path

config := shared.WiretapConfiguration{
Contract: spec,
RedirectURL: redirectURL,
RedirectHost: redirectHost,
RedirectBasePath: redirectBasePath,
RedirectPort: redirectPort,
RedirectProtocol: redirectScheme,
Port: port,
MonitorPort: monitorPort,
GlobalAPIDelay: globalAPIDelay,
WebSocketPort: wsPort,
StaticDir: staticDir,
PathConfigurations: pathConfigurations,
FS: FS,
config.Contract = spec
config.RedirectURL = redirectURL
config.RedirectHost = redirectHost
config.RedirectBasePath = redirectBasePath
config.RedirectPort = redirectPort
config.RedirectProtocol = redirectScheme
if config.Port == "" {
config.Port = port
}
if config.MonitorPort == "" {
config.MonitorPort = monitorPort
}
if config.WebSocketPort == "" {
config.WebSocketPort = wsPort
}
if config.GlobalAPIDelay == 0 {
config.GlobalAPIDelay = globalAPIDelay
}
if config.StaticDir == "" {
config.StaticDir = staticDir
}
config.FS = FS

if len(pathConfigurations) > 0 {
if len(config.PathConfigurations) > 0 {
printLoadedPathConfigurations(config.PathConfigurations)
config.CompilePaths()
}

if config.Headers != nil && len(config.Headers.DropHeaders) > 0 {

pterm.Info.Printf("Dropping the following %d %s:\n", len(config.Headers.DropHeaders),
shared.Pluralize(len(config.Headers.DropHeaders), "header", "headers"))
for _, header := range config.Headers.DropHeaders {
pterm.Printf("🗑️ %s\n", pterm.LightMagenta(header))
}
pterm.Println()
}

// ready to boot, let's go!
_, pErr := runWiretapService(&config)

Expand Down
51 changes: 47 additions & 4 deletions config/paths_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/pb33f/wiretap/shared"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
"strings"
"testing"
)
Expand Down Expand Up @@ -36,7 +37,6 @@ paths:
PathConfigurations: pc,
}

// compile paths
wcConfig.CompilePaths()

res := FindPaths("/pb33f/test/123", wcConfig)
Expand Down Expand Up @@ -74,7 +74,6 @@ paths:
PathConfigurations: pc,
}

// compile paths
wcConfig.CompilePaths()

path := RewritePath("/pb33f/test/123/slap/a/chap", wcConfig)
Expand Down Expand Up @@ -106,7 +105,6 @@ paths:
PathConfigurations: pc,
}

// compile paths
wcConfig.CompilePaths()

path := RewritePath("/pb33f/cakes/test/123/smelly/jelly", wcConfig)
Expand Down Expand Up @@ -138,10 +136,55 @@ paths:
PathConfigurations: pc,
}

// compile paths
wcConfig.CompilePaths()

path := RewritePath("/pb33f/cakes/test/lemons/321/smelly/jelly", wcConfig)
assert.Equal(t, "https://localhost:9093/slippy/cakes/whip/321/lemons/smelly/jelly", path)

}

func TestRewritePath_Secure_With_Variables_CaseSensitive(t *testing.T) {

config := `
paths:
/en-US/burgerd/__raw/*:
target: localhost:80
pathRewrite:
'^/en-US/burgerd/__raw/(\w+)/nobody/': '$1/-/'
/en-US/burgerd/services/*:
target: locahost:80
pathRewrite:
'^/en-US/burgerd/services': '/services'`

var c shared.WiretapConfiguration
_ = yaml.Unmarshal([]byte(config), &c)

c.CompilePaths()

path := RewritePath("/en-US/burgerd/__raw/noKetchupPlease/nobody/", &c)
assert.Equal(t, "http://localhost:80/noKetchupPlease/-/", path)

}

func TestRewritePath_Secure_With_Variables_CaseSensitive_AndQuery(t *testing.T) {

config := `
paths:
/en-US/burgerd/__raw/*:
target: localhost:80
pathRewrite:
'^/en-US/burgerd/__raw/(\w+)/nobody/': '$1/-/'
/en-US/burgerd/services/*:
target: locahost:80
pathRewrite:
'^/en-US/burgerd/services': '/services'`

var c shared.WiretapConfiguration
_ = yaml.Unmarshal([]byte(config), &c)

c.CompilePaths()

path := RewritePath("/en-US/burgerd/__raw/noKetchupPlease/nobody/yummy/yum?onions=true", &c)
assert.Equal(t, "http://localhost:80/noKetchupPlease/-/yummy/yum?onions=true", path)

}
5 changes: 4 additions & 1 deletion daemon/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ func (ws *WiretapService) callAPI(req *http.Request) (*http.Response, error) {
replaced := config.RewritePath(req.URL.Path, wiretapConfig)
if replaced != "" {
newUrl, _ := url.Parse(replaced)
pterm.Info.Printf("[wiretap] Re-writing path '%s' to '%s'\n", req.URL.Path, replaced)
if req.URL.RawQuery != "" {
newUrl.RawQuery = req.URL.RawQuery
}
pterm.Info.Printf("[wiretap] Re-writing path '%s' to '%s'\n", req.URL.String(), newUrl.String())
req.URL = newUrl
}

Expand Down
Loading

0 comments on commit 2c5fea8

Please sign in to comment.