Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup #97

Merged
merged 3 commits into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions cmd/multifile/multifile.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,18 @@ func multifileExecute(ctx context.Context, manifest manifest) error {
return err
}

// Get the resolution overrides
resolveOverrides, err := config.ResolveOverridesToMap(viper.GetStringSlice(config.OptResolve))
if err != nil {
return fmt.Errorf("error parsing resolve overrides: %w", err)
}

clientOpts := client.Options{
MaxConnPerHost: viper.GetInt(config.OptMaxConnPerHost),
ForceHTTP2: viper.GetBool(config.OptForceHTTP2),
MaxRetries: viper.GetInt(config.OptRetries),
ConnectTimeout: viper.GetDuration(config.OptConnTimeout),
MaxConnPerHost: viper.GetInt(config.OptMaxConnPerHost),
ForceHTTP2: viper.GetBool(config.OptForceHTTP2),
MaxRetries: viper.GetInt(config.OptRetries),
ConnectTimeout: viper.GetDuration(config.OptConnTimeout),
ResolveOverrides: resolveOverrides,
}
downloadOpts := download.Options{
MaxConcurrency: viper.GetInt(config.OptConcurrency),
Expand Down
13 changes: 9 additions & 4 deletions cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,16 @@ func rootExecute(ctx context.Context, urlString, dest string) error {
return fmt.Errorf("error parsing minimum chunk size: %w", err)
}

resolveOverrides, err := config.ResolveOverridesToMap(viper.GetStringSlice(config.OptResolve))
if err != nil {
return fmt.Errorf("error parsing resolve overrides: %w", err)
}
clientOpts := client.Options{
ForceHTTP2: viper.GetBool(config.OptForceHTTP2),
MaxRetries: viper.GetInt(config.OptRetries),
ConnectTimeout: viper.GetDuration(config.OptConnTimeout),
MaxConnPerHost: viper.GetInt(config.OptMaxConnPerHost),
ForceHTTP2: viper.GetBool(config.OptForceHTTP2),
MaxRetries: viper.GetInt(config.OptRetries),
ConnectTimeout: viper.GetDuration(config.OptConnTimeout),
MaxConnPerHost: viper.GetInt(config.OptMaxConnPerHost),
ResolveOverrides: resolveOverrides,
}

downloadOpts := download.Options{
Expand Down
44 changes: 24 additions & 20 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (

"github.com/hashicorp/go-retryablehttp"

"github.com/replicate/pget/pkg/config"
"github.com/replicate/pget/pkg/logging"
"github.com/replicate/pget/pkg/version"
)
Expand All @@ -35,23 +34,29 @@ func (c *HTTPClient) Do(req *http.Request) (*http.Response, error) {
}

type Options struct {
ForceHTTP2 bool
MaxConnPerHost int
MaxRetries int
ConnectTimeout time.Duration
ForceHTTP2 bool
MaxConnPerHost int
MaxRetries int
ConnectTimeout time.Duration
ResolveOverrides map[string]string
}

// NewHTTPClient factory function returns a new http.Client with the appropriate settings and can limit number of clients
// per host if the OptMaxConnPerHost option is set.
func NewHTTPClient(opts Options) *HTTPClient {
disableKeepAlives := opts.ForceHTTP2

transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: transportDialContext(&net.Dialer{
dialer := &transportDialer{
DNSOverrideMap: opts.ResolveOverrides,
Dialer: &net.Dialer{
Timeout: opts.ConnectTimeout,
KeepAlive: 30 * time.Second,
}),
},
}

transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
ForceAttemptHTTP2: opts.ForceHTTP2,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
Expand Down Expand Up @@ -126,19 +131,18 @@ func checkRedirectFunc(req *http.Request, via []*http.Request) error {
return nil
}

// transportDialContext is a wrapper around net.Dialer that allows for overriding DNS lookups via the values passed to
// `--resolve` argument.
func transportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) {
logger := logging.GetLogger()
type transportDialer struct {
DNSOverrideMap map[string]string
Dialer *net.Dialer
}

// Allow for overriding DNS lookups in the dialer without impacting Host and SSL resolution
return func(ctx context.Context, network, addr string) (net.Conn, error) {
if addrOverride := config.HostToIPResolutionMap[addr]; addrOverride != "" {
logger.Debug().Str("addr", addr).Str("override", addrOverride).Msg("DNS Override")
addr = addrOverride
}
return dialer.DialContext(ctx, network, addr)
func (d *transportDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
logger := logging.GetLogger()
if addrOverride := d.DNSOverrideMap[addr]; addrOverride != "" {
logger.Debug().Str("addr", addr).Str("override", addrOverride).Msg("DNS Override")
addr = addrOverride
}
return d.Dialer.DialContext(ctx, network, addr)
}

func GetSchemeHostKey(urlString string) (string, error) {
Expand Down
38 changes: 20 additions & 18 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,12 @@ type DeprecatedFlag struct {
Msg string
}

// HostToIPResolutionMap is a map of hostnames to IP addresses
// TODO: Eliminate this global variable
var HostToIPResolutionMap = make(map[string]string)

func PersistentStartupProcessFlags() error {
if viper.GetBool(OptVerbose) {
viper.Set(OptLoggingLevel, "debug")
}
setLogLevel(viper.GetString(OptLoggingLevel))
if err := convertResolveHostsToMap(); err != nil {
return err
}
return nil

}

func HideFlags(cmd *cobra.Command, flags ...string) error {
Expand Down Expand Up @@ -109,34 +101,44 @@ func setLogLevel(logLevel string) {
}
}

func convertResolveHostsToMap() error {
func ResolveOverridesToMap(resolveOverrides []string) (map[string]string, error) {
logger := logging.GetLogger()
for _, resolveHost := range viper.GetStringSlice(OptResolve) {
resolveOverrideMap := make(map[string]string)

if len(resolveOverrides) == 0 {
return nil, nil
}

for _, resolveHost := range resolveOverrides {
split := strings.SplitN(resolveHost, ":", 3)
if len(split) != 3 {
return fmt.Errorf("invalid resolve host format, expected <hostname>:port:<ip>, got: %s", resolveHost)
return nil, fmt.Errorf("invalid resolve host format, expected <hostname>:port:<ip>, got: %s", resolveHost)
}
host, port, addr := split[0], split[1], split[2]
if net.ParseIP(host) != nil {
return fmt.Errorf("invalid hostname specified, looks like an IP address: %s", host)
return nil, fmt.Errorf("invalid hostname specified, looks like an IP address: %s", host)
}
hostPort := net.JoinHostPort(host, port)
if _, ok := HostToIPResolutionMap[hostPort]; ok {
return fmt.Errorf("duplicate host:port specified: %s", host)
if override, ok := resolveOverrideMap[hostPort]; ok {
if override == net.JoinHostPort(addr, port) {
// duplicate entry, ignore
continue
}
return nil, fmt.Errorf("duplicate host:port specified: %s", host)
}
if net.ParseIP(addr) == nil {
return fmt.Errorf("invalid IP address: %s", addr)
return nil, fmt.Errorf("invalid IP address: %s", addr)
}
HostToIPResolutionMap[hostPort] = net.JoinHostPort(addr, port)
resolveOverrideMap[hostPort] = net.JoinHostPort(addr, port)
}
if logger.GetLevel() == zerolog.DebugLevel {
logger := logging.GetLogger()

for key, elem := range HostToIPResolutionMap {
for key, elem := range resolveOverrideMap {
logger.Debug().Str("host_port", key).Str("resolve_target", elem).Msg("Config")
}
}
return nil
return resolveOverrideMap, nil
}

// GetConsumer returns the consumer specified by the user on the command line
Expand Down
26 changes: 9 additions & 17 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package config

import (
"strings"
"testing"

"github.com/rs/zerolog"
Expand Down Expand Up @@ -29,35 +28,28 @@ func TestSetLogLevel(t *testing.T) {
}
}

func TestConvertResolveHostsToMap(t *testing.T) {
defer func() {
HostToIPResolutionMap = map[string]string{}
viper.Reset()
}()

func TestResolveOverrides(t *testing.T) {
testCases := []struct {
name string
resolve []string
expected map[string]string
err bool
}{
{"empty", []string{}, map[string]string{}, false},
{"empty", []string{}, nil, false},
{"single", []string{"example.com:80:127.0.0.1"}, map[string]string{"example.com:80": "127.0.0.1:80"}, false},
{"multiple", []string{"example.com:80:127.0.0.1", "example.com:443:127.0.0.1"}, map[string]string{"example.com:80": "127.0.0.1:80", "example.com:443": "127.0.0.1:443"}, false},
{"invalid ip", []string{"example.com:80:InvalidIPAddr"}, map[string]string{}, true},
{"duplicate host", []string{"example.com:80:127.0.0.1", "example.com:80:127.0.0.2"}, map[string]string{"example.com:80": "127.0.0.1:80"}, true},
{"invalid format", []string{"example.com:80"}, map[string]string{}, true},
{"invalid hostname format, is IP Addr", []string{"127.0.0.1:443:127.0.0.2"}, map[string]string{}, true},
{"invalid ip", []string{"example.com:80:InvalidIPAddr"}, nil, true},
{"duplicate host different target", []string{"example.com:80:127.0.0.1", "example.com:80:127.0.0.2"}, nil, true},
{"duplicate host same target", []string{"example.com:80:127.0.0.1", "example.com:80:127.0.0.1"}, map[string]string{"example.com:80": "127.0.0.1:80"}, false},
{"invalid format", []string{"example.com:80"}, nil, true},
{"invalid hostname format, is IP Addr", []string{"127.0.0.1:443:127.0.0.2"}, nil, true},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
HostToIPResolutionMap = map[string]string{}
viper.Set(OptResolve, strings.Join(tc.resolve, " "))
err := convertResolveHostsToMap()
resolveOverrides, err := ResolveOverridesToMap(tc.resolve)
assert.Equal(t, tc.err, err != nil)
assert.Equal(t, tc.expected, HostToIPResolutionMap)
viper.Reset()
assert.Equal(t, tc.expected, resolveOverrides)
})
}
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/download/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ func (m *BufferMode) fileToBuffer(ctx context.Context, url string) (*bytes.Buffe
if i == numChunks-1 {
end = fileSize - 1
}
resp, err := m.doRequest(ctx, start, end, trueURL)
if err != nil {
return nil, -1, err
}
errGroup.Go(func() error {
resp, err := m.doRequest(ctx, start, end, trueURL)
if err != nil {
return err
}
return m.downloadChunk(resp, data[start:end+1])
})
}
Expand Down