Skip to content

Commit

Permalink
implement Wait() on Strategy to fetch remaining errors (#179)
Browse files Browse the repository at this point in the history
Currently, we use an errgroup as a semaphore in buffer and consistent hash
modes, but we never inspect the results of sem.Wait() to see if there were any
errors. This results in errors being swallowed.

This implements Wait(), and uses it everywhere we call Fetch().  Now that we
have it, we don't need to store errors on bufferedReader so we can get rid of
that.
  • Loading branch information
philandstuff authored Mar 7, 2024
1 parent b6d4d5a commit c72c384
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 17 deletions.
6 changes: 4 additions & 2 deletions pkg/download/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
defer br.done()
firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, url)
if err != nil {
br.err = err
firstReqResultCh <- firstReqResult{err: err}
return err
}
Expand Down Expand Up @@ -145,7 +144,6 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
defer br.done()
resp, err := m.DoRequest(ctx, start, end, trueURL)
if err != nil {
br.err = err
return err
}
defer resp.Body.Close()
Expand All @@ -157,6 +155,10 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
return newChanMultiReader(readersCh), fileSize, nil
}

func (m *BufferMode) Wait() error {
return m.sem.Wait()
}

func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "GET", trueURL, nil)
if err != nil {
Expand Down
47 changes: 45 additions & 2 deletions pkg/download/buffer_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ package download

import (
"context"
"fmt"
"io"
"math/rand"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"testing/fstest"

"github.com/dustin/go-humanize"
"github.com/jarcoal/httpmock"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -42,8 +45,6 @@ func newTestServer(t *testing.T, content []byte) *httptest.Server {
return server
}

// TODO: Implement the test
// func TestGetFileSizeFromContentRange(t *testing.T) {}
func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) {
contentSize := int64(humanize.KiByte)
content := generateTestContent(contentSize)
Expand Down Expand Up @@ -116,9 +117,51 @@ func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) {
require.NoError(t, err)
data, err := io.ReadAll(download)
assert.NoError(t, err)
err = bufferMode.Wait()
assert.NoError(t, err)
assert.Equal(t, contentSize, size)
assert.Equal(t, len(content), len(data))
assert.Equal(t, content, data)
})
}
}

func TestWaitReturnsErrorWhenRequestFails(t *testing.T) {
mockTransport := httpmock.NewMockTransport()
opts := Options{
Client: client.Options{Transport: mockTransport},
ChunkSize: 2,
}
expectedErr := fmt.Errorf("Expected error in chunk 3")
mockTransport.RegisterResponder("GET", "http://test.example/hello.txt",
func(req *http.Request) (*http.Response, error) {
rangeHeader := req.Header.Get("Range")
var body string
switch rangeHeader {
case "bytes=0-1":
body = "he"
case "bytes=2-3":
body = "ll"
case "bytes=4-5":
body = "o "
case "bytes=6-7":
return nil, expectedErr
default:
return nil, fmt.Errorf("should't see this error")
}
resp := httpmock.NewStringResponse(http.StatusPartialContent, body)
resp.Request = req
resp.Header.Add("Content-Range", strings.Replace(rangeHeader, "=", " ", 1)+"/8")
resp.ContentLength = 2
resp.Header.Add("Content-Length", "2")
return resp, nil
})
bufferMode := GetBufferMode(opts)
download, _, err := bufferMode.Fetch(context.Background(), "http://test.example/hello.txt")
// No error here, because the first chunk was fetched successfully
require.NoError(t, err)
// the read might or might not return an error
_, _ = io.ReadAll(download)
err = bufferMode.Wait()
assert.ErrorIs(t, err, expectedErr)
}
13 changes: 3 additions & 10 deletions pkg/download/buffered_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ type bufferedReader struct {
// ready channel is closed when we're ready to read
ready chan struct{}
buf *bytes.Buffer
err error
pool *bufferPool
}

