Skip to content

Commit

Permalink
reimplement bufferedReader on top of bufio.Reader (#180)
Browse files Browse the repository at this point in the history
* reimplement bufferedReader on top of bufio.Reader

This reimplements bufferedReader on top of bufio.Reader.

This doens't make the code shorter, but it delegates more complexity to standard
library functions.  In particular, bufio.Reader:

- handles storing errors and passing them out via Read() where appropriate
- handles errors where a read would exceed the buffer size

Furthermore this commit removes any HTTP concerns from bufferedReader.  The
downloadBody() function has been changed to a generic ReadFrom().

* test bufferedReader

* Ensure bufferedReader doesn't stall

bufio.Reader can stall if 100 consecutive Read() calls make no progress.  The
previous bytes.Buffer implementation used io.Copy which doesn't have this case.
  • Loading branch information
philandstuff authored Mar 7, 2024
1 parent c72c384 commit 4ebad73
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 41 deletions.
23 changes: 19 additions & 4 deletions pkg/download/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
m.queue.submit(func() {
m.sem.Go(func() error {
defer close(firstReqResultCh)
defer br.done()
defer br.Done()
firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, url)
if err != nil {
firstReqResultCh <- firstReqResult{err: err}
Expand All @@ -91,7 +91,14 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
}
firstReqResultCh <- firstReqResult{fileSize: fileSize, trueURL: trueURL}

return br.downloadBody(firstChunkResp)
contentLength := firstChunkResp.ContentLength
n, err := br.ReadFrom(firstChunkResp.Body)
if err != nil {
return err
} else if n != contentLength {
return ErrContentLengthMismatch{downloadedBytes: n, contentLength: contentLength}
}
return nil
})
})

Expand Down Expand Up @@ -141,13 +148,21 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
readersCh <- br

m.sem.Go(func() error {
defer br.done()
defer br.Done()
resp, err := m.DoRequest(ctx, start, end, trueURL)
if err != nil {
return err
}
defer resp.Body.Close()
return br.downloadBody(resp)

contentLength := resp.ContentLength
n, err := br.ReadFrom(resp.Body)
if err != nil {
return err
} else if n != contentLength {
return ErrContentLengthMismatch{downloadedBytes: n, contentLength: contentLength}
}
return nil
})
}
})
Expand Down
79 changes: 46 additions & 33 deletions pkg/download/buffered_reader.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
package download

import (
"bytes"
"fmt"
"bufio"
"io"
"net/http"
"strings"
"sync"
)

// A bufferedReader wraps an http.Response.Body so that it can be eagerly
// downloaded to a buffer before the actual io.Reader consumer can read it.
// It implements io.Reader.
// A bufferedReader wraps a bufio.Reader so that it can be shared between
// goroutines, with one fetching data from an upstream reader and another
// reading the data. It implements io.ReaderFrom and io.Reader. Read() will
// block until Done() is called.
//
// The intended use is: one goroutine calls Read(), which blocks until data is
// ready. Another calls ReadFrom() and then Done(). The call to Done()
// unblocks the Read() call and allows it to read the data that was fetched by
// ReadFrom().
type bufferedReader struct {
// ready channel is closed when we're ready to read
ready chan struct{}
buf *bytes.Buffer
buf *bufio.Reader
pool *bufferPool
}

var _ io.Reader = &bufferedReader{}

var emptyBuffer = bytes.NewBuffer(nil)
var _ io.ReaderFrom = &bufferedReader{}

