Skip to content

Commit

Permalink
Use common chunk reader (#1596)
Browse files Browse the repository at this point in the history
* Add common chunker.

* add comment.

* use better config name.

* Add common chunk reader to s3.

* Add common chunk reader to git, gcs, circleci.

* revert gcs.

* revert gcs.

* fix chunker.

* revert gcs.

* update cancellablewrite.

* revert impl.

* update to remove totalsize.

* Fix my goof.

* Use unified struct in chunkreader.

* return err instead of logging and returning.

* rename error to err.

* only send single ChunkResult even if there is an error and chunkBytes.

* fix logic.
  • Loading branch information
ahrav authored Aug 7, 2023
1 parent 18b3d3d commit 1399922
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 67 deletions.
45 changes: 33 additions & 12 deletions pkg/sources/chunker.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,28 @@ func WithPeekSize(size int) ConfigOption {
}
}

// ChunkResult is the output unit of a ChunkReader,
// it contains the data and error of a chunk.
type ChunkResult struct {
data []byte
err error
}

// Bytes for a ChunkResult.
func (cr ChunkResult) Bytes() []byte {
return cr.data
}

// Error for a ChunkResult.
func (cr ChunkResult) Error() error {
return cr.err
}

// ChunkReader reads chunks from a reader and returns a channel of chunks and a channel of errors.
// The channel of chunks is closed when the reader is closed.
// This should be used whenever a large amount of data is read from a reader.
// Ex: reading attachments, archives, etc.
type ChunkReader func(ctx context.Context, reader io.Reader) (<-chan []byte, <-chan error)
type ChunkReader func(ctx context.Context, reader io.Reader) <-chan ChunkResult

