Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common chunk reader #1594

Merged
merged 11 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 105 additions & 11 deletions pkg/sources/chunker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"bytes"
"errors"
"io"

"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)

const (
Expand All @@ -25,28 +27,120 @@ func Chunker(originalChunk *Chunk) chan *Chunk {
chunkChan <- originalChunk
return
}

r := bytes.NewReader(originalChunk.Data)
reader := bufio.NewReaderSize(bufio.NewReader(r), ChunkSize)
for {
chunkBytes := make([]byte, TotalChunkSize)
chunk := *originalChunk
chunkBytes = chunkBytes[:ChunkSize]
n, err := reader.Read(chunkBytes)
if err != nil && !errors.Is(err, io.EOF) {
if n > 0 {
peekData, _ := reader.Peek(TotalChunkSize - n)
chunkBytes = append(chunkBytes[:n], peekData...)
chunk.Data = chunkBytes
chunkChan <- &chunk
}
if err != nil {
break
}
if n == 0 {
if errors.Is(err, io.EOF) {
break
}
continue
}
}()
return chunkChan
}

type chunkReaderConfig struct {
chunkSize int
totalSize int
peekSize int
}

// ConfigOption is a function that configures a chunker.
type ConfigOption func(*chunkReaderConfig)

// WithChunkSize sets the chunk size.
func WithChunkSize(size int) ConfigOption {
return func(c *chunkReaderConfig) {
c.chunkSize = size
}
}

// WithTotalChunkSize sets the total chunk size.
// This is the chunk size plus the peek size.
func WithTotalChunkSize(size int) ConfigOption {
return func(c *chunkReaderConfig) {
c.totalSize = size
}
}

// WithPeekSize sets the peek size.
func WithPeekSize(size int) ConfigOption {
return func(c *chunkReaderConfig) {
c.peekSize = size
}
}

// ChunkReader reads chunks from a reader and returns a channel of chunks and a channel of errors.
// The channel of chunks is closed when the reader is closed.
// This should be used whenever a large amount of data is read from a reader.
// Ex: reading attachments, archives, etc.
type ChunkReader func(ctx context.Context, reader io.Reader) (<-chan []byte, <-chan error)

// NewChunkReader returns a ChunkReader with the given options.
func NewChunkReader(opts ...ConfigOption) ChunkReader {
config := applyOptions(opts)
return createReaderFn(config)
}

func applyOptions(opts []ConfigOption) *chunkReaderConfig {
// Set defaults.
config := &chunkReaderConfig{
chunkSize: ChunkSize, // default
totalSize: TotalChunkSize, // default
peekSize: PeekSize, // default
}

for _, opt := range opts {
opt(config)
}

return config
}

func createReaderFn(config *chunkReaderConfig) ChunkReader {
return func(ctx context.Context, reader io.Reader) (<-chan []byte, <-chan error) {
return readInChunks(ctx, reader, config)
}
}

func readInChunks(ctx context.Context, reader io.Reader, config *chunkReaderConfig) (<-chan []byte, <-chan error) {
const channelSize = 1
chunkReader := bufio.NewReaderSize(reader, config.chunkSize)
dataChan := make(chan []byte, channelSize)
errChan := make(chan error, channelSize)

go func() {
defer close(dataChan)
defer close(errChan)

for {
chunkBytes := make([]byte, config.totalSize)
chunkBytes = chunkBytes[:config.chunkSize]
n, err := chunkReader.Read(chunkBytes)
if n > 0 {
peekData, _ := chunkReader.Peek(config.totalSize - n)
chunkBytes = append(chunkBytes[:n], peekData...)
dataChan <- chunkBytes
}
peekData, _ := reader.Peek(PeekSize)
copy(chunkBytes[n:], peekData)
chunk.Data = chunkBytes[:n+len(peekData)]

chunkChan <- &chunk
if err != nil {
if !errors.Is(err, io.EOF) {
ctx.Logger().Error(err, "error reading chunk")
errChan <- err
}
return
}
}
}()
return chunkChan
return dataChan, errChan
}
195 changes: 150 additions & 45 deletions pkg/sources/chunker_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package sources

import (
"bufio"
"bytes"
"errors"
"io"
"strings"
"testing"

diskbufferreader "github.com/bill-rich/disk-buffer-reader"
"github.com/stretchr/testify/assert"

"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)

func TestChunker(t *testing.T) {
Expand All @@ -18,59 +20,39 @@ func TestChunker(t *testing.T) {
}
defer reReader.Close()

baseChunkCount := 0
baseChunks := make([]*Chunk, 0, 9)

// Count chunks from looping using chunk size.
for {
tmpChunk := make([]byte, ChunkSize)
_, err := reReader.Read(tmpChunk)
if err != nil {
if errors.Is(err, io.EOF) {
break
}
t.Fatal(err)
}
baseChunkCount++
chunkData, _ := io.ReadAll(reReader)
originalChunk := &Chunk{
Data: chunkData,
}
_ = reReader.Reset()

// Get the first two chunks for comparing later.
baseChunkOne := make([]byte, ChunkSize)
baseChunkTwo := make([]byte, ChunkSize)

baseReader := bufio.NewReaderSize(reReader, ChunkSize)
_, _ = baseReader.Read(baseChunkOne)
peek, _ := baseReader.Peek(PeekSize)
baseChunkOne = append(baseChunkOne, peek...)
_, _ = baseReader.Read(baseChunkTwo)
peek, _ = baseReader.Peek(PeekSize)
baseChunkTwo = append(baseChunkTwo, peek...)
for chunk := range Chunker(originalChunk) {
baseChunks = append(baseChunks, chunk)
}

// Reset the reader to the beginning and use ChunkReader.
_ = reReader.Reset()

testChunkCount := 0
chunkData, _ := io.ReadAll(reReader)
originalChunk := &Chunk{
Data: chunkData,
testChunks := make([]*Chunk, 0, 9)

testData, _ := io.ReadAll(reReader)
testOriginalChunk := &Chunk{
Data: testData,
}
for chunk := range Chunker(originalChunk) {
testChunkCount++
switch testChunkCount {
case 1:
if !bytes.Equal(baseChunkOne, chunk.Data) {
t.Errorf("First chunk did not match expected. Got: %d bytes, expected: %d bytes", len(chunk.Data), len(baseChunkOne))
}
case 2:
if !bytes.Equal(baseChunkTwo, chunk.Data) {
t.Errorf("Second chunk did not match expected. Got: %d bytes, expected: %d bytes", len(chunk.Data), len(baseChunkTwo))
}
}

for chunk := range Chunker(testOriginalChunk) {
testChunks = append(testChunks, chunk)
}
if testChunkCount != baseChunkCount {
t.Errorf("Wrong number of chunks received. Got %d, expected: %d.", testChunkCount, baseChunkCount)

if len(testChunks) != len(baseChunks) {
t.Errorf("Wrong number of chunks received. Got %d, expected: %d.", len(testChunks), len(baseChunks))
}

for i, baseChunk := range baseChunks {
if !bytes.Equal(baseChunk.Data, testChunks[i].Data) {
t.Errorf("Chunk %d did not match expected. Got: %d bytes, expected: %d bytes", i+1, len(testChunks[i].Data), len(baseChunk.Data))
}
}
}

func BenchmarkChunker(b *testing.B) {
Expand All @@ -84,3 +66,126 @@ func BenchmarkChunker(b *testing.B) {
}
}
}

func TestNewChunkedReader(t *testing.T) {
tests := []struct {
name string
input string
chunkSize int
peekSize int
totalSize int
wantChunks []string
wantErr bool
}{
{
name: "Smaller data than default chunkSize and peekSize",
input: "example input",
chunkSize: ChunkSize,
peekSize: PeekSize,
totalSize: TotalChunkSize,
wantChunks: []string{"example input"},
wantErr: false,
},
{
name: "Reader with no data",
input: "",
chunkSize: 1024,
peekSize: 512,
totalSize: 1024 + 512,
wantChunks: []string{},
wantErr: false,
},
{
name: "Smaller data than chunkSize and peekSize",
input: "small data",
chunkSize: 1024,
peekSize: 512,
totalSize: 1024 + 512,
wantChunks: []string{"small data"},
wantErr: false,
},
{
name: "Equal to chunkSize",
input: strings.Repeat("a", 1024),
chunkSize: 1024,
peekSize: 512,
totalSize: 1024 + 512,
wantChunks: []string{strings.Repeat("a", 1024)},
wantErr: false,
},
{
name: "Equal to chunkSize + peekSize",
input: strings.Repeat("a", 1536),
chunkSize: 1024,
peekSize: 512,
totalSize: 1024 + 512,
wantChunks: []string{strings.Repeat("a", 1024), strings.Repeat("a", 512)},
wantErr: false,
},
{
name: "EOF during peeking",
input: strings.Repeat("a", 1300),
chunkSize: 1024,
peekSize: 512,
totalSize: 1024 + 512,
wantChunks: []string{strings.Repeat("a", 1024), strings.Repeat("a", 276)},
wantErr: false,
},
{
name: "EOF during reading",
input: strings.Repeat("a", 512),
chunkSize: 1024,
peekSize: 512,
totalSize: 1024 + 512,
wantChunks: []string{strings.Repeat("a", 512)},
wantErr: false,
},
{
name: "Equal to totalSize",
input: strings.Repeat("a", 2048),
chunkSize: 1024,
peekSize: 1024,
totalSize: 1024 + 1024,
wantChunks: []string{strings.Repeat("a", 2048), strings.Repeat("a", 1024)},
wantErr: false,
},
{
name: "Larger than totalSize",
input: strings.Repeat("a", 4096),
chunkSize: 1024,
peekSize: 1024,
totalSize: 1024 + 1024,
wantChunks: []string{strings.Repeat("a", 2048), strings.Repeat("a", 2048), strings.Repeat("a", 2048), strings.Repeat("a", 1024)},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
readerFunc := NewChunkReader(WithChunkSize(tt.chunkSize), WithPeekSize(tt.peekSize), WithTotalChunkSize(tt.totalSize))
reader := strings.NewReader(tt.input)
ctx := context.Background()
dataChan, errChan := readerFunc(ctx, reader)

chunks := make([]string, 0)
for data := range dataChan {
chunks = append(chunks, string(data))
}

assert.Equal(t, tt.wantChunks, chunks, "Chunks do not match")

select {
case err := <-errChan:
if tt.wantErr {
assert.Error(t, err, "Expected an error")
} else {
assert.NoError(t, err, "Unexpected error")
}
default:
if tt.wantErr {
assert.Fail(t, "Expected error but got none")
}
}
})
}
}
Loading
Loading