func newBufferedReader(pool *bufferPool) *bufferedReader {
return &bufferedReader{
Expand All @@ -35,57 +39,66 @@ func newBufferedReader(pool *bufferPool) *bufferedReader {
// pool.
func (b *bufferedReader) Read(buf []byte) (int, error) {
<-b.ready
if b.buf == nil {
return 0, io.EOF
}
n, err := b.buf.Read(buf)
// If we've read all the data,
if b.buf.Len() == 0 && b.buf != emptyBuffer {
if b.buf.Buffered() == 0 {
// return the buffer to the pool
b.pool.Put(b.buf)
// and replace our buffer with something that will always return EOF on
// future reads
b.buf = emptyBuffer
b.buf = nil
}
return n, err
}

func (b *bufferedReader) done() {
close(b.ready)
}

func (b *bufferedReader) downloadBody(resp *http.Response) error {
expectedBytes := resp.ContentLength

if expectedBytes > int64(b.buf.Cap()) {
return fmt.Errorf("Tried to download 0x%x bytes to a 0x%x-sized buffer", expectedBytes, b.buf.Cap())
func (b *bufferedReader) ReadFrom(r io.Reader) (int64, error) {
b.buf.Reset(r)
var bytes []byte
var err error
for {
bytes, err = b.buf.Peek(b.buf.Size())
if err != io.ErrNoProgress {
// keep trying until we make progress
break
}
}
n, err := b.buf.ReadFrom(resp.Body)
if err != nil && err != io.EOF {
return fmt.Errorf("error reading response for %s: %w", resp.Request.URL.String(), err)
if err == io.EOF {
// ReadFrom does not return io.EOF
err = nil
}
if n != expectedBytes {
return fmt.Errorf("downloaded %d bytes instead of %d for %s", n, expectedBytes, resp.Request.URL.String())
}
return nil
return int64(len(bytes)), err
}

func (b *bufferedReader) Done() {
close(b.ready)
}

type bufferPool struct {
pool sync.Pool
}

func newBufferPool(capacity int64) *bufferPool {
func newBufferPool(size int64) *bufferPool {
return &bufferPool{
pool: sync.Pool{
New: func() any {
return bytes.NewBuffer(make([]byte, 0, capacity))
return bufio.NewReaderSize(nil, int(size))
},
},
}
}

func (p *bufferPool) Get() *bytes.Buffer {
return p.pool.Get().(*bytes.Buffer)
var emptyReader = strings.NewReader("")

// Get returns a bufio.Reader with the correct size, with a blank underlying io.Reader.
func (p *bufferPool) Get() *bufio.Reader {
br := p.pool.Get().(*bufio.Reader)
br.Reset(emptyReader)
return br
}

func (p *bufferPool) Put(buf *bytes.Buffer) {
buf.Reset()
func (p *bufferPool) Put(buf *bufio.Reader) {
p.pool.Put(buf)
}
84 changes: 84 additions & 0 deletions pkg/download/buffered_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package download

import (
"bytes"
"io"
"strings"
"sync"
"testing"

"github.com/stretchr/testify/assert"
)

func TestBufferedReaderSerial(t *testing.T) {
pool := newBufferPool(10)
br := newBufferedReader(pool)
n, err := br.ReadFrom(strings.NewReader("foobar"))
assert.NoError(t, err)
assert.Equal(t, int64(6), n)
br.Done()
buf, err := io.ReadAll(br)
assert.NoError(t, err)
assert.Equal(t, "foobar", string(buf))
}

func TestBufferedReaderParallel(t *testing.T) {
pool := newBufferPool(10)
br := newBufferedReader(pool)
wg := new(sync.WaitGroup)
wg.Add(1)
go func() {
defer br.Done()
defer wg.Done()
n, err := br.ReadFrom(strings.NewReader("foobar"))
assert.NoError(t, err)
assert.Equal(t, int64(6), n)
}()
buf, err := io.ReadAll(br)
assert.NoError(t, err)
assert.Equal(t, "foobar", string(buf))
wg.Wait()
}

func TestBufferedReaderReadsWholeChunk(t *testing.T) {
chunkSize := int64(1024 * 1024)
pool := newBufferPool(chunkSize)
br := newBufferedReader(pool)
data := bytes.Repeat([]byte("x"), int(chunkSize))
n64, err := br.ReadFrom(bytes.NewReader(data))
assert.NoError(t, err)
assert.Equal(t, chunkSize, n64)
br.Done()
buf := make([]byte, chunkSize)
// We should only require a single Read() call because all the data should
// be buffered
n, err := br.Read(buf)
assert.NoError(t, err)
assert.Equal(t, data, buf)
assert.Equal(t, int(chunkSize), n)
}

func TestBufferedReaderSubsequentReadsReturnEOF(t *testing.T) {
pool := newBufferPool(10)
br := newBufferedReader(pool)
n64, err := br.ReadFrom(strings.NewReader("foobar"))
assert.NoError(t, err)
assert.Equal(t, int64(6), n64)
br.Done()
buf, err := io.ReadAll(br)
assert.NoError(t, err)
assert.Equal(t, "foobar", string(buf))

n, err := br.Read(buf)
assert.Equal(t, 0, n)
assert.ErrorIs(t, err, io.EOF)
}

func TestBufferedReaderDoneWithoutReadFrom(t *testing.T) {
pool := newBufferPool(10)
br := newBufferedReader(pool)
br.Done()
buf, err := io.ReadAll(br)
assert.NoError(t, err)
assert.Equal(t, 0, len(buf))
}
12 changes: 12 additions & 0 deletions pkg/download/common.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package download

import (
"fmt"
"regexp"

"github.com/dustin/go-humanize"
Expand All @@ -9,3 +10,14 @@ import (
const defaultChunkSize = 125 * humanize.MiByte

var contentRangeRegexp = regexp.MustCompile(`^bytes .*/([0-9]+)$`)

type ErrContentLengthMismatch struct {
contentLength int64
downloadedBytes int64
}

var _ error = ErrContentLengthMismatch{}

func (err ErrContentLengthMismatch) Error() string {
return fmt.Sprintf("Downloaded %d bytes but Content-Length was %d", err.downloadedBytes, err.contentLength)
}
22 changes: 18 additions & 4 deletions pkg/download/consistent_hashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
m.queue.submit(func() {
m.sem.Go(func() error {
defer close(firstReqResultCh)
defer br.done()
defer br.Done()
firstChunkResp, err := m.DoRequest(ctx, 0, m.chunkSize()-1, urlString)
if err != nil {
firstReqResultCh <- firstReqResult{err: err}
Expand All @@ -130,7 +130,14 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
}
firstReqResultCh <- firstReqResult{fileSize: fileSize}

return br.downloadBody(firstChunkResp)
contentLength := firstChunkResp.ContentLength
n, err := br.ReadFrom(firstChunkResp.Body)
if err != nil {
return err
} else if n != contentLength {
return ErrContentLengthMismatch{downloadedBytes: n, contentLength: contentLength}
}
return nil
})
})
firstReqResult, ok := <-firstReqResultCh
Expand Down Expand Up @@ -201,7 +208,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
br := newBufferedReader(m.pool)
readersCh <- br
m.sem.Go(func() error {
defer br.done()
defer br.Done()
logger.Debug().Int64("start", chunkStart).Int64("end", chunkEnd).Msg("starting request")
resp, err := m.DoRequest(ctx, chunkStart, chunkEnd, urlString)
if err != nil {
Expand All @@ -222,7 +229,14 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
}
}
defer resp.Body.Close()
return br.downloadBody(resp)
contentLength := resp.ContentLength
n, err := br.ReadFrom(resp.Body)
if err != nil {
return err
} else if n != contentLength {
return ErrContentLengthMismatch{downloadedBytes: n, contentLength: contentLength}
}
return nil
})

}
Expand Down

0 comments on commit 4ebad73

Please sign in to comment.