// NewChunkReader returns a ChunkReader with the given options.
func NewChunkReader(opts ...ConfigOption) ChunkReader {
Expand All @@ -101,39 +118,43 @@ func applyOptions(opts []ConfigOption) *chunkReaderConfig {
}

func createReaderFn(config *chunkReaderConfig) ChunkReader {
return func(ctx context.Context, reader io.Reader) (<-chan []byte, <-chan error) {
return func(ctx context.Context, reader io.Reader) <-chan ChunkResult {
return readInChunks(ctx, reader, config)
}
}

func readInChunks(ctx context.Context, reader io.Reader, config *chunkReaderConfig) (<-chan []byte, <-chan error) {
func readInChunks(ctx context.Context, reader io.Reader, config *chunkReaderConfig) <-chan ChunkResult {
const channelSize = 1
chunkReader := bufio.NewReaderSize(reader, config.chunkSize)
dataChan := make(chan []byte, channelSize)
errChan := make(chan error, channelSize)
chunkResultChan := make(chan ChunkResult, channelSize)

go func() {
defer close(dataChan)
defer close(errChan)
defer close(chunkResultChan)

for {
chunkRes := ChunkResult{}
chunkBytes := make([]byte, config.totalSize)
chunkBytes = chunkBytes[:config.chunkSize]
n, err := chunkReader.Read(chunkBytes)
if n > 0 {
peekData, _ := chunkReader.Peek(config.totalSize - n)
chunkBytes = append(chunkBytes[:n], peekData...)
dataChan <- chunkBytes
chunkRes.data = chunkBytes
}

if err != nil {
if !errors.Is(err, io.EOF) {
// If there is an error other than EOF, or if we have read some bytes, send the chunk.
if err != nil && !errors.Is(err, io.EOF) || n > 0 {
if err != nil && !errors.Is(err, io.EOF) {
ctx.Logger().Error(err, "error reading chunk")
errChan <- err
chunkRes.err = err
}
chunkResultChan <- chunkRes
}

if err != nil {
return
}
}
}()
return dataChan, errChan
return chunkResultChan
}
47 changes: 32 additions & 15 deletions pkg/sources/chunker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,27 +155,44 @@ func TestNewChunkedReader(t *testing.T) {
readerFunc := NewChunkReader(WithChunkSize(tt.chunkSize), WithPeekSize(tt.peekSize))
reader := strings.NewReader(tt.input)
ctx := context.Background()
dataChan, errChan := readerFunc(ctx, reader)
chunkResChan := readerFunc(ctx, reader)

var err error
chunks := make([]string, 0)
for data := range dataChan {
chunks = append(chunks, string(data))
for data := range chunkResChan {
chunks = append(chunks, string(data.Bytes()))
err = data.Error()
}

assert.Equal(t, tt.wantChunks, chunks, "Chunks do not match")

select {
case err := <-errChan:
if tt.wantErr {
assert.Error(t, err, "Expected an error")
} else {
assert.NoError(t, err, "Unexpected error")
}
default:
if tt.wantErr {
assert.Fail(t, "Expected error but got none")
}
if tt.wantErr {
assert.Error(t, err, "Expected an error")
} else {
assert.NoError(t, err, "Unexpected error")
}
})
}
}

func BenchmarkChunkReader(b *testing.B) {
var bigChunk = make([]byte, 1<<24) // 16MB

reader := bytes.NewReader(bigChunk)
chunkReader := NewChunkReader(WithChunkSize(ChunkSize), WithPeekSize(PeekSize))

b.ReportAllocs()
b.ResetTimer()

for i := 0; i < b.N; i++ {
b.StartTimer()
chunkResChan := chunkReader(context.Background(), reader)

// Drain the channel.
for range chunkResChan {
}

b.StopTimer()
_, err := reader.Seek(0, 0)
assert.Nil(b, err)
}
}
52 changes: 28 additions & 24 deletions pkg/sources/circleci/circleci.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"sync/atomic"

Expand Down Expand Up @@ -219,7 +218,7 @@ func (s *Source) stepsForBuild(_ context.Context, proj project, bld build) ([]bu
return bldRes.Steps, nil
}

func (s *Source) chunkAction(_ context.Context, proj project, bld build, act action, stepName string, chunksChan chan *sources.Chunk) error {
func (s *Source) chunkAction(ctx context.Context, proj project, bld build, act action, stepName string, chunksChan chan *sources.Chunk) error {
req, err := http.NewRequest("GET", act.OutputURL, nil)
if err != nil {
return err
Expand All @@ -229,35 +228,40 @@ func (s *Source) chunkAction(_ context.Context, proj project, bld build, act act
return err
}
defer res.Body.Close()
logOutput, err := io.ReadAll(res.Body)
if err != nil {
return err
}

linkURL := fmt.Sprintf("https://app.circleci.com/pipelines/%s/%s/%s/%d", proj.VCS, proj.Username, proj.RepoName, bld.BuildNum)

chunk := &sources.Chunk{
SourceType: s.Type(),
SourceName: s.name,
SourceID: s.SourceID(),
Data: removeCircleSha1Line(logOutput),
SourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_Circleci{
Circleci: &source_metadatapb.CircleCI{
VcsType: proj.VCS,
Username: proj.Username,
Repository: proj.RepoName,
BuildNumber: int64(bld.BuildNum),
BuildStep: stepName,
Link: linkURL,
chunkReader := sources.NewChunkReader()
chunkResChan := chunkReader(ctx, res.Body)
for data := range chunkResChan {
chunk := &sources.Chunk{
SourceType: s.Type(),
SourceName: s.name,
SourceID: s.SourceID(),
Data: removeCircleSha1Line(data.Bytes()),
SourceMetadata: &source_metadatapb.MetaData{
Data: &source_metadatapb.MetaData_Circleci{
Circleci: &source_metadatapb.CircleCI{
VcsType: proj.VCS,
Username: proj.Username,
Repository: proj.RepoName,
BuildNumber: int64(bld.BuildNum),
BuildStep: stepName,
Link: linkURL,
},
},
},
},
Verify: s.verify,
Verify: s.verify,
}
chunk.Data = data.Bytes()
if err := data.Error(); err != nil {
return err
}
if err := common.CancellableWrite(ctx, chunksChan, chunk); err != nil {
return err
}
}

chunksChan <- chunk

return nil
}

Expand Down
20 changes: 12 additions & 8 deletions pkg/sources/git/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bufio"
"bytes"
"fmt"
"io"
"net/url"
"os"
"os/exec"
Expand All @@ -26,6 +25,7 @@ import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/gitparse"
"github.com/trufflesecurity/trufflehog/v3/pkg/handlers"
Expand Down Expand Up @@ -930,15 +930,19 @@ func handleBinary(ctx context.Context, repo *git.Repository, chunksChan chan *so
}
reader.Stop()

chunkData, err := io.ReadAll(reader)
if err != nil {
return err
chunkReader := sources.NewChunkReader()
chunkResChan := chunkReader(ctx, reader)
for data := range chunkResChan {
chunk := *chunkSkel
chunk.Data = data.Bytes()
if err := data.Error(); err != nil {
return err
}
if err := common.CancellableWrite(ctx, chunksChan, &chunk); err != nil {
return err
}
}

chunk := *chunkSkel
chunk.Data = chunkData
chunksChan <- &chunk

return nil
}

Expand Down
20 changes: 12 additions & 8 deletions pkg/sources/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package s3

import (
"fmt"
"io"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -336,16 +335,21 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan
reader.Stop()

chunk := *chunkSkel
chunkData, err := io.ReadAll(reader)
if err != nil {
s.log.Error(err, "Could not read file data.")
return nil
chunkReader := sources.NewChunkReader()
chunkResChan := chunkReader(ctx, reader)
for data := range chunkResChan {
chunk.Data = data.Bytes()
if err := data.Error(); err != nil {
s.log.Error(err, "error reading chunk.")
continue
}
if err := common.CancellableWrite(ctx, chunksChan, &chunk); err != nil {
return err
}
}

atomic.AddUint64(objectCount, 1)
s.log.V(5).Info("S3 object scanned.", "object_count", objectCount, "page_number", pageNumber)
chunk.Data = chunkData
chunksChan <- &chunk

nErr, ok = errorCount.Load(prefix)
if !ok {
nErr = 0
Expand Down

0 comments on commit 1399922

Please sign in to comment.