diff --git a/cmd/multifile/multifile.go b/cmd/multifile/multifile.go index 0531e66..054dc44 100644 --- a/cmd/multifile/multifile.go +++ b/cmd/multifile/multifile.go @@ -88,6 +88,14 @@ func runMultifileCMD(cmd *cobra.Command, args []string) error { return multifileExecute(cmd.Context(), manifest) } +func maxConcurrentFiles() int { + maxConcurrentFiles := viper.GetInt(config.OptMaxConcurrentFiles) + if maxConcurrentFiles == 0 { + maxConcurrentFiles = 20 + } + return maxConcurrentFiles +} + func multifileExecute(ctx context.Context, manifest pget.Manifest) error { minChunkSize, err := humanize.ParseBytes(viper.GetString(config.OptMinimumChunkSize)) if err != nil { @@ -112,6 +120,9 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { MinChunkSize: int64(minChunkSize), Client: clientOpts, } + pgetOpts := pget.Options{ + MaxConcurrentFiles: maxConcurrentFiles(), + } consumer, err := config.GetConsumer() if err != nil { @@ -121,6 +132,7 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error { getter := &pget.Getter{ Downloader: download.GetBufferMode(downloadOpts), Consumer: consumer, + Options: pgetOpts, } // TODO DRY this diff --git a/cmd/root/root.go b/cmd/root/root.go index f1b611b..593f0f9 100644 --- a/cmd/root/root.go +++ b/cmd/root/root.go @@ -11,7 +11,6 @@ import ( "github.com/rs/zerolog/log" "github.com/spf13/cobra" "github.com/spf13/viper" - "golang.org/x/sync/semaphore" pget "github.com/replicate/pget/pkg" "github.com/replicate/pget/pkg/cli" @@ -174,7 +173,6 @@ func rootExecute(ctx context.Context, urlString, dest string) error { MaxConcurrency: viper.GetInt(config.OptConcurrency), MinChunkSize: int64(minChunkSize), Client: clientOpts, - Semaphore: semaphore.NewWeighted(int64(viper.GetInt(config.OptConcurrency))), } consumer, err := config.GetConsumer() diff --git a/pkg/config/optnames.go b/pkg/config/optnames.go index 0e3da99..a1756ba 100644 --- a/pkg/config/optnames.go +++ b/pkg/config/optnames.go @@ -6,18 +6,19 @@ const ( OptCacheNodesSRVNameByHostCIDR = "cache-nodes-srv-name-by-host-cidr" OptHostIP = "host-ip" - OptCacheNodesSRVName = "cache-nodes-srv-name" - OptConcurrency = "concurrency" - OptConnTimeout = "connect-timeout" - OptExtract = "extract" - OptForce = "force" - OptForceHTTP2 = "force-http2" - OptLoggingLevel = "log-level" - OptMaxChunks = "max-chunks" - OptMaxConnPerHost = "max-conn-per-host" - OptMinimumChunkSize = "minimum-chunk-size" - OptOutputConsumer = "output" - OptResolve = "resolve" - OptRetries = "retries" - OptVerbose = "verbose" + OptCacheNodesSRVName = "cache-nodes-srv-name" + OptConcurrency = "concurrency" + OptConnTimeout = "connect-timeout" + OptExtract = "extract" + OptForce = "force" + OptForceHTTP2 = "force-http2" + OptLoggingLevel = "log-level" + OptMaxChunks = "max-chunks" + OptMaxConnPerHost = "max-conn-per-host" + OptMaxConcurrentFiles = "max-concurrent-files" + OptMinimumChunkSize = "minimum-chunk-size" + OptOutputConsumer = "output" + OptResolve = "resolve" + OptRetries = "retries" + OptVerbose = "verbose" ) diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index 936f14c..97459ac 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -1,13 +1,11 @@ package download import ( - "bytes" "context" "fmt" "io" "net/http" "regexp" - "runtime" "strconv" "golang.org/x/sync/errgroup" @@ -25,24 +23,26 @@ var contentRangeRegexp = regexp.MustCompile(`^bytes .*/([0-9]+)$`) type BufferMode struct { Client *client.HTTPClient Options + + // we use this errgroup as a semaphore (via sem.SetLimit()) + sem *errgroup.Group + queue *workQueue } func GetBufferMode(opts Options) *BufferMode { client := client.NewHTTPClient(opts.Client) + sem := new(errgroup.Group) + sem.SetLimit(opts.maxConcurrency()) + queue := newWorkQueue(opts.maxConcurrency()) + queue.start() return &BufferMode{ Client: client, Options: opts, + sem: sem, + queue: queue, } } -func (m *BufferMode) maxChunks() int { - maxChunks := m.MaxConcurrency - if maxChunks == 0 { - return runtime.NumCPU() * 4 - } - return maxChunks -} - func (m *BufferMode) minChunkSize() int64 { minChunkSize := m.MinChunkSize if minChunkSize == 0 { @@ -59,32 +59,61 @@ func (m *BufferMode) getFileSizeFromContentRange(contentRange string) (int64, er return strconv.ParseInt(groups[1], 10, 64) } -func (m *BufferMode) fileToBuffer(ctx context.Context, url string) (*bytes.Buffer, int64, error) { +type firstReqResult struct { + fileSize int64 + trueURL string + err error +} + +func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, error) { logger := logging.GetLogger() - firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, url) - if err != nil { - return nil, -1, err - } - trueURL := firstChunkResp.Request.URL.String() - if trueURL != url { - logger.Info().Str("url", url).Str("redirect_url", trueURL).Msg("Redirect") + br := newBufferedReader(m.minChunkSize()) + + firstReqResultCh := make(chan firstReqResult) + m.queue.submit(func() { + m.sem.Go(func() error { + defer close(firstReqResultCh) + defer br.done() + firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, url) + if err != nil { + firstReqResultCh <- firstReqResult{err: err} + return err + } + + defer firstChunkResp.Body.Close() + + trueURL := firstChunkResp.Request.URL.String() + if trueURL != url { + logger.Info().Str("url", url).Str("redirect_url", trueURL).Msg("Redirect") + } + + fileSize, err := m.getFileSizeFromContentRange(firstChunkResp.Header.Get("Content-Range")) + if err != nil { + firstReqResultCh <- firstReqResult{err: err} + return err + } + firstReqResultCh <- firstReqResult{fileSize: fileSize, trueURL: trueURL} + + return br.downloadBody(firstChunkResp) + }) + }) + + firstReqResult, ok := <-firstReqResultCh + if !ok { + panic("logic error in BufferMode: first request didn't return any output") } - fileSize, err := m.getFileSizeFromContentRange(firstChunkResp.Header.Get("Content-Range")) - if err != nil { - firstChunkResp.Body.Close() - return nil, -1, err + if firstReqResult.err != nil { + return nil, -1, firstReqResult.err } - data := make([]byte, fileSize) + fileSize := firstReqResult.fileSize + trueURL := firstReqResult.trueURL + if fileSize <= m.minChunkSize() { // we only need a single chunk: just download it and finish - err = m.downloadChunk(firstChunkResp, data[0:fileSize]) - if err != nil { - return nil, -1, err - } - return bytes.NewBuffer(data), fileSize, nil + return br, fileSize, nil } remainingBytes := fileSize - m.minChunkSize() @@ -93,51 +122,52 @@ func (m *BufferMode) fileToBuffer(ctx context.Context, url string) (*bytes.Buffe if numChunks <= 0 { numChunks = 1 } - if numChunks > m.maxChunks() { - numChunks = m.maxChunks() + if numChunks > m.maxConcurrency() { + numChunks = m.maxConcurrency() } + readersCh := make(chan io.Reader, m.maxConcurrency()+1) + readersCh <- br + + startOffset := m.minChunkSize() + chunkSize := remainingBytes / int64(numChunks) if chunkSize < 0 { - firstChunkResp.Body.Close() return nil, -1, fmt.Errorf("error: chunksize incorrect - result is negative, %d", chunkSize) } - logger.Debug().Str("url", url). - Int64("size", fileSize). - Int("connections", numChunks). - Int64("chunkSize", chunkSize). - Msg("Downloading") - - errGroup, ctx := errgroup.WithContext(ctx) - errGroup.Go(func() error { - return m.downloadChunk(firstChunkResp, data[0:m.minChunkSize()]) - }) - - startOffset := m.minChunkSize() + m.queue.submit(func() { + defer close(readersCh) + logger.Debug().Str("url", url). + Int64("size", fileSize). + Int("connections", numChunks). + Int64("chunkSize", chunkSize). + Msg("Downloading") - for i := 0; i < numChunks; i++ { - start := startOffset + chunkSize*int64(i) - end := start + chunkSize - 1 + for i := 0; i < numChunks; i++ { + start := startOffset + chunkSize*int64(i) + end := start + chunkSize - 1 - if i == numChunks-1 { - end = fileSize - 1 - } - errGroup.Go(func() error { - resp, err := m.DoRequest(ctx, start, end, trueURL) - if err != nil { - return err + if i == numChunks-1 { + end = fileSize - 1 } - return m.downloadChunk(resp, data[start:end+1]) - }) - } - if err := errGroup.Wait(); err != nil { - return nil, -1, err // return the first error we encounter - } + br := newBufferedReader(end - start + 1) + readersCh <- br + + m.sem.Go(func() error { + defer br.done() + resp, err := m.DoRequest(ctx, start, end, trueURL) + if err != nil { + return err + } + defer resp.Body.Close() + return br.downloadBody(resp) + }) + } + }) - buffer := bytes.NewBuffer(data) - return buffer, fileSize, nil + return newChanMultiReader(readersCh), fileSize, nil } func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) { @@ -156,24 +186,3 @@ func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL st return resp, nil } - -func (m *BufferMode) downloadChunk(resp *http.Response, dataSlice []byte) error { - defer resp.Body.Close() - expectedBytes := len(dataSlice) - n, err := io.ReadFull(resp.Body, dataSlice) - if err != nil && err != io.EOF { - return fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err) - } - if n != expectedBytes { - return fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String()) - } - return nil -} - -func (m *BufferMode) Fetch(ctx context.Context, url string) (result io.Reader, fileSize int64, err error) { - buffer, fileSize, err := m.fileToBuffer(ctx, url) - if err != nil { - return nil, 0, err - } - return buffer, fileSize, nil -} diff --git a/pkg/download/buffer_test.go b/pkg/download/buffer_test.go index 56b2947..da05e4b 100644 --- a/pkg/download/buffer_test.go +++ b/pkg/download/buffer_test.go @@ -18,13 +18,8 @@ func init() { var defaultOpts = download.Options{Client: client.Options{}} var http2Opts = download.Options{Client: client.Options{ForceHTTP2: true}} -func makeBufferMode(opts download.Options) *download.BufferMode { - client := client.NewHTTPClient(opts.Client) - - return &download.BufferMode{Client: client, Options: opts} -} func benchmarkDownloadURL(opts download.Options, url string, b *testing.B) { - bufferMode := makeBufferMode(opts) + bufferMode := download.GetBufferMode(opts) for n := 0; n < b.N; n++ { ctx, cancel := context.WithCancel(context.Background()) diff --git a/pkg/download/buffer_unit_test.go b/pkg/download/buffer_unit_test.go index dde7ac0..e9b7d00 100644 --- a/pkg/download/buffer_unit_test.go +++ b/pkg/download/buffer_unit_test.go @@ -13,6 +13,7 @@ import ( "github.com/dustin/go-humanize" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/replicate/pget/pkg/client" ) @@ -109,10 +110,10 @@ func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) { t.Run(tc.name, func(t *testing.T) { opts.MaxConcurrency = tc.maxConcurrency opts.MinChunkSize = tc.minChunkSize - bufferMode := makeBufferMode(opts) + bufferMode := GetBufferMode(opts) path, _ := url.JoinPath(server.URL, testFilePath) download, size, err := bufferMode.Fetch(context.Background(), path) - assert.NoError(t, err) + require.NoError(t, err) data, err := io.ReadAll(download) assert.NoError(t, err) assert.Equal(t, contentSize, size) @@ -121,9 +122,3 @@ func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) { }) } } - -func makeBufferMode(opts Options) *BufferMode { - client := client.NewHTTPClient(opts.Client) - - return &BufferMode{Client: client, Options: opts} -} diff --git a/pkg/download/buffered_reader.go b/pkg/download/buffered_reader.go new file mode 100644 index 0000000..973eeea --- /dev/null +++ b/pkg/download/buffered_reader.go @@ -0,0 +1,55 @@ +package download + +import ( + "bytes" + "fmt" + "io" + "net/http" +) + +// A bufferedReader wraps an http.Response.Body so that it can be eagerly +// downloaded to a buffer before the actual io.Reader consumer can read it. +// It implements io.Reader. +type bufferedReader struct { + // ready channel is closed when we're ready to read + ready chan struct{} + buf *bytes.Buffer + err error +} + +var _ io.Reader = &bufferedReader{} + +func newBufferedReader(capacity int64) *bufferedReader { + return &bufferedReader{ + ready: make(chan struct{}), + buf: bytes.NewBuffer(make([]byte, 0, capacity)), + } +} + +// Read implements io.Reader. It will block until the full body is available for +// reading. +func (b *bufferedReader) Read(buf []byte) (int, error) { + <-b.ready + if b.err != nil { + return 0, b.err + } + return b.buf.Read(buf) +} + +func (b *bufferedReader) done() { + close(b.ready) +} + +func (b *bufferedReader) downloadBody(resp *http.Response) error { + expectedBytes := resp.ContentLength + n, err := b.buf.ReadFrom(resp.Body) + if err != nil && err != io.EOF { + b.err = fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err) + return b.err + } + if n != expectedBytes { + b.err = fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String()) + return b.err + } + return nil +} diff --git a/pkg/download/chan_multi_reader.go b/pkg/download/chan_multi_reader.go new file mode 100644 index 0000000..56efc8f --- /dev/null +++ b/pkg/download/chan_multi_reader.go @@ -0,0 +1,42 @@ +package download + +import "io" + +type chanMultiReader struct { + ch <-chan io.Reader + cur io.Reader +} + +var _ io.Reader = &chanMultiReader{} + +func newChanMultiReader(ch <-chan io.Reader) *chanMultiReader { + return &chanMultiReader{ch: ch} +} + +func (c *chanMultiReader) Read(p []byte) (n int, err error) { + for { + if c.cur == nil { + var ok bool + c.cur, ok = <-c.ch + if !ok { + // no more readers; return EOF + return 0, io.EOF + } + } + n, err = c.cur.Read(p) + if err == io.EOF { + c.cur = nil + } + if n > 0 || err != io.EOF { + // we either made progress or hit an error, return to the caller + if err == io.EOF { + // TODO: we could eagerly check to see if the channel is closed + // and return EOF one call early + err = nil + } + return + } + // n == 0, err == EOF; this reader is done and we need to start the next + c.cur = nil + } +} diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index b17ae89..4b649ae 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -1,14 +1,12 @@ package download import ( - "bytes" "context" "errors" "fmt" "io" "net/http" "net/url" - "runtime" "strconv" jump "github.com/dgryski/go-jump" @@ -25,6 +23,10 @@ type ConsistentHashingMode struct { Options // TODO: allow this to be configured and not just "BufferMode" FallbackStrategy Strategy + + // we use this errgroup as a semaphore (via sem.SetLimit()) + sem *errgroup.Group + queue *workQueue } type CacheKey struct { @@ -32,32 +34,32 @@ type CacheKey struct { Slice int64 } -func GetConsistentHashingMode(opts Options) (Strategy, error) { +func GetConsistentHashingMode(opts Options) (*ConsistentHashingMode, error) { if opts.SliceSize == 0 { return nil, fmt.Errorf("must specify slice size in consistent hashing mode") } - if opts.Semaphore != nil && opts.MaxConcurrency == 0 { - return nil, fmt.Errorf("if you provide a semaphore you must specify MaxConcurrency") - } client := client.NewHTTPClient(opts.Client) - fallbackStrategy := GetBufferMode(opts) - fallbackStrategy.Client = client + sem := new(errgroup.Group) + sem.SetLimit(opts.maxConcurrency()) + queue := newWorkQueue(opts.maxConcurrency()) + queue.start() + + fallbackStrategy := &BufferMode{ + Client: client, + Options: opts, + sem: sem, + queue: queue, + } return &ConsistentHashingMode{ Client: client, Options: opts, FallbackStrategy: fallbackStrategy, + sem: sem, + queue: queue, }, nil } -func (m *ConsistentHashingMode) maxConcurrency() int { - maxChunks := m.MaxConcurrency - if maxChunks == 0 { - return runtime.NumCPU() * 4 - } - return maxChunks -} - func (m *ConsistentHashingMode) minChunkSize() int64 { minChunkSize := m.MinChunkSize if minChunkSize == 0 { @@ -100,34 +102,47 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io return m.FallbackStrategy.Fetch(ctx, urlString) } - firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, urlString) - if err != nil { + br := newBufferedReader(m.minChunkSize()) + firstReqResultCh := make(chan firstReqResult) + m.queue.submit(func() { + m.sem.Go(func() error { + defer close(firstReqResultCh) + defer br.done() + firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, urlString) + if err != nil { + firstReqResultCh <- firstReqResult{err: err} + return err + } + defer firstChunkResp.Body.Close() + + fileSize, err := m.getFileSizeFromContentRange(firstChunkResp.Header.Get("Content-Range")) + if err != nil { + firstReqResultCh <- firstReqResult{err: err} + return err + } + firstReqResultCh <- firstReqResult{fileSize: fileSize} + + return br.downloadBody(firstChunkResp) + }) + }) + firstReqResult, ok := <-firstReqResultCh + if !ok { + panic("logic error in ConsistentHashingMode: first request didn't return any output") + } + if firstReqResult.err != nil { // In the case that an error indicating an issue with the cache server, networking, etc is returned, // this will use the fallback strategy. This is a case where the whole file will use the fallback // strategy. - if errors.Is(err, client.ErrStrategyFallback) { + if errors.Is(firstReqResult.err, client.ErrStrategyFallback) { return m.FallbackStrategy.Fetch(ctx, urlString) } - return nil, -1, err - } - - fileSize, err := m.getFileSizeFromContentRange(firstChunkResp.Header.Get("Content-Range")) - if err != nil { - firstChunkResp.Body.Close() - return nil, -1, err + return nil, -1, firstReqResult.err } + fileSize := firstReqResult.fileSize - data := make([]byte, fileSize) if fileSize <= m.minChunkSize() { // we only need a single chunk: just download it and finish - err = m.downloadChunk(firstChunkResp, data) - if err != nil { - return nil, -1, err - } - // TODO: rather than eagerly downloading here, we could return - // an io.ReadCloser that downloads the file and releases the - // semaphore when closed - return bytes.NewBuffer(data), fileSize, nil + return br, fileSize, nil } totalSlices := fileSize / m.SliceSize @@ -135,11 +150,6 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io totalSlices++ } - errGroup, ctx := errgroup.WithContext(ctx) - errGroup.Go(func() error { - return m.downloadChunk(firstChunkResp, data[0:m.minChunkSize()]) - }) - // we subtract one because we've already got firstChunkResp in flight concurrency := m.maxConcurrency() - 1 if concurrency <= int(totalSlices) { @@ -154,80 +164,77 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io chunksPerSlice = append([]int64{0}, EqualSplit(int64(concurrency), totalSlices-1)...) } + readersCh := make(chan io.Reader, m.maxConcurrency()+1) + readersCh <- br + logger.Debug().Str("url", urlString). Int64("size", fileSize). Int("concurrency", m.maxConcurrency()). Ints64("chunks_per_slice", chunksPerSlice). Msg("Downloading") - for slice, numChunks := range chunksPerSlice { - if numChunks == 0 { - // this happens if we've already downloaded the whole first slice - continue - } - startFrom := m.SliceSize * int64(slice) - sliceSize := m.SliceSize - if slice == 0 { - startFrom = firstChunkResp.ContentLength - sliceSize = sliceSize - firstChunkResp.ContentLength - } - if slice == int(totalSlices)-1 { - sliceSize = (fileSize-1)%m.SliceSize + 1 - } - if sliceSize/numChunks < m.minChunkSize() { - // reset numChunks to respect minChunkSize - numChunks = sliceSize / m.minChunkSize() - // although we must always have at least one chunk + m.queue.submit(func() { + defer close(readersCh) + for slice, numChunks := range chunksPerSlice { if numChunks == 0 { - numChunks = 1 + // this happens if we've already downloaded the whole first slice + continue } - } - chunkSizes := EqualSplit(sliceSize, numChunks) - for _, chunkSize := range chunkSizes { - // startFrom changes each time round the loop - // we create chunkStart to be a stable variable for the goroutine to capture - chunkStart := startFrom - chunkEnd := startFrom + chunkSize - 1 - - dataSlice := data[chunkStart : chunkEnd+1] - errGroup.Go(func() error { - logger.Debug().Int64("start", chunkStart).Int64("end", chunkEnd).Msg("starting request") - resp, err := m.DoRequest(ctx, chunkStart, chunkEnd, urlString) - if err != nil { - // In the case that an error indicating an issue with the cache server, networking, etc is returned, - // this will use the fallback strategy. This is a case where the whole file will perform the fall-back - // for the specified chunk instead of the whole file. - if errors.Is(err, client.ErrStrategyFallback) { - resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString) - } + startFrom := m.SliceSize * int64(slice) + sliceSize := m.SliceSize + if slice == 0 { + startFrom = m.minChunkSize() + sliceSize = sliceSize - m.minChunkSize() + } + if slice == int(totalSlices)-1 { + sliceSize = (fileSize-1)%m.SliceSize + 1 + } + if sliceSize/numChunks < m.minChunkSize() { + // reset numChunks to respect minChunkSize + numChunks = sliceSize / m.minChunkSize() + // although we must always have at least one chunk + if numChunks == 0 { + numChunks = 1 + } + } + chunkSizes := EqualSplit(sliceSize, numChunks) + for _, chunkSize := range chunkSizes { + // startFrom changes each time round the loop + // we create chunkStart to be a stable variable for the goroutine to capture + chunkStart := startFrom + chunkEnd := startFrom + chunkSize - 1 + + br := newBufferedReader(m.minChunkSize()) + readersCh <- br + m.sem.Go(func() error { + defer br.done() + logger.Debug().Int64("start", chunkStart).Int64("end", chunkEnd).Msg("starting request") + resp, err := m.DoRequest(ctx, chunkStart, chunkEnd, urlString) if err != nil { - return err + // in the case that an error indicating an issue with the cache server, networking, etc is returned, + // this will use the fallback strategy. This is a case where the whole file will perform the fall-back + // for the specified chunk instead of the whole file. + if errors.Is(err, client.ErrStrategyFallback) { + resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString) + } + if err != nil { + return err + } } - } + defer resp.Body.Close() + return br.downloadBody(resp) + }) - return m.downloadChunk(resp, dataSlice) - }) - - startFrom = startFrom + chunkSize + startFrom = startFrom + chunkSize + } } - } - - if err := errGroup.Wait(); err != nil { - return nil, -1, err // return the first error we encounter - } + }) - buffer := bytes.NewBuffer(data) - return buffer, fileSize, nil + return newChanMultiReader(readersCh), fileSize, nil } func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) { logger := logging.GetLogger() - if m.Semaphore != nil { - err := m.Semaphore.Acquire(ctx, 1) - if err != nil { - return nil, err - } - } chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true) req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil) if err != nil { @@ -288,21 +295,3 @@ func (m *ConsistentHashingMode) consistentHashIfNeeded(req *http.Request, start } return nil } - -func (m *ConsistentHashingMode) downloadChunk(resp *http.Response, dataSlice []byte) error { - logger := logging.GetLogger() - defer resp.Body.Close() - if m.Semaphore != nil { - defer m.Semaphore.Release(1) - } - expectedBytes := len(dataSlice) - n, err := io.ReadFull(resp.Body, dataSlice) - if err != nil && err != io.EOF { - return fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err) - } - if n != expectedBytes { - return fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String()) - } - logger.Debug().Int("size", len(dataSlice)).Int("downloaded", n).Msg("downloaded chunk") - return nil -} diff --git a/pkg/download/consistent_hashing_test.go b/pkg/download/consistent_hashing_test.go index dae9eec..4d7e568 100644 --- a/pkg/download/consistent_hashing_test.go +++ b/pkg/download/consistent_hashing_test.go @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/sync/semaphore" "github.com/replicate/pget/pkg/client" "github.com/replicate/pget/pkg/download" @@ -30,13 +29,6 @@ var testFSes = []fstest.MapFS{ {"hello.txt": {Data: []byte("7777777777777777")}}, } -func makeConsistentHashingMode(opts download.Options) *download.ConsistentHashingMode { - client := client.NewHTTPClient(opts.Client) - fallbackMode := download.BufferMode{Options: opts, Client: client} - - return &download.ConsistentHashingMode{Client: client, Options: opts, FallbackStrategy: &fallbackMode} -} - type chTestCase struct { name string concurrency int @@ -185,7 +177,6 @@ func TestConsistentHashing(t *testing.T) { Client: client.Options{}, MaxConcurrency: tc.concurrency, MinChunkSize: tc.minChunkSize, - Semaphore: semaphore.NewWeighted(int64(tc.concurrency)), CacheHosts: hostnames[0:tc.numCacheHosts], DomainsToCache: []string{"test.replicate.delivery"}, SliceSize: tc.sliceSize, @@ -194,7 +185,8 @@ func TestConsistentHashing(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - strategy := makeConsistentHashingMode(opts) + strategy, err := download.GetConsistentHashingMode(opts) + assert.NoError(t, err) reader, _, err := strategy.Fetch(ctx, "http://test.replicate.delivery/hello.txt") assert.NoError(t, err) @@ -214,7 +206,6 @@ func TestConsistentHashingHasFallback(t *testing.T) { Client: client.Options{}, MaxConcurrency: 8, MinChunkSize: 2, - Semaphore: semaphore.NewWeighted(8), CacheHosts: []string{}, DomainsToCache: []string{"fake.replicate.delivery"}, SliceSize: 3, @@ -223,12 +214,13 @@ func TestConsistentHashingHasFallback(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - strategy := makeConsistentHashingMode(opts) + strategy, err := download.GetConsistentHashingMode(opts) + require.NoError(t, err) urlString, err := url.JoinPath(server.URL, "hello.txt") - assert.NoError(t, err) + require.NoError(t, err) reader, _, err := strategy.Fetch(ctx, urlString) - assert.NoError(t, err) + require.NoError(t, err) bytes, err := io.ReadAll(reader) assert.NoError(t, err) @@ -263,7 +255,14 @@ func (s *testStrategy) DoRequest(ctx context.Context, start, end int64, url stri s.mut.Lock() s.doRequestCalledCount++ s.mut.Unlock() - resp := &http.Response{Body: io.NopCloser(strings.NewReader("00"))} + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp := &http.Response{ + Request: req, + Body: io.NopCloser(strings.NewReader("00")), + } return resp, nil } @@ -302,7 +301,6 @@ func TestConsistentHashingFileFallback(t *testing.T) { Client: client.Options{}, MaxConcurrency: 8, MinChunkSize: 2, - Semaphore: semaphore.NewWeighted(8), CacheHosts: []string{url.Host}, DomainsToCache: []string{"fake.replicate.delivery"}, SliceSize: 3, @@ -311,12 +309,14 @@ func TestConsistentHashingFileFallback(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - strategy := makeConsistentHashingMode(opts) + strategy, err := download.GetConsistentHashingMode(opts) + assert.NoError(t, err) + fallbackStrategy := &testStrategy{} strategy.FallbackStrategy = fallbackStrategy urlString := "http://fake.replicate.delivery/hello.txt" - _, _, err := strategy.Fetch(ctx, urlString) + _, _, err = strategy.Fetch(ctx, urlString) if tc.expectedError != nil { assert.ErrorIs(t, err, tc.expectedError) } @@ -363,7 +363,6 @@ func TestConsistentHashingChunkFallback(t *testing.T) { Client: client.Options{}, MaxConcurrency: 8, MinChunkSize: 3, - Semaphore: semaphore.NewWeighted(8), CacheHosts: []string{url.Host}, DomainsToCache: []string{"fake.replicate.delivery"}, SliceSize: 3, @@ -372,13 +371,20 @@ func TestConsistentHashingChunkFallback(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - strategy := makeConsistentHashingMode(opts) + strategy, err := download.GetConsistentHashingMode(opts) + assert.NoError(t, err) + fallbackStrategy := &testStrategy{} strategy.FallbackStrategy = fallbackStrategy urlString := "http://fake.replicate.delivery/hello.txt" - _, _, err := strategy.Fetch(ctx, urlString) + out, _, err := strategy.Fetch(ctx, urlString) assert.ErrorIs(t, err, tc.expectedError) + if err == nil { + // eagerly read the whole output reader to force all the + // requests to be completed + _, _ = io.Copy(io.Discard, out) + } assert.Equal(t, tc.fetchCalledCount, fallbackStrategy.fetchCalledCount) assert.Equal(t, tc.doRequestCalledCount, fallbackStrategy.doRequestCalledCount) }) diff --git a/pkg/download/options.go b/pkg/download/options.go index a9b29ce..a3baa66 100644 --- a/pkg/download/options.go +++ b/pkg/download/options.go @@ -1,7 +1,7 @@ package download import ( - "golang.org/x/sync/semaphore" + "runtime" "github.com/replicate/pget/pkg/client" ) @@ -29,8 +29,12 @@ type Options struct { // hashing algorithm. The slice may contain empty entries which // correspond to a cache host which is currently unavailable. CacheHosts []string +} - // Semaphore is used to manage maximum concurrency. If nil, concurrency - // is unlimited. - Semaphore *semaphore.Weighted +func (o *Options) maxConcurrency() int { + maxChunks := o.MaxConcurrency + if maxChunks == 0 { + return runtime.NumCPU() * 4 + } + return maxChunks } diff --git a/pkg/download/work_queue.go b/pkg/download/work_queue.go new file mode 100644 index 0000000..b181f2c --- /dev/null +++ b/pkg/download/work_queue.go @@ -0,0 +1,26 @@ +package download + +// workQueue takes work items and executes them serially, in strict FIFO order +type workQueue struct { + queue chan work +} + +type work func() + +func newWorkQueue(depth int) *workQueue { + return &workQueue{queue: make(chan work, depth)} +} + +func (q *workQueue) submit(w work) { + q.queue <- w +} + +func (q *workQueue) start() { + go q.run() +} + +func (q *workQueue) run() { + for item := range q.queue { + item() + } +} diff --git a/pkg/pget.go b/pkg/pget.go index 061d4fc..14fc1df 100644 --- a/pkg/pget.go +++ b/pkg/pget.go @@ -17,6 +17,11 @@ import ( type Getter struct { Downloader download.Strategy Consumer consumer.Consumer + Options Options +} + +type Options struct { + MaxConcurrentFiles int } type ManifestEntry struct { @@ -39,26 +44,27 @@ func (g *Getter) DownloadFile(ctx context.Context, url string, dest string) (int if err != nil { return fileSize, 0, err } - downloadElapsed := time.Since(downloadStartTime) - writeStartTime := time.Now() + // downloadElapsed := time.Since(downloadStartTime) + // writeStartTime := time.Now() err = g.Consumer.Consume(buffer, dest) if err != nil { return fileSize, 0, fmt.Errorf("error writing file: %w", err) } - writeElapsed := time.Since(writeStartTime) + // writeElapsed := time.Since(writeStartTime) totalElapsed := time.Since(downloadStartTime) size := humanize.Bytes(uint64(fileSize)) - downloadThroughput := humanize.Bytes(uint64(float64(fileSize) / downloadElapsed.Seconds())) - writeThroughput := humanize.Bytes(uint64(float64(fileSize) / writeElapsed.Seconds())) + // downloadThroughput := humanize.Bytes(uint64(float64(fileSize) / downloadElapsed.Seconds())) + // writeThroughput := humanize.Bytes(uint64(float64(fileSize) / writeElapsed.Seconds())) logger.Info(). Str("dest", dest). + Str("url", url). Str("size", size). - Str("download_throughput", fmt.Sprintf("%s/s", downloadThroughput)). - Str("download_elapsed", fmt.Sprintf("%.3fs", downloadElapsed.Seconds())). - Str("write_throughput", fmt.Sprintf("%s/s", writeThroughput)). - Str("write_elapsed", fmt.Sprintf("%.3fs", writeElapsed.Seconds())). + // Str("download_throughput", fmt.Sprintf("%s/s", downloadThroughput)). + // Str("download_elapsed", fmt.Sprintf("%.3fs", downloadElapsed.Seconds())). + // Str("write_throughput", fmt.Sprintf("%s/s", writeThroughput)). + // Str("write_elapsed", fmt.Sprintf("%.3fs", writeElapsed.Seconds())). Str("total_elapsed", fmt.Sprintf("%.3fs", totalElapsed.Seconds())). Msg("Complete") return fileSize, totalElapsed, nil @@ -71,6 +77,10 @@ func (g *Getter) DownloadFiles(ctx context.Context, manifest Manifest) (int64, t errGroup, ctx := errgroup.WithContext(ctx) + if g.Options.MaxConcurrentFiles != 0 { + errGroup.SetLimit(g.Options.MaxConcurrentFiles) + } + totalSize := new(atomic.Int64) multifileDownloadStart := time.Now() diff --git a/pkg/pget_test.go b/pkg/pget_test.go index f33ac0c..59d3913 100644 --- a/pkg/pget_test.go +++ b/pkg/pget_test.go @@ -36,10 +36,8 @@ var defaultOpts = download.Options{Client: client.Options{}} var http2Opts = download.Options{Client: client.Options{ForceHTTP2: true}} func makeGetter(opts download.Options) *pget.Getter { - client := client.NewHTTPClient(opts.Client) - return &pget.Getter{ - Downloader: &download.BufferMode{Client: client, Options: opts}, + Downloader: download.GetBufferMode(opts), } }