Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement Wait() on Strategy to fetch remaining errors #179

Merged
merged 1 commit into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pkg/download/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
defer br.done()
firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, url)
if err != nil {
br.err = err
firstReqResultCh <- firstReqResult{err: err}
return err
}
Expand Down Expand Up @@ -145,7 +144,6 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
defer br.done()
resp, err := m.DoRequest(ctx, start, end, trueURL)
if err != nil {
br.err = err
return err
}
defer resp.Body.Close()
Expand All @@ -157,6 +155,10 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
return newChanMultiReader(readersCh), fileSize, nil
}

func (m *BufferMode) Wait() error {
return m.sem.Wait()
}

func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "GET", trueURL, nil)
if err != nil {
Expand Down
47 changes: 45 additions & 2 deletions pkg/download/buffer_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ package download

import (
"context"
"fmt"
"io"
"math/rand"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"testing/fstest"

"github.com/dustin/go-humanize"
"github.com/jarcoal/httpmock"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -42,8 +45,6 @@ func newTestServer(t *testing.T, content []byte) *httptest.Server {
return server
}

// TODO: Implement the test
// func TestGetFileSizeFromContentRange(t *testing.T) {}
func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) {
contentSize := int64(humanize.KiByte)
content := generateTestContent(contentSize)
Expand Down Expand Up @@ -116,9 +117,51 @@ func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) {
require.NoError(t, err)
data, err := io.ReadAll(download)
assert.NoError(t, err)
err = bufferMode.Wait()
assert.NoError(t, err)
assert.Equal(t, contentSize, size)
assert.Equal(t, len(content), len(data))
assert.Equal(t, content, data)
})
}
}

func TestWaitReturnsErrorWhenRequestFails(t *testing.T) {
mockTransport := httpmock.NewMockTransport()
opts := Options{
Client: client.Options{Transport: mockTransport},
ChunkSize: 2,
}
expectedErr := fmt.Errorf("Expected error in chunk 3")
mockTransport.RegisterResponder("GET", "http://test.example/hello.txt",
func(req *http.Request) (*http.Response, error) {
rangeHeader := req.Header.Get("Range")
var body string
switch rangeHeader {
case "bytes=0-1":
body = "he"
case "bytes=2-3":
body = "ll"
case "bytes=4-5":
body = "o "
case "bytes=6-7":
return nil, expectedErr
default:
return nil, fmt.Errorf("should't see this error")
}
resp := httpmock.NewStringResponse(http.StatusPartialContent, body)
resp.Request = req
resp.Header.Add("Content-Range", strings.Replace(rangeHeader, "=", " ", 1)+"/8")
resp.ContentLength = 2
resp.Header.Add("Content-Length", "2")
return resp, nil
})
bufferMode := GetBufferMode(opts)
download, _, err := bufferMode.Fetch(context.Background(), "http://test.example/hello.txt")
// No error here, because the first chunk was fetched successfully
require.NoError(t, err)
// the read might or might not return an error
_, _ = io.ReadAll(download)
err = bufferMode.Wait()
assert.ErrorIs(t, err, expectedErr)
}
13 changes: 3 additions & 10 deletions pkg/download/buffered_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ type bufferedReader struct {
// ready channel is closed when we're ready to read
ready chan struct{}
buf *bytes.Buffer
err error
pool *bufferPool
}

