diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index 7633d3b..d6cd492 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -73,7 +73,6 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e defer br.done() firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, url) if err != nil { - br.err = err firstReqResultCh <- firstReqResult{err: err} return err } @@ -145,7 +144,6 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e defer br.done() resp, err := m.DoRequest(ctx, start, end, trueURL) if err != nil { - br.err = err return err } defer resp.Body.Close() @@ -157,6 +155,10 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e return newChanMultiReader(readersCh), fileSize, nil } +func (m *BufferMode) Wait() error { + return m.sem.Wait() +} + func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, "GET", trueURL, nil) if err != nil { diff --git a/pkg/download/buffer_unit_test.go b/pkg/download/buffer_unit_test.go index 6ff4c78..3e778e3 100644 --- a/pkg/download/buffer_unit_test.go +++ b/pkg/download/buffer_unit_test.go @@ -2,15 +2,18 @@ package download import ( "context" + "fmt" "io" "math/rand" "net/http" "net/http/httptest" "net/url" + "strings" "testing" "testing/fstest" "github.com/dustin/go-humanize" + "github.com/jarcoal/httpmock" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -42,8 +45,6 @@ func newTestServer(t *testing.T, content []byte) *httptest.Server { return server } -// TODO: Implement the test -// func TestGetFileSizeFromContentRange(t *testing.T) {} func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) { contentSize := int64(humanize.KiByte) content := generateTestContent(contentSize) @@ -116,9 +117,51 @@ func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) { require.NoError(t, err) data, err := io.ReadAll(download) assert.NoError(t, err) + err = bufferMode.Wait() + assert.NoError(t, err) assert.Equal(t, contentSize, size) assert.Equal(t, len(content), len(data)) assert.Equal(t, content, data) }) } } + +func TestWaitReturnsErrorWhenRequestFails(t *testing.T) { + mockTransport := httpmock.NewMockTransport() + opts := Options{ + Client: client.Options{Transport: mockTransport}, + ChunkSize: 2, + } + expectedErr := fmt.Errorf("Expected error in chunk 3") + mockTransport.RegisterResponder("GET", "http://test.example/hello.txt", + func(req *http.Request) (*http.Response, error) { + rangeHeader := req.Header.Get("Range") + var body string + switch rangeHeader { + case "bytes=0-1": + body = "he" + case "bytes=2-3": + body = "ll" + case "bytes=4-5": + body = "o " + case "bytes=6-7": + return nil, expectedErr + default: + return nil, fmt.Errorf("should't see this error") + } + resp := httpmock.NewStringResponse(http.StatusPartialContent, body) + resp.Request = req + resp.Header.Add("Content-Range", strings.Replace(rangeHeader, "=", " ", 1)+"/8") + resp.ContentLength = 2 + resp.Header.Add("Content-Length", "2") + return resp, nil + }) + bufferMode := GetBufferMode(opts) + download, _, err := bufferMode.Fetch(context.Background(), "http://test.example/hello.txt") + // No error here, because the first chunk was fetched successfully + require.NoError(t, err) + // the read might or might not return an error + _, _ = io.ReadAll(download) + err = bufferMode.Wait() + assert.ErrorIs(t, err, expectedErr) +} diff --git a/pkg/download/buffered_reader.go b/pkg/download/buffered_reader.go index 6994ea0..d57c080 100644 --- a/pkg/download/buffered_reader.go +++ b/pkg/download/buffered_reader.go @@ -15,7 +15,6 @@ type bufferedReader struct { // ready channel is closed when we're ready to read ready chan struct{} buf *bytes.Buffer - err error pool *bufferPool } @@ -36,9 +35,6 @@ func newBufferedReader(pool *bufferPool) *bufferedReader { // pool. func (b *bufferedReader) Read(buf []byte) (int, error) { <-b.ready - if b.err != nil { - return 0, b.err - } n, err := b.buf.Read(buf) // If we've read all the data, if b.buf.Len() == 0 && b.buf != emptyBuffer { @@ -59,17 +55,14 @@ func (b *bufferedReader) downloadBody(resp *http.Response) error { expectedBytes := resp.ContentLength if expectedBytes > int64(b.buf.Cap()) { - b.err = fmt.Errorf("Tried to download 0x%x bytes to a 0x%x-sized buffer", expectedBytes, b.buf.Cap()) - return b.err + return fmt.Errorf("Tried to download 0x%x bytes to a 0x%x-sized buffer", expectedBytes, b.buf.Cap()) } n, err := b.buf.ReadFrom(resp.Body) if err != nil && err != io.EOF { - b.err = fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err) - return b.err + return fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err) } if n != expectedBytes { - b.err = fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String()) - return b.err + return fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String()) } return nil } diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index d353361..0f9cf54 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -116,9 +116,10 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io defer br.done() firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, urlString) if err != nil { - br.err = err firstReqResultCh <- firstReqResult{err: err} - return err + // The error will be handled by the firstReqResultCh consumer, + // and may be recoverable; so we return nil to the errGroup + return nil } defer firstChunkResp.Body.Close() @@ -217,7 +218,6 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString) } if err != nil { - br.err = err return err } } @@ -232,6 +232,10 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io return newChanMultiReader(readersCh), fileSize, nil } +func (m *ConsistentHashingMode) Wait() error { + return m.sem.Wait() +} + func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) { chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true) req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil) diff --git a/pkg/download/consistent_hashing_test.go b/pkg/download/consistent_hashing_test.go index 2751971..2ac80a2 100644 --- a/pkg/download/consistent_hashing_test.go +++ b/pkg/download/consistent_hashing_test.go @@ -270,6 +270,8 @@ func TestConsistentHashing(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) assert.Equal(t, tc.expectedOutput, string(bytes)) }) @@ -321,6 +323,8 @@ func TestConsistentHashingPathBased(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) assert.Equal(t, tc.expectedOutput, string(bytes)) }) @@ -352,6 +356,8 @@ func TestConsistentHashRetries(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) // with a functional hostnames[0], we'd see 0344760706165500, but instead we // should fall back to this. Note that each 0 value has been changed to a @@ -387,6 +393,8 @@ func TestConsistentHashRetriesMissingHostname(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) // with a functional hostnames[0], we'd see 0344760706165500, but instead we // should fall back to this. Note that each 0 value has been changed to a @@ -421,6 +429,8 @@ func TestConsistentHashRetriesTwoHosts(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) assert.Equal(t, "0000000000000000", string(bytes)) } @@ -448,6 +458,8 @@ func TestConsistentHashingHasFallback(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) assert.Equal(t, "0000000000000000", string(bytes)) } @@ -476,6 +488,10 @@ func (s *testStrategy) Fetch(ctx context.Context, url string) (io.Reader, int64, return io.NopCloser(strings.NewReader("00")), -1, nil } +func (s *testStrategy) Wait() error { + return nil +} + func (s *testStrategy) DoRequest(ctx context.Context, start, end int64, url string) (*http.Response, error) { s.mut.Lock() s.doRequestCalledCount++ diff --git a/pkg/download/strategy.go b/pkg/download/strategy.go index a430dd2..e716363 100644 --- a/pkg/download/strategy.go +++ b/pkg/download/strategy.go @@ -15,6 +15,9 @@ type Strategy interface { // This is the primary method that should be called to initiate a download of a file. Fetch(ctx context.Context, url string) (result io.Reader, fileSize int64, err error) + // Wait waits until all requests have completed, and returns the first error encountered, if any. + Wait() error + // DoRequest sends an HTTP GET request with a specified range of bytes to the given URL using the provided context. // It returns the HTTP response and any error encountered during the request. It is intended that Fetch calls DoRequest // and that each chunk is downloaded with a call to DoRequest. DoRequest is exposed so that consistent-hashing can diff --git a/pkg/pget.go b/pkg/pget.go index 3267174..84b0e03 100644 --- a/pkg/pget.go +++ b/pkg/pget.go @@ -53,6 +53,11 @@ func (g *Getter) DownloadFile(ctx context.Context, url string, dest string) (int if err != nil { return fileSize, 0, fmt.Errorf("error writing file: %w", err) } + err = g.Downloader.Wait() + if err != nil { + return fileSize, 0, err + } + // writeElapsed := time.Since(writeStartTime) totalElapsed := time.Since(downloadStartTime)