Skip to content

Commit

Permalink
Better concurrency (#115)
Browse files Browse the repository at this point in the history
* introduce bufferedReader

The idea of bufferedReader is to be a staging area where we can download request
bodies to.  It blocks Read() calls until the whole request has been downloaded.

The key thing this enables us to do is to get rid of the errGroup.Wait() call,
because the readers that we return are waiting for each chunk to be ready and we
no longer need to wait until the whole file has been downloaded to return from
Fetch().

* respect MaxConcurrency within a single BufferMode object

This means that a single BufferMode instance cannot ever exceed
MaxConcurrency concurrent requests.

* introduce workQueue and chanMultiReader

We want to be able to return early so that a consumer can start consuming ASAP,
even if we block on making new requests before we get to the last chunk.

This lets us do that.

* add a per-file concurrency limit

* make max-concurrent-files configurable
  • Loading branch information
philandstuff authored Dec 14, 2023
1 parent c336304 commit f9b008e
Show file tree
Hide file tree
Showing 14 changed files with 404 additions and 264 deletions.
12 changes: 12 additions & 0 deletions cmd/multifile/multifile.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ func runMultifileCMD(cmd *cobra.Command, args []string) error {
return multifileExecute(cmd.Context(), manifest)
}

func maxConcurrentFiles() int {
maxConcurrentFiles := viper.GetInt(config.OptMaxConcurrentFiles)
if maxConcurrentFiles == 0 {
maxConcurrentFiles = 20
}
return maxConcurrentFiles
}

func multifileExecute(ctx context.Context, manifest pget.Manifest) error {
minChunkSize, err := humanize.ParseBytes(viper.GetString(config.OptMinimumChunkSize))
if err != nil {
Expand All @@ -112,6 +120,9 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error {
MinChunkSize: int64(minChunkSize),
Client: clientOpts,
}
pgetOpts := pget.Options{
MaxConcurrentFiles: maxConcurrentFiles(),
}

consumer, err := config.GetConsumer()
if err != nil {
Expand All @@ -121,6 +132,7 @@ func multifileExecute(ctx context.Context, manifest pget.Manifest) error {
getter := &pget.Getter{
Downloader: download.GetBufferMode(downloadOpts),
Consumer: consumer,
Options: pgetOpts,
}

// TODO DRY this
Expand Down
2 changes: 0 additions & 2 deletions cmd/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"golang.org/x/sync/semaphore"

pget "github.com/replicate/pget/pkg"
"github.com/replicate/pget/pkg/cli"
Expand Down Expand Up @@ -174,7 +173,6 @@ func rootExecute(ctx context.Context, urlString, dest string) error {
MaxConcurrency: viper.GetInt(config.OptConcurrency),
MinChunkSize: int64(minChunkSize),
Client: clientOpts,
Semaphore: semaphore.NewWeighted(int64(viper.GetInt(config.OptConcurrency))),
}

consumer, err := config.GetConsumer()
Expand Down
29 changes: 15 additions & 14 deletions pkg/config/optnames.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@ const (
OptCacheNodesSRVNameByHostCIDR = "cache-nodes-srv-name-by-host-cidr"
OptHostIP = "host-ip"

OptCacheNodesSRVName = "cache-nodes-srv-name"
OptConcurrency = "concurrency"
OptConnTimeout = "connect-timeout"
OptExtract = "extract"
OptForce = "force"
OptForceHTTP2 = "force-http2"
OptLoggingLevel = "log-level"
OptMaxChunks = "max-chunks"
OptMaxConnPerHost = "max-conn-per-host"
OptMinimumChunkSize = "minimum-chunk-size"
OptOutputConsumer = "output"
OptResolve = "resolve"
OptRetries = "retries"
OptVerbose = "verbose"
OptCacheNodesSRVName = "cache-nodes-srv-name"
OptConcurrency = "concurrency"
OptConnTimeout = "connect-timeout"
OptExtract = "extract"
OptForce = "force"
OptForceHTTP2 = "force-http2"
OptLoggingLevel = "log-level"
OptMaxChunks = "max-chunks"
OptMaxConnPerHost = "max-conn-per-host"
OptMaxConcurrentFiles = "max-concurrent-files"
OptMinimumChunkSize = "minimum-chunk-size"
OptOutputConsumer = "output"
OptResolve = "resolve"
OptRetries = "retries"
OptVerbose = "verbose"
)
173 changes: 91 additions & 82 deletions pkg/download/buffer.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
package download

import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"regexp"
"runtime"
"strconv"

"golang.org/x/sync/errgroup"
Expand All @@ -25,24 +23,26 @@ var contentRangeRegexp = regexp.MustCompile(`^bytes .*/([0-9]+)$`)
type BufferMode struct {
Client *client.HTTPClient
Options

// we use this errgroup as a semaphore (via sem.SetLimit())
sem *errgroup.Group
queue *workQueue
}

func GetBufferMode(opts Options) *BufferMode {
client := client.NewHTTPClient(opts.Client)
sem := new(errgroup.Group)
sem.SetLimit(opts.maxConcurrency())
queue := newWorkQueue(opts.maxConcurrency())
queue.start()
return &BufferMode{
Client: client,
Options: opts,
sem: sem,
queue: queue,
}
}

func (m *BufferMode) maxChunks() int {
maxChunks := m.MaxConcurrency
if maxChunks == 0 {
return runtime.NumCPU() * 4
}
return maxChunks
}

func (m *BufferMode) minChunkSize() int64 {
minChunkSize := m.MinChunkSize
if minChunkSize == 0 {
Expand All @@ -59,32 +59,61 @@ func (m *BufferMode) getFileSizeFromContentRange(contentRange string) (int64, er
return strconv.ParseInt(groups[1], 10, 64)
}

func (m *BufferMode) fileToBuffer(ctx context.Context, url string) (*bytes.Buffer, int64, error) {
type firstReqResult struct {
fileSize int64
trueURL string
err error
}

func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, error) {
logger := logging.GetLogger()
firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, url)
if err != nil {
return nil, -1, err
}

trueURL := firstChunkResp.Request.URL.String()
if trueURL != url {
logger.Info().Str("url", url).Str("redirect_url", trueURL).Msg("Redirect")
br := newBufferedReader(m.minChunkSize())

firstReqResultCh := make(chan firstReqResult)
m.queue.submit(func() {
m.sem.Go(func() error {
defer close(firstReqResultCh)
defer br.done()
firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, url)
if err != nil {
firstReqResultCh <- firstReqResult{err: err}
return err
}

defer firstChunkResp.Body.Close()

trueURL := firstChunkResp.Request.URL.String()
if trueURL != url {
logger.Info().Str("url", url).Str("redirect_url", trueURL).Msg("Redirect")
}

fileSize, err := m.getFileSizeFromContentRange(firstChunkResp.Header.Get("Content-Range"))
if err != nil {
firstReqResultCh <- firstReqResult{err: err}
return err
}
firstReqResultCh <- firstReqResult{fileSize: fileSize, trueURL: trueURL}

return br.downloadBody(firstChunkResp)
})
})

firstReqResult, ok := <-firstReqResultCh
if !ok {
panic("logic error in BufferMode: first request didn't return any output")
}

fileSize, err := m.getFileSizeFromContentRange(firstChunkResp.Header.Get("Content-Range"))
if err != nil {
firstChunkResp.Body.Close()
return nil, -1, err
if firstReqResult.err != nil {
return nil, -1, firstReqResult.err
}

data := make([]byte, fileSize)
fileSize := firstReqResult.fileSize
trueURL := firstReqResult.trueURL

if fileSize <= m.minChunkSize() {
// we only need a single chunk: just download it and finish
err = m.downloadChunk(firstChunkResp, data[0:fileSize])
if err != nil {
return nil, -1, err
}
return bytes.NewBuffer(data), fileSize, nil
return br, fileSize, nil
}

remainingBytes := fileSize - m.minChunkSize()
Expand All @@ -93,51 +122,52 @@ func (m *BufferMode) fileToBuffer(ctx context.Context, url string) (*bytes.Buffe
if numChunks <= 0 {
numChunks = 1
}
if numChunks > m.maxChunks() {
numChunks = m.maxChunks()
if numChunks > m.maxConcurrency() {
numChunks = m.maxConcurrency()
}

readersCh := make(chan io.Reader, m.maxConcurrency()+1)
readersCh <- br

startOffset := m.minChunkSize()

chunkSize := remainingBytes / int64(numChunks)
if chunkSize < 0 {
firstChunkResp.Body.Close()
return nil, -1, fmt.Errorf("error: chunksize incorrect - result is negative, %d", chunkSize)
}

logger.Debug().Str("url", url).
Int64("size", fileSize).
Int("connections", numChunks).
Int64("chunkSize", chunkSize).
Msg("Downloading")

errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
return m.downloadChunk(firstChunkResp, data[0:m.minChunkSize()])
})

startOffset := m.minChunkSize()
m.queue.submit(func() {
defer close(readersCh)
logger.Debug().Str("url", url).
Int64("size", fileSize).
Int("connections", numChunks).
Int64("chunkSize", chunkSize).
Msg("Downloading")

for i := 0; i < numChunks; i++ {
start := startOffset + chunkSize*int64(i)
end := start + chunkSize - 1
for i := 0; i < numChunks; i++ {
start := startOffset + chunkSize*int64(i)
end := start + chunkSize - 1

if i == numChunks-1 {
end = fileSize - 1
}
errGroup.Go(func() error {
resp, err := m.DoRequest(ctx, start, end, trueURL)
if err != nil {
return err
if i == numChunks-1 {
end = fileSize - 1
}
return m.downloadChunk(resp, data[start:end+1])
})
}

if err := errGroup.Wait(); err != nil {
return nil, -1, err // return the first error we encounter
}
br := newBufferedReader(end - start + 1)
readersCh <- br

m.sem.Go(func() error {
defer br.done()
resp, err := m.DoRequest(ctx, start, end, trueURL)
if err != nil {
return err
}
defer resp.Body.Close()
return br.downloadBody(resp)
})
}
})

buffer := bytes.NewBuffer(data)
return buffer, fileSize, nil
return newChanMultiReader(readersCh), fileSize, nil
}

func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) {
Expand All @@ -156,24 +186,3 @@ func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL st

return resp, nil
}

func (m *BufferMode) downloadChunk(resp *http.Response, dataSlice []byte) error {
defer resp.Body.Close()
expectedBytes := len(dataSlice)
n, err := io.ReadFull(resp.Body, dataSlice)
if err != nil && err != io.EOF {
return fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err)
}
if n != expectedBytes {
return fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String())
}
return nil
}

func (m *BufferMode) Fetch(ctx context.Context, url string) (result io.Reader, fileSize int64, err error) {
buffer, fileSize, err := m.fileToBuffer(ctx, url)
if err != nil {
return nil, 0, err
}
return buffer, fileSize, nil
}
7 changes: 1 addition & 6 deletions pkg/download/buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,8 @@ func init() {
var defaultOpts = download.Options{Client: client.Options{}}
var http2Opts = download.Options{Client: client.Options{ForceHTTP2: true}}

func makeBufferMode(opts download.Options) *download.BufferMode {
client := client.NewHTTPClient(opts.Client)

return &download.BufferMode{Client: client, Options: opts}
}
func benchmarkDownloadURL(opts download.Options, url string, b *testing.B) {
bufferMode := makeBufferMode(opts)
bufferMode := download.GetBufferMode(opts)

for n := 0; n < b.N; n++ {
ctx, cancel := context.WithCancel(context.Background())
Expand Down
11 changes: 3 additions & 8 deletions pkg/download/buffer_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/dustin/go-humanize"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/replicate/pget/pkg/client"
)
Expand Down Expand Up @@ -109,10 +110,10 @@ func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
opts.MaxConcurrency = tc.maxConcurrency
opts.MinChunkSize = tc.minChunkSize
bufferMode := makeBufferMode(opts)
bufferMode := GetBufferMode(opts)
path, _ := url.JoinPath(server.URL, testFilePath)
download, size, err := bufferMode.Fetch(context.Background(), path)
assert.NoError(t, err)
require.NoError(t, err)
data, err := io.ReadAll(download)
assert.NoError(t, err)
assert.Equal(t, contentSize, size)
Expand All @@ -121,9 +122,3 @@ func TestFileToBufferChunkCountExceedsMaxChunks(t *testing.T) {
})
}
}

func makeBufferMode(opts Options) *BufferMode {
client := client.NewHTTPClient(opts.Client)

return &BufferMode{Client: client, Options: opts}
}
Loading

0 comments on commit f9b008e

Please sign in to comment.