Expand All @@ -36,9 +35,6 @@ func newBufferedReader(pool *bufferPool) *bufferedReader {
// pool.
func (b *bufferedReader) Read(buf []byte) (int, error) {
<-b.ready
if b.err != nil {
return 0, b.err
}
n, err := b.buf.Read(buf)
// If we've read all the data,
if b.buf.Len() == 0 && b.buf != emptyBuffer {
Expand All @@ -59,17 +55,14 @@ func (b *bufferedReader) downloadBody(resp *http.Response) error {
expectedBytes := resp.ContentLength

if expectedBytes > int64(b.buf.Cap()) {
b.err = fmt.Errorf("Tried to download 0x%x bytes to a 0x%x-sized buffer", expectedBytes, b.buf.Cap())
return b.err
return fmt.Errorf("Tried to download 0x%x bytes to a 0x%x-sized buffer", expectedBytes, b.buf.Cap())
}
n, err := b.buf.ReadFrom(resp.Body)
if err != nil && err != io.EOF {
b.err = fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err)
return b.err
return fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err)
}
if n != expectedBytes {
b.err = fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String())
return b.err
return fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String())
}
return nil
}
Expand Down
10 changes: 7 additions & 3 deletions pkg/download/consistent_hashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
defer br.done()
firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, urlString)
if err != nil {
br.err = err
firstReqResultCh <- firstReqResult{err: err}
return err
// The error will be handled by the firstReqResultCh consumer,
// and may be recoverable; so we return nil to the errGroup
return nil
}
defer firstChunkResp.Body.Close()

Expand Down Expand Up @@ -217,7 +218,6 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString)
}
if err != nil {
br.err = err
return err
}
}
Expand All @@ -232,6 +232,10 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
return newChanMultiReader(readersCh), fileSize, nil
}

func (m *ConsistentHashingMode) Wait() error {
return m.sem.Wait()
}

func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) {
chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true)
req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil)
Expand Down
16 changes: 16 additions & 0 deletions pkg/download/consistent_hashing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ func TestConsistentHashing(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

assert.Equal(t, tc.expectedOutput, string(bytes))
})
Expand Down Expand Up @@ -321,6 +323,8 @@ func TestConsistentHashingPathBased(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

assert.Equal(t, tc.expectedOutput, string(bytes))
})
Expand Down Expand Up @@ -352,6 +356,8 @@ func TestConsistentHashRetries(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

// with a functional hostnames[0], we'd see 0344760706165500, but instead we
// should fall back to this. Note that each 0 value has been changed to a
Expand Down Expand Up @@ -387,6 +393,8 @@ func TestConsistentHashRetriesMissingHostname(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

// with a functional hostnames[0], we'd see 0344760706165500, but instead we
// should fall back to this. Note that each 0 value has been changed to a
Expand Down Expand Up @@ -421,6 +429,8 @@ func TestConsistentHashRetriesTwoHosts(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

assert.Equal(t, "0000000000000000", string(bytes))
}
Expand Down Expand Up @@ -448,6 +458,8 @@ func TestConsistentHashingHasFallback(t *testing.T) {
require.NoError(t, err)
bytes, err := io.ReadAll(reader)
require.NoError(t, err)
err = strategy.Wait()
require.NoError(t, err)

assert.Equal(t, "0000000000000000", string(bytes))
}
Expand Down Expand Up @@ -476,6 +488,10 @@ func (s *testStrategy) Fetch(ctx context.Context, url string) (io.Reader, int64,
return io.NopCloser(strings.NewReader("00")), -1, nil
}

func (s *testStrategy) Wait() error {
return nil
}

func (s *testStrategy) DoRequest(ctx context.Context, start, end int64, url string) (*http.Response, error) {
s.mut.Lock()
s.doRequestCalledCount++
Expand Down
3 changes: 3 additions & 0 deletions pkg/download/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ type Strategy interface {
// This is the primary method that should be called to initiate a download of a file.
Fetch(ctx context.Context, url string) (result io.Reader, fileSize int64, err error)

// Wait waits until all requests have completed, and returns the first error encountered, if any.
Wait() error

// DoRequest sends an HTTP GET request with a specified range of bytes to the given URL using the provided context.
// It returns the HTTP response and any error encountered during the request. It is intended that Fetch calls DoRequest
// and that each chunk is downloaded with a call to DoRequest. DoRequest is exposed so that consistent-hashing can
Expand Down
5 changes: 5 additions & 0 deletions pkg/pget.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ func (g *Getter) DownloadFile(ctx context.Context, url string, dest string) (int
if err != nil {
return fileSize, 0, fmt.Errorf("error writing file: %w", err)
}
err = g.Downloader.Wait()
if err != nil {
return fileSize, 0, err
}

// writeElapsed := time.Since(writeStartTime)
totalElapsed := time.Since(downloadStartTime)

Expand Down