diff --git a/cmd/multifile/manifest.go b/cmd/multifile/manifest.go index b3b1eb3..c00ed8b 100644 --- a/cmd/multifile/manifest.go +++ b/cmd/multifile/manifest.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/fs" + netUrl "net/url" "os" "strings" @@ -14,6 +15,7 @@ import ( pget "github.com/replicate/pget/pkg" "github.com/replicate/pget/pkg/cli" "github.com/replicate/pget/pkg/config" + "github.com/replicate/pget/pkg/logging" ) // A manifest is a file consisting of pairs of URLs and paths: @@ -27,6 +29,8 @@ import ( // When we parse a manifest, we group by URL base (ie scheme://hostname) so that // all URLs that may share a connection are grouped. +var errDupeURLDestCombo = errors.New("duplicate destination with different URLs") + func manifestFile(manifestPath string) (*os.File, error) { if manifestPath == "-" { return os.Stdin, nil @@ -41,7 +45,7 @@ func manifestFile(manifestPath string) (*os.File, error) { return file, err } -func parseLine(line string) (urlString, dest string, err error) { +func parseLine(line string) (url, dest string, err error) { fields := strings.Fields(line) if len(fields) != 2 { return "", "", fmt.Errorf("error parsing manifest invalid line format `%s`", line) @@ -49,18 +53,19 @@ func parseLine(line string) (urlString, dest string, err error) { return fields[0], fields[1], nil } -func checkSeenDestinations(destinations map[string]string, dest string, urlString string) error { +func checkSeenDestinations(destinations map[string]string, dest string, url string) error { if seenURL, ok := destinations[dest]; ok { - if seenURL != urlString { - return fmt.Errorf("duplicate destination %s with different urls: %s and %s", dest, seenURL, urlString) + if seenURL != url { + return fmt.Errorf("duplicate destination %s with different urls: %s and %s", dest, seenURL, url) } else { - return fmt.Errorf("duplicate entry: %s %s", urlString, dest) + return errDupeURLDestCombo } } return nil } func parseManifest(file io.Reader) (pget.Manifest, error) { + logger := logging.GetLogger() seenDestinations := make(map[string]string) manifest := make(pget.Manifest, 0) @@ -83,6 +88,13 @@ func parseManifest(file io.Reader) (pget.Manifest, error) { if consumer != config.ConsumerNull { err = checkSeenDestinations(seenDestinations, dest, urlString) if err != nil { + if errors.Is(err, errDupeURLDestCombo) { + logger.Warn(). + Str("url", urlString). + Str("destination", dest). + Msg("Parse Manifest: Skip Duplicate URL/Destination") + continue + } return nil, err } seenDestinations[dest] = urlString @@ -92,12 +104,17 @@ func parseManifest(file io.Reader) (pget.Manifest, error) { return nil, err } } + if valid, err := validURL(urlString); !valid { + return nil, fmt.Errorf("error parsing manifest invalid URL: %s: %w", urlString, err) - manifest, err = manifest.AddEntry(urlString, dest) - if err != nil { - return nil, fmt.Errorf("error adding url: %w", err) } + manifest = manifest.AddEntry(urlString, dest) } return manifest, nil } + +func validURL(s string) (bool, error) { + _, err := netUrl.Parse(s) + return err == nil, err +} diff --git a/cmd/multifile/manifest_test.go b/cmd/multifile/manifest_test.go index 632e79e..967ba85 100644 --- a/cmd/multifile/manifest_test.go +++ b/cmd/multifile/manifest_test.go @@ -7,8 +7,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - pget "github.com/replicate/pget/pkg" ) // validManifest is a valid manifest file with additional empty lines @@ -43,45 +41,22 @@ func TestParseLine(t *testing.T) { assert.Error(t, err) } -func TestCheckSeenDests(t *testing.T) { - seenDests := map[string]string{ +func TestCheckSeenDestinations(t *testing.T) { + seenDestinations := map[string]string{ "/tmp/file1.txt": "https://example.com/file1.txt", } // a different destination is fine - err := checkSeenDestinations(seenDests, "/tmp/file2.txt", "https://example.com/file2.txt") - assert.NoError(t, err) + err := checkSeenDestinations(seenDestinations, "/tmp/file2.txt", "https://example.com/file2.txt") + require.NoError(t, err) // the same destination with a different URL is not fine - err = checkSeenDestinations(seenDests, "/tmp/file1.txt", "https://example.com/file2.txt") + err = checkSeenDestinations(seenDestinations, "/tmp/file1.txt", "https://example.com/file2.txt") assert.Error(t, err) - // the same destination with the same URL is also not fine - err = checkSeenDestinations(seenDests, "/tmp/file1.txt", "https://example.com/file1.txt") - assert.Error(t, err) -} - -func TestAddEntry(t *testing.T) { - entries := make(pget.Manifest, 0) - - entries, err := entries.AddEntry("https://example.com/file1.txt", "/tmp/file1.txt") - require.NoError(t, err) - assert.Len(t, entries, 1) - assert.Equal(t, "https://example.com/file1.txt", entries[0].URL()) - assert.Equal(t, "/tmp/file1.txt", entries[0].Dest) - - entries, err = entries.AddEntry("https://example.com/file2.txt", "/tmp/file2.txt") - require.NoError(t, err) - assert.Len(t, entries, 2) - assert.Equal(t, "https://example.com/file2.txt", entries[1].URL()) - assert.Equal(t, "/tmp/file2.txt", entries[1].Dest) - - entries, err = entries.AddEntry("https://example2.com/file3.txt", "/tmp/file3.txt") - require.NoError(t, err) - assert.Len(t, entries, 3) - assert.Equal(t, "https://example2.com/file3.txt", entries[2].URL()) - assert.Equal(t, "/tmp/file3.txt", entries[2].Dest) - + // the same destination with the same URL is fine, we raise a specific error to detect and skip + err = checkSeenDestinations(seenDestinations, "/tmp/file1.txt", "https://example.com/file1.txt") + assert.ErrorIs(t, err, errDupeURLDestCombo) } func TestParseManifest(t *testing.T) { diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index 8617aa4..0cb461e 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -160,7 +160,7 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { throughput := float64(totalFileSize) / elapsedTime.Seconds() logger := logging.GetLogger() logger.Info(). - Int("file_count", numEntries(manifest)). + Int("file_count", len(manifest)). Str("total_bytes_downloaded", humanize.Bytes(uint64(totalFileSize))). Str("throughput", fmt.Sprintf("%s/s", humanize.Bytes(uint64(throughput)))). Str("elapsed_time", fmt.Sprintf("%.3fs", elapsedTime.Seconds())). @@ -168,8 +168,3 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { return nil } - -func numEntries(manifest pget.Manifest) int { - return len(manifest) - -} diff --git a/pkg/pget.go b/pkg/pget.go index 0743908..3267174 100644 --- a/pkg/pget.go +++ b/pkg/pget.go @@ -3,7 +3,6 @@ package pget import ( "context" "fmt" - netUrl "net/url" "sync/atomic" "time" @@ -26,23 +25,15 @@ type Options struct { } type ManifestEntry struct { - parsedURL *netUrl.URL - Dest string -} - -func (m ManifestEntry) URL() string { - return m.parsedURL.String() + URL string + Dest string } // A Manifest is a slice of ManifestEntry, with a helper method to add entries type Manifest []ManifestEntry -func (m Manifest) AddEntry(url string, destination string) (Manifest, error) { - parsed, err := netUrl.Parse(url) - if err != nil { - return nil, fmt.Errorf("error parsing url %s: %w", url, err) - } - return append(m, ManifestEntry{parsedURL: parsed, Dest: destination}), nil +func (m Manifest) AddEntry(url string, destination string) Manifest { + return append(m, ManifestEntry{URL: url, Dest: destination}) } func (g *Getter) DownloadFile(ctx context.Context, url string, dest string) (int64, time.Duration, error) { @@ -113,7 +104,7 @@ func (g *Getter) downloadFilesFromManifest(ctx context.Context, eg *errgroup.Gro for _, entry := range entries { // Avoid the `entry` loop variable being captured by the // goroutine by creating new variables - url, dest := entry.URL(), entry.Dest + url, dest := entry.URL, entry.Dest logger.Debug().Str("url", url).Str("dest", dest).Msg("Queueing Download") eg.Go(func() error { diff --git a/pkg/pget_test.go b/pkg/pget_test.go index 027ab35..58334f4 100644 --- a/pkg/pget_test.go +++ b/pkg/pget_test.go @@ -161,7 +161,7 @@ func testDownloadMultipleFiles(opts download.Options, sizes []int64, t *testing. manifest := make(pget.Manifest, 0) for _, srcFilename := range srcFilenames { - manifest, err = manifest.AddEntry(ts.URL+"/"+srcFilename, filepath.Join(outputDir, srcFilename)) + manifest = manifest.AddEntry(ts.URL+"/"+srcFilename, filepath.Join(outputDir, srcFilename)) require.NoError(t, err) } @@ -196,3 +196,18 @@ func TestDownloadFive10MFiles(t *testing.T) { 10 * humanize.MiByte, }, t) } + +func TestManifest_AddEntry(t *testing.T) { + entries := make(pget.Manifest, 0) + + entries = entries.AddEntry("https://example.com/file1.txt", "/tmp/file1.txt") + assert.Len(t, entries, 1) + entries = entries.AddEntry("https://example.org/file2.txt", "/tmp/file2.txt") + assert.Len(t, entries, 2) + + assert.Equal(t, "https://example.com/file1.txt", entries[0].URL) + assert.Equal(t, "/tmp/file1.txt", entries[0].Dest) + assert.Equal(t, "https://example.org/file2.txt", entries[1].URL) + assert.Equal(t, "/tmp/file2.txt", entries[1].Dest) + +}