diff --git a/cmd/connect/main.go b/cmd/connect/main.go index 8735651ae..6bb5fb432 100644 --- a/cmd/connect/main.go +++ b/cmd/connect/main.go @@ -5,12 +5,13 @@ import ( "errors" "fmt" "net/http" - //nolint: gosec _ "net/http/pprof" "os" "os/signal" "path/filepath" + "regexp" + "strings" "syscall" "time" @@ -322,6 +323,14 @@ func runOracle() error { } } + // check that the marketmap endpoint they provided is correct. + if marketMapProvider == marketmap.Name { + mmEndpoint := cfg.Providers[marketMapProvider].API.Endpoints[0].URL + if err := isValidGRPCEndpoint(mmEndpoint); err != nil { + return err + } + } + var marketCfg mmtypes.MarketMap if marketCfgPath != "" { marketCfg, err = mmtypes.ReadMarketMapFromFile(marketCfgPath) @@ -467,3 +476,27 @@ func overwriteMarketMapEndpoint(cfg config.OracleConfig, overwrite string) (conf return cfg, fmt.Errorf("no market-map provider found in config") } + +// isValidGRPCEndpoint checks that the string s is a valid gRPC endpoint. (doesn't start with http, ends with a port). +func isValidGRPCEndpoint(s string) error { + if strings.HasPrefix(s, "http") { + return fmt.Errorf("expected gRPC endpoint but got HTTP endpoint %q. Please provide a gRPC endpoint (e.g. some.host:9090)", s) + } + if !hasPort(s) { + // they might do something like foo.bar:hello + // so lets just take the bit before foo.bar for the example in the error. + example := strings.Split(s, ":")[0] + return fmt.Errorf("invalid gRPC endpoint %q. Must specify port (e.g. %s:9090)", s, example) + } + return nil +} + +// hasPort reports whether s contains `:` followed by numbers. +func hasPort(s string) bool { + // matches anything that has `:` and some numbers after. + pattern := `:[0-9]+$` + + regex := regexp.MustCompile(pattern) + + return regex.MatchString(s) +} diff --git a/cmd/connect/main_test.go b/cmd/connect/main_test.go new file mode 100644 index 000000000..d55f9ea1b --- /dev/null +++ b/cmd/connect/main_test.go @@ -0,0 +1,62 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCheckMarketMapEndpoint(t *testing.T) { + tests := []struct { + name string + endpoint string + wantErr bool + errMsg string + }{ + { + name: "Valid gRPC endpoint", + endpoint: "example.com:8080", + wantErr: false, + }, + { + name: "Valid IP address endpoint", + endpoint: "192.168.1.1:9090", + wantErr: false, + }, + { + name: "HTTP endpoint", + endpoint: "http://example.com:8080", + wantErr: true, + errMsg: `expected gRPC endpoint but got HTTP endpoint "http://example.com:8080". Please provide a gRPC endpoint (e.g. some.host:9090)`, + }, + { + name: "HTTPS endpoint", + endpoint: "https://example.com:8080", + wantErr: true, + errMsg: `expected gRPC endpoint but got HTTP endpoint "https://example.com:8080". Please provide a gRPC endpoint (e.g. some.host:9090)`, + }, + { + name: "Missing port", + endpoint: "example.com", + wantErr: true, + errMsg: `invalid gRPC endpoint "example.com". Must specify port (e.g. example.com:9090)`, + }, + { + name: "Invalid port format", + endpoint: "example.com:port", + wantErr: true, + errMsg: `invalid gRPC endpoint "example.com:port". Must specify port (e.g. example.com:9090)`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := isValidGRPCEndpoint(tt.endpoint) + if tt.wantErr { + require.EqualError(t, err, tt.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/providers/factories/oracle/marketmap.go b/providers/factories/oracle/marketmap.go index f673be748..32f14cdb4 100644 --- a/providers/factories/oracle/marketmap.go +++ b/providers/factories/oracle/marketmap.go @@ -49,11 +49,11 @@ func MarketMapProviderFactory( return nil, err } - switch name := cfg.Name; { - case name == dydx.Name: + switch cfg.Name { + case dydx.Name: apiDataHandler, err = dydx.NewAPIHandler(logger, cfg.API) ids = []types.Chain{{ChainID: dydx.ChainID}} - case name == dydx.SwitchOverAPIHandlerName: + case dydx.SwitchOverAPIHandlerName: marketMapFetcher, err = dydx.NewDefaultSwitchOverMarketMapFetcher( logger, cfg.API, @@ -61,7 +61,7 @@ func MarketMapProviderFactory( apiMetrics, ) ids = []types.Chain{{ChainID: dydx.ChainID}} - case name == dydx.ResearchAPIHandlerName || name == dydx.ResearchCMCAPIHandlerName: + case dydx.ResearchAPIHandlerName, dydx.ResearchCMCAPIHandlerName: marketMapFetcher, err = dydx.DefaultDYDXResearchMarketMapFetcher( requestHandler, apiMetrics,