From c72c38485b2389513b8051912b47e7f63db54641 Mon Sep 17 00:00:00 2001 From: Philip Potter Date: Thu, 7 Mar 2024 15:41:44 +0000 Subject: [PATCH] implement Wait() on Strategy to fetch remaining errors (#179) Currently, we use an errgroup as a semaphore in buffer and consistent hash modes, but we never inspect the results of sem.Wait() to see if there were any errors. This results in errors being swallowed. This implements Wait(), and uses it everywhere we call Fetch(). Now that we have it, we don't need to store errors on bufferedReader so we can get rid of that. --- pkg/download/buffer.go | 6 ++-- pkg/download/buffer_unit_test.go | 47 +++++++++++++++++++++++-- pkg/download/buffered_reader.go | 13 ++----- pkg/download/consistent_hashing.go | 10 ++++-- pkg/download/consistent_hashing_test.go | 16 +++++++++ pkg/download/strategy.go | 3 ++ pkg/pget.go | 5 +++ 7 files changed, 83 insertions(+), 17 deletions(-) diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index 7633d3b..d6cd492 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -73,7 +73,6 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e defer br.done() firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, url) if err != nil { - br.err = err firstReqResultCh <- firstReqResult{err: err} return err } @@ -145,7 +144,6 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e defer br.done() resp, err := m.DoRequest(ctx, start, end, trueURL) if err != nil { - br.err = err return err } defer resp.Body.Close() @@ -157,6 +155,10 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e return newChanMultiReader(readersCh), fileSize, nil } +func (m *BufferMode) Wait() error { + return m.sem.Wait() +} + func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, "GET", trueURL, nil) if err != nil { diff --git a/pkg/download/buffer_unit_test.go b/pkg/download/buffer_unit_test.go index 6ff4c78..3e778e3 100644 --- a/pkg/download/buffer_unit_test.go +++ b/pkg/download/buffer_unit_test.go @@ -2,15 +2,18 @@ package download import ( "context" + "fmt" "io" "math/rand" "net/http" "net/http/httptest" "net/url" + "strings" "testing" "testing/fstest" "github.com/dustin/go-humanize" + "github.com/jarcoal/httpmock" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -42,8 +45,6 @@ func newTestServer(t *testing.T, content []byte) *httptest.Server { return server } -// TODO: Implement the test -// func TestGetFileSizeFromContentRange(t *testing.T) {} func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) { contentSize := int64(humanize.KiByte) content := generateTestContent(contentSize) @@ -116,9 +117,51 @@ func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) { require.NoError(t, err) data, err := io.ReadAll(download) assert.NoError(t, err) + err = bufferMode.Wait() + assert.NoError(t, err) assert.Equal(t, contentSize, size) assert.Equal(t, len(content), len(data)) assert.Equal(t, content, data) }) } } + +func TestWaitReturnsErrorWhenRequestFails(t *testing.T) { + mockTransport := httpmock.NewMockTransport() + opts := Options{ + Client: client.Options{Transport: mockTransport}, + ChunkSize: 2, + } + expectedErr := fmt.Errorf("Expected error in chunk 3") + mockTransport.RegisterResponder("GET", "http://test.example/hello.txt", + func(req *http.Request) (*http.Response, error) { + rangeHeader := req.Header.Get("Range") + var body string + switch rangeHeader { + case "bytes=0-1": + body = "he" + case "bytes=2-3": + body = "ll" + case "bytes=4-5": + body = "o " + case "bytes=6-7": + return nil, expectedErr + default: + return nil, fmt.Errorf("should't see this error") + } + resp := httpmock.NewStringResponse(http.StatusPartialContent, body) + resp.Request = req + resp.Header.Add("Content-Range", strings.Replace(rangeHeader, "=", " ", 1)+"/8") + resp.ContentLength = 2 + resp.Header.Add("Content-Length", "2") + return resp, nil + }) + bufferMode := GetBufferMode(opts) + download, _, err := bufferMode.Fetch(context.Background(), "http://test.example/hello.txt") + // No error here, because the first chunk was fetched successfully + require.NoError(t, err) + // the read might or might not return an error + _, _ = io.ReadAll(download) + err = bufferMode.Wait() + assert.ErrorIs(t, err, expectedErr) +} diff --git a/pkg/download/buffered_reader.go b/pkg/download/buffered_reader.go index 6994ea0..d57c080 100644 --- a/pkg/download/buffered_reader.go +++ b/pkg/download/buffered_reader.go @@ -15,7 +15,6 @@ type bufferedReader struct { // ready channel is closed when we're ready to read ready chan struct{} buf *bytes.Buffer - err error pool *bufferPool } @@ -36,9 +35,6 @@ func newBufferedReader(pool *bufferPool) *bufferedReader { // pool. func (b *bufferedReader) Read(buf []byte) (int, error) { <-b.ready - if b.err != nil { - return 0, b.err - } n, err := b.buf.Read(buf) // If we've read all the data, if b.buf.Len() == 0 && b.buf != emptyBuffer { @@ -59,17 +55,14 @@ func (b *bufferedReader) downloadBody(resp *http.Response) error { expectedBytes := resp.ContentLength if expectedBytes > int64(b.buf.Cap()) { - b.err = fmt.Errorf("Tried to download 0x%x bytes to a 0x%x-sized buffer", expectedBytes, b.buf.Cap()) - return b.err + return fmt.Errorf("Tried to download 0x%x bytes to a 0x%x-sized buffer", expectedBytes, b.buf.Cap()) } n, err := b.buf.ReadFrom(resp.Body) if err != nil && err != io.EOF { - b.err = fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err) - return b.err + return fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err) } if n != expectedBytes { - b.err = fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String()) - return b.err + return fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String()) } return nil } diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index d353361..0f9cf54 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -116,9 +116,10 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io defer br.done() firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, urlString) if err != nil { - br.err = err firstReqResultCh <- firstReqResult{err: err} - return err + // The error will be handled by the firstReqResultCh consumer, + // and may be recoverable; so we return nil to the errGroup + return nil } defer firstChunkResp.Body.Close() @@ -217,7 +218,6 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString) } if err != nil { - br.err = err return err } } @@ -232,6 +232,10 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io return newChanMultiReader(readersCh), fileSize, nil } +func (m *ConsistentHashingMode) Wait() error { + return m.sem.Wait() +} + func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) { chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true) req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil) diff --git a/pkg/download/consistent_hashing_test.go b/pkg/download/consistent_hashing_test.go index 2751971..2ac80a2 100644 --- a/pkg/download/consistent_hashing_test.go +++ b/pkg/download/consistent_hashing_test.go @@ -270,6 +270,8 @@ func TestConsistentHashing(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) assert.Equal(t, tc.expectedOutput, string(bytes)) }) @@ -321,6 +323,8 @@ func TestConsistentHashingPathBased(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) assert.Equal(t, tc.expectedOutput, string(bytes)) }) @@ -352,6 +356,8 @@ func TestConsistentHashRetries(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) // with a functional hostnames[0], we'd see 0344760706165500, but instead we // should fall back to this. Note that each 0 value has been changed to a @@ -387,6 +393,8 @@ func TestConsistentHashRetriesMissingHostname(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) // with a functional hostnames[0], we'd see 0344760706165500, but instead we // should fall back to this. Note that each 0 value has been changed to a @@ -421,6 +429,8 @@ func TestConsistentHashRetriesTwoHosts(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) assert.Equal(t, "0000000000000000", string(bytes)) } @@ -448,6 +458,8 @@ func TestConsistentHashingHasFallback(t *testing.T) { require.NoError(t, err) bytes, err := io.ReadAll(reader) require.NoError(t, err) + err = strategy.Wait() + require.NoError(t, err) assert.Equal(t, "0000000000000000", string(bytes)) } @@ -476,6 +488,10 @@ func (s *testStrategy) Fetch(ctx context.Context, url string) (io.Reader, int64, return io.NopCloser(strings.NewReader("00")), -1, nil } +func (s *testStrategy) Wait() error { + return nil +} + func (s *testStrategy) DoRequest(ctx context.Context, start, end int64, url string) (*http.Response, error) { s.mut.Lock() s.doRequestCalledCount++ diff --git a/pkg/download/strategy.go b/pkg/download/strategy.go index a430dd2..e716363 100644 --- a/pkg/download/strategy.go +++ b/pkg/download/strategy.go @@ -15,6 +15,9 @@ type Strategy interface { // This is the primary method that should be called to initiate a download of a file. Fetch(ctx context.Context, url string) (result io.Reader, fileSize int64, err error) + // Wait waits until all requests have completed, and returns the first error encountered, if any. + Wait() error + // DoRequest sends an HTTP GET request with a specified range of bytes to the given URL using the provided context. // It returns the HTTP response and any error encountered during the request. It is intended that Fetch calls DoRequest // and that each chunk is downloaded with a call to DoRequest. DoRequest is exposed so that consistent-hashing can diff --git a/pkg/pget.go b/pkg/pget.go index 3267174..84b0e03 100644 --- a/pkg/pget.go +++ b/pkg/pget.go @@ -53,6 +53,11 @@ func (g *Getter) DownloadFile(ctx context.Context, url string, dest string) (int if err != nil { return fileSize, 0, fmt.Errorf("error writing file: %w", err) } + err = g.Downloader.Wait() + if err != nil { + return fileSize, 0, err + } + // writeElapsed := time.Since(writeStartTime) totalElapsed := time.Since(downloadStartTime)