Expand All @@ -36,9 +35,6 @@ func newBufferedReader(pool *bufferPool) *bufferedReader {
// pool.
func (b *bufferedReader) Read(buf []byte) (int, error) {
<-b.ready
if b.err != nil {
return 0, b.err
}
n, err := b.buf.Read(buf)
// If we've read all the data,
if b.buf.Len() == 0 && b.buf != emptyBuffer {
Expand All @@ -59,17 +55,14 @@ func (b *bufferedReader) downloadBody(resp *http.Response) error {
expectedBytes := resp.ContentLength

if expectedBytes > int64(b.buf.Cap()) {
b.err = fmt.Errorf("Tried to download 0x%x bytes to a 0x%x-sized buffer", expectedBytes, b.buf.Cap())
return b.err
return fmt.Errorf("Tried to download 0x%x bytes to a 0x%x-sized buffer", expectedBytes, b.buf.Cap())
}
n, err := b.buf.ReadFrom(resp.Body)
if err != nil && err != io.EOF {
b.err = fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err)
return b.err
return fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err)
}
if n != expectedBytes {
b.err = fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String())
return b.err
return fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String())
}
return nil
}
Expand Down
10 changes: 7 additions & 3 deletions pkg/download/consistent_hashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
defer br.done()
firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, urlString)
if err != nil {
br.err = err
firstReqResultCh <- firstReqResult{err: err}
return err
// The error will be handled by the firstReqResultCh consumer,
// and may be recoverable; so we return nil to the errGroup
return nil
}
defer firstChunkResp.Body.Close()

Expand Down Expand Up @@ -217,7 +218,6 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString)
}
if err != nil {
br.err = err
return err
}
}
Expand All @@ -232,6 +232,10 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
return newChanMultiReader(readersCh), fileSize, nil
}

func (m *ConsistentHashingMode) Wait() error {
return m.sem.Wait()
}

func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) {
chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true)
req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil)
Expand Down
16 changes: 16 additions & 0 deletions pkg/download/consistent_hashing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ func TestConsistentHashing(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

assert.Equal(t, tc.expectedOutput, string(bytes))
})
Expand Down Expand Up @@ -321,6 +323,8 @@ func TestConsistentHashingPathBased(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

assert.Equal(t, tc.expectedOutput, string(bytes))
})
Expand Down Expand Up @@ -352,6 +356,8 @@ func TestConsistentHashRetries(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

// with a functional hostnames[0], we'd see 0344760706165500, but instead we
// should fall back to this. Note that each 0 value has been changed to a
Expand Down Expand Up @@ -387,6 +393,8 @@ func TestConsistentHashRetriesMissingHostname(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

// with a functional hostnames[0], we'd see 0344760706165500, but instead we
// should fall back to this. Note that each 0 value has been changed to a
Expand Down Expand Up @@ -421,6 +429,8 @@ func TestConsistentHashRetriesTwoHosts(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

assert.Equal(t, "0000000000000000", string(bytes))
}
Expand Down Expand Up @@ -448,6 +458,8 @@ func TestConsistentHashingHasFallback(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

assert.Equal(t, "0000000000000000", string(bytes))
}
Expand Down Expand Up @@ -476,6 +488,10 @@ func (s *testStrategy) Fetch(ctx context.Context, url string) (io.Reader, int64,
return io.NopCloser(strings.NewReader("00")), -1, nil
}

func (s *testStrategy) Wait() error {
return nil
}

func (s *testStrategy) DoRequest(ctx context.Context, start, end int64, url string) (*http.Response, error) {
s.mut.Lock()
s.doRequestCalledCount++
Expand Down
3 changes: 3 additions & 0 deletions pkg/download/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ type Strategy interface {
// This is the primary method that should be called to initiate a download of a file.
Fetch(ctx context.Context, url string) (result io.Reader, fileSize int64, err error)

// Wait waits until all requests have completed, and returns the first error encountered, if any.
Wait() error

// DoRequest sends an HTTP GET request with a specified range of bytes to the given URL using the provided context.
// It returns the HTTP response and any error encountered during the request. It is intended that Fetch calls DoRequest
// and that each chunk is downloaded with a call to DoRequest. DoRequest is exposed so that consistent-hashing can
Expand Down
5 changes: 5 additions & 0 deletions pkg/pget.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ func (g *Getter) DownloadFile(ctx context.Context, url string, dest string) (int
if err != nil {
return fileSize, 0, fmt.Errorf("error writing file: %w", err)
}
err = g.Downloader.Wait()
if err != nil {
return fileSize, 0, err
}

// writeElapsed := time.Since(writeStartTime)
totalElapsed := time.Since(downloadStartTime)

Expand Down

0 comments on commit c72c384

Please sign in to comment.