Skip to content

Commit

Permalink
fix: error handling for invalid gRPC endpoints (#757)
Browse files Browse the repository at this point in the history
  • Loading branch information
technicallyty committed Sep 18, 2024
1 parent 4afa6c6 commit 04c8ba5
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 5 deletions.
35 changes: 34 additions & 1 deletion cmd/connect/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import (
"errors"
"fmt"
"net/http"

//nolint: gosec
_ "net/http/pprof"
"os"
"os/signal"
"path/filepath"
"regexp"
"strings"
"syscall"
"time"

Expand Down Expand Up @@ -322,6 +323,14 @@ func runOracle() error {
}
}

// check that the marketmap endpoint they provided is correct.
if marketMapProvider == marketmap.Name {
mmEndpoint := cfg.Providers[marketMapProvider].API.Endpoints[0].URL
if err := isValidGRPCEndpoint(mmEndpoint); err != nil {
return err
}
}

var marketCfg mmtypes.MarketMap
if marketCfgPath != "" {
marketCfg, err = mmtypes.ReadMarketMapFromFile(marketCfgPath)
Expand Down Expand Up @@ -467,3 +476,27 @@ func overwriteMarketMapEndpoint(cfg config.OracleConfig, overwrite string) (conf

return cfg, fmt.Errorf("no market-map provider found in config")
}

// isValidGRPCEndpoint checks that the string s is a valid gRPC endpoint. (doesn't start with http, ends with a port).
func isValidGRPCEndpoint(s string) error {
if strings.HasPrefix(s, "http") {
return fmt.Errorf("expected gRPC endpoint but got HTTP endpoint %q. Please provide a gRPC endpoint (e.g. some.host:9090)", s)
}
if !hasPort(s) {
// they might do something like foo.bar:hello
// so lets just take the bit before foo.bar for the example in the error.
example := strings.Split(s, ":")[0]
return fmt.Errorf("invalid gRPC endpoint %q. Must specify port (e.g. %s:9090)", s, example)
}
return nil
}

// hasPort reports whether s contains `:` followed by numbers.
func hasPort(s string) bool {
// matches anything that has `:` and some numbers after.
pattern := `:[0-9]+$`

regex := regexp.MustCompile(pattern)

return regex.MatchString(s)
}
62 changes: 62 additions & 0 deletions cmd/connect/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package main

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestCheckMarketMapEndpoint(t *testing.T) {
tests := []struct {
name string
endpoint string
wantErr bool
errMsg string
}{
{
name: "Valid gRPC endpoint",
endpoint: "example.com:8080",
wantErr: false,
},
{
name: "Valid IP address endpoint",
endpoint: "192.168.1.1:9090",
wantErr: false,
},
{
name: "HTTP endpoint",
endpoint: "http://example.com:8080",
wantErr: true,
errMsg: `expected gRPC endpoint but got HTTP endpoint "http://example.com:8080". Please provide a gRPC endpoint (e.g. some.host:9090)`,
},
{
name: "HTTPS endpoint",
endpoint: "https://example.com:8080",
wantErr: true,
errMsg: `expected gRPC endpoint but got HTTP endpoint "https://example.com:8080". Please provide a gRPC endpoint (e.g. some.host:9090)`,
},
{
name: "Missing port",
endpoint: "example.com",
wantErr: true,
errMsg: `invalid gRPC endpoint "example.com". Must specify port (e.g. example.com:9090)`,
},
{
name: "Invalid port format",
endpoint: "example.com:port",
wantErr: true,
errMsg: `invalid gRPC endpoint "example.com:port". Must specify port (e.g. example.com:9090)`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := isValidGRPCEndpoint(tt.endpoint)
if tt.wantErr {
require.EqualError(t, err, tt.errMsg)
} else {
require.NoError(t, err)
}
})
}
}
8 changes: 4 additions & 4 deletions providers/factories/oracle/marketmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,19 @@ func MarketMapProviderFactory(
return nil, err
}

switch name := cfg.Name; {
case name == dydx.Name:
switch cfg.Name {
case dydx.Name:
apiDataHandler, err = dydx.NewAPIHandler(logger, cfg.API)
ids = []types.Chain{{ChainID: dydx.ChainID}}
case name == dydx.SwitchOverAPIHandlerName:
case dydx.SwitchOverAPIHandlerName:
marketMapFetcher, err = dydx.NewDefaultSwitchOverMarketMapFetcher(
logger,
cfg.API,
requestHandler,
apiMetrics,
)
ids = []types.Chain{{ChainID: dydx.ChainID}}
case name == dydx.ResearchAPIHandlerName || name == dydx.ResearchCMCAPIHandlerName:
case dydx.ResearchAPIHandlerName, dydx.ResearchCMCAPIHandlerName:
marketMapFetcher, err = dydx.DefaultDYDXResearchMarketMapFetcher(
requestHandler,
apiMetrics,
Expand Down

0 comments on commit 04c8ba5

Please sign in to comment.