From f9b008ec88b8df31deb43a1dc1871b29a39c83de Mon Sep 17 00:00:00 2001 From: Philip Potter Date: Thu, 14 Dec 2023 13:53:44 +0000 Subject: [PATCH] Better concurrency (#115) * introduce bufferedReader The idea of bufferedReader is to be a staging area where we can download request bodies to. It blocks Read() calls until the whole request has been downloaded. The key thing this enables us to do is to get rid of the errGroup.Wait() call, because the readers that we return are waiting for each chunk to be ready and we no longer need to wait until the whole file has been downloaded to return from Fetch(). * respect MaxConcurrency within a single BufferMode object This means that a single BufferMode instance cannot ever exceed MaxConcurrency concurrent requests. * introduce workQueue and chanMultiReader We want to be able to return early so that a consumer can start consuming ASAP, even if we block on making new requests before we get to the last chunk. This lets us do that. * add a per-file concurrency limit * make max-concurrent-files configurable --- cmd/multifile/multifile.go | 12 ++ cmd/root/root.go | 2 - pkg/config/optnames.go | 29 ++-- pkg/download/buffer.go | 173 ++++++++++--------- pkg/download/buffer_test.go | 7 +- pkg/download/buffer_unit_test.go | 11 +- pkg/download/buffered_reader.go | 55 ++++++ pkg/download/chan_multi_reader.go | 42 +++++ pkg/download/consistent_hashing.go | 219 +++++++++++------------- pkg/download/consistent_hashing_test.go | 48 +++--- pkg/download/options.go | 12 +- pkg/download/work_queue.go | 26 +++ pkg/pget.go | 28 ++- pkg/pget_test.go | 4 +- 14 files changed, 404 insertions(+), 264 deletions(-) create mode 100644 pkg/download/buffered_reader.go create mode 100644 pkg/download/chan_multi_reader.go create mode 100644 pkg/download/work_queue.go diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index 0531e66..054dc44 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -88,6 +88,14 @@ func runMultifileCMD(cmd *cobra.Command, args []string) error { return multifileExecute(cmd.Context(), manifest) } +func maxConcurrentFiles() int { + maxConcurrentFiles := viper.GetInt(config.OptMaxConcurrentFiles) + if maxConcurrentFiles == 0 { + maxConcurrentFiles = 20 + } + return maxConcurrentFiles +} + func multifileExecute(ctx context.Context, manifest pget.Manifest) error { minChunkSize, err := humanize.ParseBytes(viper.GetString(config.OptMinimumChunkSize)) if err != nil { @@ -112,6 +120,9 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { MinChunkSize: int64(minChunkSize), Client: clientOpts, } + pgetOpts := pget.Options{ + MaxConcurrentFiles: maxConcurrentFiles(), + } consumer, err := config.GetConsumer() if err != nil { @@ -121,6 +132,7 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { getter := &pget.Getter{ Downloader: download.GetBufferMode(downloadOpts), Consumer: consumer, + Options: pgetOpts, } // TODO DRY this diff --git a/cmd/root/root.go b/cmd/root/root.go index f1b611b..593f0f9 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -11,7 +11,6 @@ import ( "github.com/rs/zerolog/log" "github.com/spf13/cobra" "github.com/spf13/viper" - "golang.org/x/sync/semaphore" pget "github.com/replicate/pget/pkg" "github.com/replicate/pget/pkg/cli" @@ -174,7 +173,6 @@ func rootExecute(ctx context.Context, urlString, dest string) error { MaxConcurrency: viper.GetInt(config.OptConcurrency), MinChunkSize: int64(minChunkSize), Client: clientOpts, - Semaphore: semaphore.NewWeighted(int64(viper.GetInt(config.OptConcurrency))), } consumer, err := config.GetConsumer() diff --git a/pkg/config/optnames.go b/pkg/config/optnames.go index 0e3da99..a1756ba 100644 --- a/pkg/config/optnames.go +++ b/pkg/config/optnames.go @@ -6,18 +6,19 @@ const ( OptCacheNodesSRVNameByHostCIDR = "cache-nodes-srv-name-by-host-cidr" OptHostIP = "host-ip" - OptCacheNodesSRVName = "cache-nodes-srv-name" - OptConcurrency = "concurrency" - OptConnTimeout = "connect-timeout" - OptExtract = "extract" - OptForce = "force" - OptForceHTTP2 = "force-http2" - OptLoggingLevel = "log-level" - OptMaxChunks = "max-chunks" - OptMaxConnPerHost = "max-conn-per-host" - OptMinimumChunkSize = "minimum-chunk-size" - OptOutputConsumer = "output" - OptResolve = "resolve" - OptRetries = "retries" - OptVerbose = "verbose" + OptCacheNodesSRVName = "cache-nodes-srv-name" + OptConcurrency = "concurrency" + OptConnTimeout = "connect-timeout" + OptExtract = "extract" + OptForce = "force" + OptForceHTTP2 = "force-http2" + OptLoggingLevel = "log-level" + OptMaxChunks = "max-chunks" + OptMaxConnPerHost = "max-conn-per-host" + OptMaxConcurrentFiles = "max-concurrent-files" + OptMinimumChunkSize = "minimum-chunk-size" + OptOutputConsumer = "output" + OptResolve = "resolve" + OptRetries = "retries" + OptVerbose = "verbose" ) diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index 936f14c..97459ac 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -1,13 +1,11 @@ package download import ( - "bytes" "context" "fmt" "io" "net/http" "regexp" - "runtime" "strconv" "golang.org/x/sync/errgroup" @@ -25,24 +23,26 @@ var contentRangeRegexp = regexp.MustCompile(`^bytes .*/([0-9]+)$`) type BufferMode struct { Client *client.HTTPClient Options + + // we use this errgroup as a semaphore (via sem.SetLimit()) + sem *errgroup.Group + queue *workQueue } func GetBufferMode(opts Options) *BufferMode { client := client.NewHTTPClient(opts.Client) + sem := new(errgroup.Group) + sem.SetLimit(opts.maxConcurrency()) + queue := newWorkQueue(opts.maxConcurrency()) + queue.start() return &BufferMode{ Client: client, Options: opts, + sem: sem, + queue: queue, } } -func (m *BufferMode) maxChunks() int { - maxChunks := m.MaxConcurrency - if maxChunks == 0 { - return runtime.NumCPU() * 4 - } - return maxChunks -} - func (m *BufferMode) minChunkSize() int64 { minChunkSize := m.MinChunkSize if minChunkSize == 0 { @@ -59,32 +59,61 @@ func (m *BufferMode) getFileSizeFromContentRange(contentRange string) (int64, er return strconv.ParseInt(groups[1], 10, 64) } -func (m *BufferMode) fileToBuffer(ctx context.Context, url string) (*bytes.Buffer, int64, error) { +type firstReqResult struct { + fileSize int64 + trueURL string + err error +} + +func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, error) { logger := logging.GetLogger() - firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, url) - if err != nil { - return nil, -1, err - } - trueURL := firstChunkResp.Request.URL.String() - if trueURL != url { - logger.Info().Str("url", url).Str("redirect_url", trueURL).Msg("Redirect") + br := newBufferedReader(m.minChunkSize()) + + firstReqResultCh := make(chan firstReqResult) + m.queue.submit(func() { + m.sem.Go(func() error { + defer close(firstReqResultCh) + defer br.done() + firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, url) + if err != nil { + firstReqResultCh <- firstReqResult{err: err} + return err + } + + defer firstChunkResp.Body.Close() + + trueURL := firstChunkResp.Request.URL.String() + if trueURL != url { + logger.Info().Str("url", url).Str("redirect_url", trueURL).Msg("Redirect") + } + + fileSize, err := m.getFileSizeFromContentRange(firstChunkResp.Header.Get("Content-Range")) + if err != nil { + firstReqResultCh <- firstReqResult{err: err} + return err + } + firstReqResultCh <- firstReqResult{fileSize: fileSize, trueURL: trueURL} + + return br.downloadBody(firstChunkResp) + }) + }) + + firstReqResult, ok := <-firstReqResultCh + if !ok { + panic("logic error in BufferMode: first request didn't return any output") } - fileSize, err := m.getFileSizeFromContentRange(firstChunkResp.Header.Get("Content-Range")) - if err != nil { - firstChunkResp.Body.Close() - return nil, -1, err + if firstReqResult.err != nil { + return nil, -1, firstReqResult.err } - data := make([]byte, fileSize) + fileSize := firstReqResult.fileSize + trueURL := firstReqResult.trueURL + if fileSize <= m.minChunkSize() { // we only need a single chunk: just download it and finish - err = m.downloadChunk(firstChunkResp, data[0:fileSize]) - if err != nil { - return nil, -1, err - } - return bytes.NewBuffer(data), fileSize, nil + return br, fileSize, nil } remainingBytes := fileSize - m.minChunkSize() @@ -93,51 +122,52 @@ func (m *BufferMode) fileToBuffer(ctx context.Context, url string) (*bytes.Buffe if numChunks <= 0 { numChunks = 1 } - if numChunks > m.maxChunks() { - numChunks = m.maxChunks() + if numChunks > m.maxConcurrency() { + numChunks = m.maxConcurrency() } + readersCh := make(chan io.Reader, m.maxConcurrency()+1) + readersCh <- br + + startOffset := m.minChunkSize() + chunkSize := remainingBytes / int64(numChunks) if chunkSize < 0 { - firstChunkResp.Body.Close() return nil, -1, fmt.Errorf("error: chunksize incorrect - result is negative, %d", chunkSize) } - logger.Debug().Str("url", url). - Int64("size", fileSize). - Int("connections", numChunks). - Int64("chunkSize", chunkSize). - Msg("Downloading") - - errGroup, ctx := errgroup.WithContext(ctx) - errGroup.Go(func() error { - return m.downloadChunk(firstChunkResp, data[0:m.minChunkSize()]) - }) - - startOffset := m.minChunkSize() + m.queue.submit(func() { + defer close(readersCh) + logger.Debug().Str("url", url). + Int64("size", fileSize). + Int("connections", numChunks). + Int64("chunkSize", chunkSize). + Msg("Downloading") - for i := 0; i < numChunks; i++ { - start := startOffset + chunkSize*int64(i) - end := start + chunkSize - 1 + for i := 0; i < numChunks; i++ { + start := startOffset + chunkSize*int64(i) + end := start + chunkSize - 1 - if i == numChunks-1 { - end = fileSize - 1 - } - errGroup.Go(func() error { - resp, err := m.DoRequest(ctx, start, end, trueURL) - if err != nil { - return err + if i == numChunks-1 { + end = fileSize - 1 } - return m.downloadChunk(resp, data[start:end+1]) - }) - } - if err := errGroup.Wait(); err != nil { - return nil, -1, err // return the first error we encounter - } + br := newBufferedReader(end - start + 1) + readersCh <- br + + m.sem.Go(func() error { + defer br.done() + resp, err := m.DoRequest(ctx, start, end, trueURL) + if err != nil { + return err + } + defer resp.Body.Close() + return br.downloadBody(resp) + }) + } + }) - buffer := bytes.NewBuffer(data) - return buffer, fileSize, nil + return newChanMultiReader(readersCh), fileSize, nil } func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) { @@ -156,24 +186,3 @@ func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL st return resp, nil } - -func (m *BufferMode) downloadChunk(resp *http.Response, dataSlice []byte) error { - defer resp.Body.Close() - expectedBytes := len(dataSlice) - n, err := io.ReadFull(resp.Body, dataSlice) - if err != nil && err != io.EOF { - return fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err) - } - if n != expectedBytes { - return fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String()) - } - return nil -} - -func (m *BufferMode) Fetch(ctx context.Context, url string) (result io.Reader, fileSize int64, err error) { - buffer, fileSize, err := m.fileToBuffer(ctx, url) - if err != nil { - return nil, 0, err - } - return buffer, fileSize, nil -} diff --git a/pkg/download/buffer_test.go b/pkg/download/buffer_test.go index 56b2947..da05e4b 100644 --- a/pkg/download/buffer_test.go +++ b/pkg/download/buffer_test.go @@ -18,13 +18,8 @@ func init() { var defaultOpts = download.Options{Client: client.Options{}} var http2Opts = download.Options{Client: client.Options{ForceHTTP2: true}} -func makeBufferMode(opts download.Options) *download.BufferMode { - client := client.NewHTTPClient(opts.Client) - - return &download.BufferMode{Client: client, Options: opts} -} func benchmarkDownloadURL(opts download.Options, url string, b *testing.B) { - bufferMode := makeBufferMode(opts) + bufferMode := download.GetBufferMode(opts) for n := 0; n < b.N; n++ { ctx, cancel := context.WithCancel(context.Background()) diff --git a/pkg/download/buffer_unit_test.go b/pkg/download/buffer_unit_test.go index dde7ac0..e9b7d00 100644 --- a/pkg/download/buffer_unit_test.go +++ b/pkg/download/buffer_unit_test.go @@ -13,6 +13,7 @@ import ( "github.com/dustin/go-humanize" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/replicate/pget/pkg/client" ) @@ -109,10 +110,10 @@ func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) { t.Run(tc.name, func(t *testing.T) { opts.MaxConcurrency = tc.maxConcurrency opts.MinChunkSize = tc.minChunkSize - bufferMode := makeBufferMode(opts) + bufferMode := GetBufferMode(opts) path, _ := url.JoinPath(server.URL, testFilePath) download, size, err := bufferMode.Fetch(context.Background(), path) - assert.NoError(t, err) + require.NoError(t, err) data, err := io.ReadAll(download) assert.NoError(t, err) assert.Equal(t, contentSize, size) @@ -121,9 +122,3 @@ func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) { }) } } - -func makeBufferMode(opts Options) *BufferMode { - client := client.NewHTTPClient(opts.Client) - - return &BufferMode{Client: client, Options: opts} -} diff --git a/pkg/download/buffered_reader.go b/pkg/download/buffered_reader.go new file mode 100644 index 0000000..973eeea --- /dev/null +++ b/pkg/download/buffered_reader.go @@ -0,0 +1,55 @@ +package download + +import ( + "bytes" + "fmt" + "io" + "net/http" +) + +// A bufferedReader wraps an http.Response.Body so that it can be eagerly +// downloaded to a buffer before the actual io.Reader consumer can read it. +// It implements io.Reader. +type bufferedReader struct { + // ready channel is closed when we're ready to read + ready chan struct{} + buf *bytes.Buffer + err error +} + +var _ io.Reader = &bufferedReader{} + +func newBufferedReader(capacity int64) *bufferedReader { + return &bufferedReader{ + ready: make(chan struct{}), + buf: bytes.NewBuffer(make([]byte, 0, capacity)), + } +} + +// Read implements io.Reader. It will block until the full body is available for +// reading. +func (b *bufferedReader) Read(buf []byte) (int, error) { + <-b.ready + if b.err != nil { + return 0, b.err + } + return b.buf.Read(buf) +} + +func (b *bufferedReader) done() { + close(b.ready) +} + +func (b *bufferedReader) downloadBody(resp *http.Response) error { + expectedBytes := resp.ContentLength + n, err := b.buf.ReadFrom(resp.Body) + if err != nil && err != io.EOF { + b.err = fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err) + return b.err + } + if n != expectedBytes { + b.err = fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String()) + return b.err + } + return nil +} diff --git a/pkg/download/chan_multi_reader.go b/pkg/download/chan_multi_reader.go new file mode 100644 index 0000000..56efc8f --- /dev/null +++ b/pkg/download/chan_multi_reader.go @@ -0,0 +1,42 @@ +package download + +import "io" + +type chanMultiReader struct { + ch <-chan io.Reader + cur io.Reader +} + +var _ io.Reader = &chanMultiReader{} + +func newChanMultiReader(ch <-chan io.Reader) *chanMultiReader { + return &chanMultiReader{ch: ch} +} + +func (c *chanMultiReader) Read(p []byte) (n int, err error) { + for { + if c.cur == nil { + var ok bool + c.cur, ok = <-c.ch + if !ok { + // no more readers; return EOF + return 0, io.EOF + } + } + n, err = c.cur.Read(p) + if err == io.EOF { + c.cur = nil + } + if n > 0 || err != io.EOF { + // we either made progress or hit an error, return to the caller + if err == io.EOF { + // TODO: we could eagerly check to see if the channel is closed + // and return EOF one call early + err = nil + } + return + } + // n == 0, err == EOF; this reader is done and we need to start the next + c.cur = nil + } +} diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index b17ae89..4b649ae 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -1,14 +1,12 @@ package download import ( - "bytes" "context" "errors" "fmt" "io" "net/http" "net/url" - "runtime" "strconv" jump "github.com/dgryski/go-jump" @@ -25,6 +23,10 @@ type ConsistentHashingMode struct { Options // TODO: allow this to be configured and not just "BufferMode" FallbackStrategy Strategy + + // we use this errgroup as a semaphore (via sem.SetLimit()) + sem *errgroup.Group + queue *workQueue } type CacheKey struct { @@ -32,32 +34,32 @@ type CacheKey struct { Slice int64 } -func GetConsistentHashingMode(opts Options) (Strategy, error) { +func GetConsistentHashingMode(opts Options) (*ConsistentHashingMode, error) { if opts.SliceSize == 0 { return nil, fmt.Errorf("must specify slice size in consistent hashing mode") } - if opts.Semaphore != nil && opts.MaxConcurrency == 0 { - return nil, fmt.Errorf("if you provide a semaphore you must specify MaxConcurrency") - } client := client.NewHTTPClient(opts.Client) - fallbackStrategy := GetBufferMode(opts) - fallbackStrategy.Client = client + sem := new(errgroup.Group) + sem.SetLimit(opts.maxConcurrency()) + queue := newWorkQueue(opts.maxConcurrency()) + queue.start() + + fallbackStrategy := &BufferMode{ + Client: client, + Options: opts, + sem: sem, + queue: queue, + } return &ConsistentHashingMode{ Client: client, Options: opts, FallbackStrategy: fallbackStrategy, + sem: sem, + queue: queue, }, nil } -func (m *ConsistentHashingMode) maxConcurrency() int { - maxChunks := m.MaxConcurrency - if maxChunks == 0 { - return runtime.NumCPU() * 4 - } - return maxChunks -} - func (m *ConsistentHashingMode) minChunkSize() int64 { minChunkSize := m.MinChunkSize if minChunkSize == 0 { @@ -100,34 +102,47 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io return m.FallbackStrategy.Fetch(ctx, urlString) } - firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, urlString) - if err != nil { + br := newBufferedReader(m.minChunkSize()) + firstReqResultCh := make(chan firstReqResult) + m.queue.submit(func() { + m.sem.Go(func() error { + defer close(firstReqResultCh) + defer br.done() + firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, urlString) + if err != nil { + firstReqResultCh <- firstReqResult{err: err} + return err + } + defer firstChunkResp.Body.Close() + + fileSize, err := m.getFileSizeFromContentRange(firstChunkResp.Header.Get("Content-Range")) + if err != nil { + firstReqResultCh <- firstReqResult{err: err} + return err + } + firstReqResultCh <- firstReqResult{fileSize: fileSize} + + return br.downloadBody(firstChunkResp) + }) + }) + firstReqResult, ok := <-firstReqResultCh + if !ok { + panic("logic error in ConsistentHashingMode: first request didn't return any output") + } + if firstReqResult.err != nil { // In the case that an error indicating an issue with the cache server, networking, etc is returned, // this will use the fallback strategy. This is a case where the whole file will use the fallback // strategy. - if errors.Is(err, client.ErrStrategyFallback) { + if errors.Is(firstReqResult.err, client.ErrStrategyFallback) { return m.FallbackStrategy.Fetch(ctx, urlString) } - return nil, -1, err - } - - fileSize, err := m.getFileSizeFromContentRange(firstChunkResp.Header.Get("Content-Range")) - if err != nil { - firstChunkResp.Body.Close() - return nil, -1, err + return nil, -1, firstReqResult.err } + fileSize := firstReqResult.fileSize - data := make([]byte, fileSize) if fileSize <= m.minChunkSize() { // we only need a single chunk: just download it and finish - err = m.downloadChunk(firstChunkResp, data) - if err != nil { - return nil, -1, err - } - // TODO: rather than eagerly downloading here, we could return - // an io.ReadCloser that downloads the file and releases the - // semaphore when closed - return bytes.NewBuffer(data), fileSize, nil + return br, fileSize, nil } totalSlices := fileSize / m.SliceSize @@ -135,11 +150,6 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io totalSlices++ } - errGroup, ctx := errgroup.WithContext(ctx) - errGroup.Go(func() error { - return m.downloadChunk(firstChunkResp, data[0:m.minChunkSize()]) - }) - // we subtract one because we've already got firstChunkResp in flight concurrency := m.maxConcurrency() - 1 if concurrency <= int(totalSlices) { @@ -154,80 +164,77 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io chunksPerSlice = append([]int64{0}, EqualSplit(int64(concurrency), totalSlices-1)...) } + readersCh := make(chan io.Reader, m.maxConcurrency()+1) + readersCh <- br + logger.Debug().Str("url", urlString). Int64("size", fileSize). Int("concurrency", m.maxConcurrency()). Ints64("chunks_per_slice", chunksPerSlice). Msg("Downloading") - for slice, numChunks := range chunksPerSlice { - if numChunks == 0 { - // this happens if we've already downloaded the whole first slice - continue - } - startFrom := m.SliceSize * int64(slice) - sliceSize := m.SliceSize - if slice == 0 { - startFrom = firstChunkResp.ContentLength - sliceSize = sliceSize - firstChunkResp.ContentLength - } - if slice == int(totalSlices)-1 { - sliceSize = (fileSize-1)%m.SliceSize + 1 - } - if sliceSize/numChunks < m.minChunkSize() { - // reset numChunks to respect minChunkSize - numChunks = sliceSize / m.minChunkSize() - // although we must always have at least one chunk + m.queue.submit(func() { + defer close(readersCh) + for slice, numChunks := range chunksPerSlice { if numChunks == 0 { - numChunks = 1 + // this happens if we've already downloaded the whole first slice + continue } - } - chunkSizes := EqualSplit(sliceSize, numChunks) - for _, chunkSize := range chunkSizes { - // startFrom changes each time round the loop - // we create chunkStart to be a stable variable for the goroutine to capture - chunkStart := startFrom - chunkEnd := startFrom + chunkSize - 1 - - dataSlice := data[chunkStart : chunkEnd+1] - errGroup.Go(func() error { - logger.Debug().Int64("start", chunkStart).Int64("end", chunkEnd).Msg("starting request") - resp, err := m.DoRequest(ctx, chunkStart, chunkEnd, urlString) - if err != nil { - // In the case that an error indicating an issue with the cache server, networking, etc is returned, - // this will use the fallback strategy. This is a case where the whole file will perform the fall-back - // for the specified chunk instead of the whole file. - if errors.Is(err, client.ErrStrategyFallback) { - resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString) - } + startFrom := m.SliceSize * int64(slice) + sliceSize := m.SliceSize + if slice == 0 { + startFrom = m.minChunkSize() + sliceSize = sliceSize - m.minChunkSize() + } + if slice == int(totalSlices)-1 { + sliceSize = (fileSize-1)%m.SliceSize + 1 + } + if sliceSize/numChunks < m.minChunkSize() { + // reset numChunks to respect minChunkSize + numChunks = sliceSize / m.minChunkSize() + // although we must always have at least one chunk + if numChunks == 0 { + numChunks = 1 + } + } + chunkSizes := EqualSplit(sliceSize, numChunks) + for _, chunkSize := range chunkSizes { + // startFrom changes each time round the loop + // we create chunkStart to be a stable variable for the goroutine to capture + chunkStart := startFrom + chunkEnd := startFrom + chunkSize - 1 + + br := newBufferedReader(m.minChunkSize()) + readersCh <- br + m.sem.Go(func() error { + defer br.done() + logger.Debug().Int64("start", chunkStart).Int64("end", chunkEnd).Msg("starting request") + resp, err := m.DoRequest(ctx, chunkStart, chunkEnd, urlString) if err != nil { - return err + // in the case that an error indicating an issue with the cache server, networking, etc is returned, + // this will use the fallback strategy. This is a case where the whole file will perform the fall-back + // for the specified chunk instead of the whole file. + if errors.Is(err, client.ErrStrategyFallback) { + resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString) + } + if err != nil { + return err + } } - } + defer resp.Body.Close() + return br.downloadBody(resp) + }) - return m.downloadChunk(resp, dataSlice) - }) - - startFrom = startFrom + chunkSize + startFrom = startFrom + chunkSize + } } - } - - if err := errGroup.Wait(); err != nil { - return nil, -1, err // return the first error we encounter - } + }) - buffer := bytes.NewBuffer(data) - return buffer, fileSize, nil + return newChanMultiReader(readersCh), fileSize, nil } func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) { logger := logging.GetLogger() - if m.Semaphore != nil { - err := m.Semaphore.Acquire(ctx, 1) - if err != nil { - return nil, err - } - } chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true) req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil) if err != nil { @@ -288,21 +295,3 @@ func (m *ConsistentHashingMode) consistentHashIfNeeded(req *http.Request, start } return nil } - -func (m *ConsistentHashingMode) downloadChunk(resp *http.Response, dataSlice []byte) error { - logger := logging.GetLogger() - defer resp.Body.Close() - if m.Semaphore != nil { - defer m.Semaphore.Release(1) - } - expectedBytes := len(dataSlice) - n, err := io.ReadFull(resp.Body, dataSlice) - if err != nil && err != io.EOF { - return fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err) - } - if n != expectedBytes { - return fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String()) - } - logger.Debug().Int("size", len(dataSlice)).Int("downloaded", n).Msg("downloaded chunk") - return nil -} diff --git a/pkg/download/consistent_hashing_test.go b/pkg/download/consistent_hashing_test.go index dae9eec..4d7e568 100644 --- a/pkg/download/consistent_hashing_test.go +++ b/pkg/download/consistent_hashing_test.go @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/sync/semaphore" "github.com/replicate/pget/pkg/client" "github.com/replicate/pget/pkg/download" @@ -30,13 +29,6 @@ var testFSes = []fstest.MapFS{ {"hello.txt": {Data: []byte("7777777777777777")}}, } -func makeConsistentHashingMode(opts download.Options) *download.ConsistentHashingMode { - client := client.NewHTTPClient(opts.Client) - fallbackMode := download.BufferMode{Options: opts, Client: client} - - return &download.ConsistentHashingMode{Client: client, Options: opts, FallbackStrategy: &fallbackMode} -} - type chTestCase struct { name string concurrency int @@ -185,7 +177,6 @@ func TestConsistentHashing(t *testing.T) { Client: client.Options{}, MaxConcurrency: tc.concurrency, MinChunkSize: tc.minChunkSize, - Semaphore: semaphore.NewWeighted(int64(tc.concurrency)), CacheHosts: hostnames[0:tc.numCacheHosts], DomainsToCache: []string{"test.replicate.delivery"}, SliceSize: tc.sliceSize, @@ -194,7 +185,8 @@ func TestConsistentHashing(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - strategy := makeConsistentHashingMode(opts) + strategy, err := download.GetConsistentHashingMode(opts) + assert.NoError(t, err) reader, _, err := strategy.Fetch(ctx, "http://test.replicate.delivery/hello.txt") assert.NoError(t, err) @@ -214,7 +206,6 @@ func TestConsistentHashingHasFallback(t *testing.T) { Client: client.Options{}, MaxConcurrency: 8, MinChunkSize: 2, - Semaphore: semaphore.NewWeighted(8), CacheHosts: []string{}, DomainsToCache: []string{"fake.replicate.delivery"}, SliceSize: 3, @@ -223,12 +214,13 @@ func TestConsistentHashingHasFallback(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - strategy := makeConsistentHashingMode(opts) + strategy, err := download.GetConsistentHashingMode(opts) + require.NoError(t, err) urlString, err := url.JoinPath(server.URL, "hello.txt") - assert.NoError(t, err) + require.NoError(t, err) reader, _, err := strategy.Fetch(ctx, urlString) - assert.NoError(t, err) + require.NoError(t, err) bytes, err := io.ReadAll(reader) assert.NoError(t, err) @@ -263,7 +255,14 @@ func (s *testStrategy) DoRequest(ctx context.Context, start, end int64, url stri s.mut.Lock() s.doRequestCalledCount++ s.mut.Unlock() - resp := &http.Response{Body: io.NopCloser(strings.NewReader("00"))} + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp := &http.Response{ + Request: req, + Body: io.NopCloser(strings.NewReader("00")), + } return resp, nil } @@ -302,7 +301,6 @@ func TestConsistentHashingFileFallback(t *testing.T) { Client: client.Options{}, MaxConcurrency: 8, MinChunkSize: 2, - Semaphore: semaphore.NewWeighted(8), CacheHosts: []string{url.Host}, DomainsToCache: []string{"fake.replicate.delivery"}, SliceSize: 3, @@ -311,12 +309,14 @@ func TestConsistentHashingFileFallback(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - strategy := makeConsistentHashingMode(opts) + strategy, err := download.GetConsistentHashingMode(opts) + assert.NoError(t, err) + fallbackStrategy := &testStrategy{} strategy.FallbackStrategy = fallbackStrategy urlString := "http://fake.replicate.delivery/hello.txt" - _, _, err := strategy.Fetch(ctx, urlString) + _, _, err = strategy.Fetch(ctx, urlString) if tc.expectedError != nil { assert.ErrorIs(t, err, tc.expectedError) } @@ -363,7 +363,6 @@ func TestConsistentHashingChunkFallback(t *testing.T) { Client: client.Options{}, MaxConcurrency: 8, MinChunkSize: 3, - Semaphore: semaphore.NewWeighted(8), CacheHosts: []string{url.Host}, DomainsToCache: []string{"fake.replicate.delivery"}, SliceSize: 3, @@ -372,13 +371,20 @@ func TestConsistentHashingChunkFallback(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - strategy := makeConsistentHashingMode(opts) + strategy, err := download.GetConsistentHashingMode(opts) + assert.NoError(t, err) + fallbackStrategy := &testStrategy{} strategy.FallbackStrategy = fallbackStrategy urlString := "http://fake.replicate.delivery/hello.txt" - _, _, err := strategy.Fetch(ctx, urlString) + out, _, err := strategy.Fetch(ctx, urlString) assert.ErrorIs(t, err, tc.expectedError) + if err == nil { + // eagerly read the whole output reader to force all the + // requests to be completed + _, _ = io.Copy(io.Discard, out) + } assert.Equal(t, tc.fetchCalledCount, fallbackStrategy.fetchCalledCount) assert.Equal(t, tc.doRequestCalledCount, fallbackStrategy.doRequestCalledCount) }) diff --git a/pkg/download/options.go b/pkg/download/options.go index a9b29ce..a3baa66 100644 --- a/pkg/download/options.go +++ b/pkg/download/options.go @@ -1,7 +1,7 @@ package download import ( - "golang.org/x/sync/semaphore" + "runtime" "github.com/replicate/pget/pkg/client" ) @@ -29,8 +29,12 @@ type Options struct { // hashing algorithm. The slice may contain empty entries which // correspond to a cache host which is currently unavailable. CacheHosts []string +} - // Semaphore is used to manage maximum concurrency. If nil, concurrency - // is unlimited. - Semaphore *semaphore.Weighted +func (o *Options) maxConcurrency() int { + maxChunks := o.MaxConcurrency + if maxChunks == 0 { + return runtime.NumCPU() * 4 + } + return maxChunks } diff --git a/pkg/download/work_queue.go b/pkg/download/work_queue.go new file mode 100644 index 0000000..b181f2c --- /dev/null +++ b/pkg/download/work_queue.go @@ -0,0 +1,26 @@ +package download + +// workQueue takes work items and executes them serially, in strict FIFO order +type workQueue struct { + queue chan work +} + +type work func() + +func newWorkQueue(depth int) *workQueue { + return &workQueue{queue: make(chan work, depth)} +} + +func (q *workQueue) submit(w work) { + q.queue <- w +} + +func (q *workQueue) start() { + go q.run() +} + +func (q *workQueue) run() { + for item := range q.queue { + item() + } +} diff --git a/pkg/pget.go b/pkg/pget.go index 061d4fc..14fc1df 100644 --- a/pkg/pget.go +++ b/pkg/pget.go @@ -17,6 +17,11 @@ import ( type Getter struct { Downloader download.Strategy Consumer consumer.Consumer + Options Options +} + +type Options struct { + MaxConcurrentFiles int } type ManifestEntry struct { @@ -39,26 +44,27 @@ func (g *Getter) DownloadFile(ctx context.Context, url string, dest string) (int if err != nil { return fileSize, 0, err } - downloadElapsed := time.Since(downloadStartTime) - writeStartTime := time.Now() + // downloadElapsed := time.Since(downloadStartTime) + // writeStartTime := time.Now() err = g.Consumer.Consume(buffer, dest) if err != nil { return fileSize, 0, fmt.Errorf("error writing file: %w", err) } - writeElapsed := time.Since(writeStartTime) + // writeElapsed := time.Since(writeStartTime) totalElapsed := time.Since(downloadStartTime) size := humanize.Bytes(uint64(fileSize)) - downloadThroughput := humanize.Bytes(uint64(float64(fileSize) / downloadElapsed.Seconds())) - writeThroughput := humanize.Bytes(uint64(float64(fileSize) / writeElapsed.Seconds())) + // downloadThroughput := humanize.Bytes(uint64(float64(fileSize) / downloadElapsed.Seconds())) + // writeThroughput := humanize.Bytes(uint64(float64(fileSize) / writeElapsed.Seconds())) logger.Info(). Str("dest", dest). + Str("url", url). Str("size", size). - Str("download_throughput", fmt.Sprintf("%s/s", downloadThroughput)). - Str("download_elapsed", fmt.Sprintf("%.3fs", downloadElapsed.Seconds())). - Str("write_throughput", fmt.Sprintf("%s/s", writeThroughput)). - Str("write_elapsed", fmt.Sprintf("%.3fs", writeElapsed.Seconds())). + // Str("download_throughput", fmt.Sprintf("%s/s", downloadThroughput)). + // Str("download_elapsed", fmt.Sprintf("%.3fs", downloadElapsed.Seconds())). + // Str("write_throughput", fmt.Sprintf("%s/s", writeThroughput)). + // Str("write_elapsed", fmt.Sprintf("%.3fs", writeElapsed.Seconds())). Str("total_elapsed", fmt.Sprintf("%.3fs", totalElapsed.Seconds())). Msg("Complete") return fileSize, totalElapsed, nil @@ -71,6 +77,10 @@ func (g *Getter) DownloadFiles(ctx context.Context, manifest Manifest) (int64, t errGroup, ctx := errgroup.WithContext(ctx) + if g.Options.MaxConcurrentFiles != 0 { + errGroup.SetLimit(g.Options.MaxConcurrentFiles) + } + totalSize := new(atomic.Int64) multifileDownloadStart := time.Now() diff --git a/pkg/pget_test.go b/pkg/pget_test.go index f33ac0c..59d3913 100644 --- a/pkg/pget_test.go +++ b/pkg/pget_test.go @@ -36,10 +36,8 @@ var defaultOpts = download.Options{Client: client.Options{}} var http2Opts = download.Options{Client: client.Options{ForceHTTP2: true}} func makeGetter(opts download.Options) *pget.Getter { - client := client.NewHTTPClient(opts.Client) - return &pget.Getter{ - Downloader: &download.BufferMode{Client: client, Options: opts}, + Downloader: download.GetBufferMode(opts), } }