diff --git a/pkg/sources/chunker.go b/pkg/sources/chunker.go index 084c6a6982b2..78a26aede71b 100644 --- a/pkg/sources/chunker.go +++ b/pkg/sources/chunker.go @@ -72,11 +72,28 @@ func WithPeekSize(size int) ConfigOption { } } +// ChunkResult is the output unit of a ChunkReader, +// it contains the data and error of a chunk. +type ChunkResult struct { + data []byte + err error +} + +// Bytes for a ChunkResult. +func (cr ChunkResult) Bytes() []byte { + return cr.data +} + +// Error for a ChunkResult. +func (cr ChunkResult) Error() error { + return cr.err +} + // ChunkReader reads chunks from a reader and returns a channel of chunks and a channel of errors. // The channel of chunks is closed when the reader is closed. // This should be used whenever a large amount of data is read from a reader. // Ex: reading attachments, archives, etc. -type ChunkReader func(ctx context.Context, reader io.Reader) (<-chan []byte, <-chan error) +type ChunkReader func(ctx context.Context, reader io.Reader) <-chan ChunkResult // NewChunkReader returns a ChunkReader with the given options. func NewChunkReader(opts ...ConfigOption) ChunkReader { @@ -101,39 +118,43 @@ func applyOptions(opts []ConfigOption) *chunkReaderConfig { } func createReaderFn(config *chunkReaderConfig) ChunkReader { - return func(ctx context.Context, reader io.Reader) (<-chan []byte, <-chan error) { + return func(ctx context.Context, reader io.Reader) <-chan ChunkResult { return readInChunks(ctx, reader, config) } } -func readInChunks(ctx context.Context, reader io.Reader, config *chunkReaderConfig) (<-chan []byte, <-chan error) { +func readInChunks(ctx context.Context, reader io.Reader, config *chunkReaderConfig) <-chan ChunkResult { const channelSize = 1 chunkReader := bufio.NewReaderSize(reader, config.chunkSize) - dataChan := make(chan []byte, channelSize) - errChan := make(chan error, channelSize) + chunkResultChan := make(chan ChunkResult, channelSize) go func() { - defer close(dataChan) - defer close(errChan) + defer close(chunkResultChan) for { + chunkRes := ChunkResult{} chunkBytes := make([]byte, config.totalSize) chunkBytes = chunkBytes[:config.chunkSize] n, err := chunkReader.Read(chunkBytes) if n > 0 { peekData, _ := chunkReader.Peek(config.totalSize - n) chunkBytes = append(chunkBytes[:n], peekData...) - dataChan <- chunkBytes + chunkRes.data = chunkBytes } - if err != nil { - if !errors.Is(err, io.EOF) { + // If there is an error other than EOF, or if we have read some bytes, send the chunk. + if err != nil && !errors.Is(err, io.EOF) || n > 0 { + if err != nil && !errors.Is(err, io.EOF) { ctx.Logger().Error(err, "error reading chunk") - errChan <- err + chunkRes.err = err } + chunkResultChan <- chunkRes + } + + if err != nil { return } } }() - return dataChan, errChan + return chunkResultChan } diff --git a/pkg/sources/chunker_test.go b/pkg/sources/chunker_test.go index ab3fd0ad3293..921aefa4f3cf 100644 --- a/pkg/sources/chunker_test.go +++ b/pkg/sources/chunker_test.go @@ -155,27 +155,44 @@ func TestNewChunkedReader(t *testing.T) { readerFunc := NewChunkReader(WithChunkSize(tt.chunkSize), WithPeekSize(tt.peekSize)) reader := strings.NewReader(tt.input) ctx := context.Background() - dataChan, errChan := readerFunc(ctx, reader) + chunkResChan := readerFunc(ctx, reader) + var err error chunks := make([]string, 0) - for data := range dataChan { - chunks = append(chunks, string(data)) + for data := range chunkResChan { + chunks = append(chunks, string(data.Bytes())) + err = data.Error() } assert.Equal(t, tt.wantChunks, chunks, "Chunks do not match") - - select { - case err := <-errChan: - if tt.wantErr { - assert.Error(t, err, "Expected an error") - } else { - assert.NoError(t, err, "Unexpected error") - } - default: - if tt.wantErr { - assert.Fail(t, "Expected error but got none") - } + if tt.wantErr { + assert.Error(t, err, "Expected an error") + } else { + assert.NoError(t, err, "Unexpected error") } }) } } + +func BenchmarkChunkReader(b *testing.B) { + var bigChunk = make([]byte, 1<<24) // 16MB + + reader := bytes.NewReader(bigChunk) + chunkReader := NewChunkReader(WithChunkSize(ChunkSize), WithPeekSize(PeekSize)) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + b.StartTimer() + chunkResChan := chunkReader(context.Background(), reader) + + // Drain the channel. + for range chunkResChan { + } + + b.StopTimer() + _, err := reader.Seek(0, 0) + assert.Nil(b, err) + } +} diff --git a/pkg/sources/circleci/circleci.go b/pkg/sources/circleci/circleci.go index ebf7b6cf212d..1fb94c1791ce 100644 --- a/pkg/sources/circleci/circleci.go +++ b/pkg/sources/circleci/circleci.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "io" "net/http" "sync/atomic" @@ -219,7 +218,7 @@ func (s *Source) stepsForBuild(_ context.Context, proj project, bld build) ([]bu return bldRes.Steps, nil } -func (s *Source) chunkAction(_ context.Context, proj project, bld build, act action, stepName string, chunksChan chan *sources.Chunk) error { +func (s *Source) chunkAction(ctx context.Context, proj project, bld build, act action, stepName string, chunksChan chan *sources.Chunk) error { req, err := http.NewRequest("GET", act.OutputURL, nil) if err != nil { return err @@ -229,35 +228,40 @@ func (s *Source) chunkAction(_ context.Context, proj project, bld build, act act return err } defer res.Body.Close() - logOutput, err := io.ReadAll(res.Body) - if err != nil { - return err - } linkURL := fmt.Sprintf("https://app.circleci.com/pipelines/%s/%s/%s/%d", proj.VCS, proj.Username, proj.RepoName, bld.BuildNum) - chunk := &sources.Chunk{ - SourceType: s.Type(), - SourceName: s.name, - SourceID: s.SourceID(), - Data: removeCircleSha1Line(logOutput), - SourceMetadata: &source_metadatapb.MetaData{ - Data: &source_metadatapb.MetaData_Circleci{ - Circleci: &source_metadatapb.CircleCI{ - VcsType: proj.VCS, - Username: proj.Username, - Repository: proj.RepoName, - BuildNumber: int64(bld.BuildNum), - BuildStep: stepName, - Link: linkURL, + chunkReader := sources.NewChunkReader() + chunkResChan := chunkReader(ctx, res.Body) + for data := range chunkResChan { + chunk := &sources.Chunk{ + SourceType: s.Type(), + SourceName: s.name, + SourceID: s.SourceID(), + Data: removeCircleSha1Line(data.Bytes()), + SourceMetadata: &source_metadatapb.MetaData{ + Data: &source_metadatapb.MetaData_Circleci{ + Circleci: &source_metadatapb.CircleCI{ + VcsType: proj.VCS, + Username: proj.Username, + Repository: proj.RepoName, + BuildNumber: int64(bld.BuildNum), + BuildStep: stepName, + Link: linkURL, + }, }, }, - }, - Verify: s.verify, + Verify: s.verify, + } + chunk.Data = data.Bytes() + if err := data.Error(); err != nil { + return err + } + if err := common.CancellableWrite(ctx, chunksChan, chunk); err != nil { + return err + } } - chunksChan <- chunk - return nil } diff --git a/pkg/sources/git/git.go b/pkg/sources/git/git.go index f224ef5f72bd..e6ce07376002 100644 --- a/pkg/sources/git/git.go +++ b/pkg/sources/git/git.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "fmt" - "io" "net/url" "os" "os/exec" @@ -26,6 +25,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/gitparse" "github.com/trufflesecurity/trufflehog/v3/pkg/handlers" @@ -930,15 +930,19 @@ func handleBinary(ctx context.Context, repo *git.Repository, chunksChan chan *so } reader.Stop() - chunkData, err := io.ReadAll(reader) - if err != nil { - return err + chunkReader := sources.NewChunkReader() + chunkResChan := chunkReader(ctx, reader) + for data := range chunkResChan { + chunk := *chunkSkel + chunk.Data = data.Bytes() + if err := data.Error(); err != nil { + return err + } + if err := common.CancellableWrite(ctx, chunksChan, &chunk); err != nil { + return err + } } - chunk := *chunkSkel - chunk.Data = chunkData - chunksChan <- &chunk - return nil } diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index b7971ae4ea00..f87bd067a8c4 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -2,7 +2,6 @@ package s3 import ( "fmt" - "io" "strings" "sync" "sync/atomic" @@ -336,16 +335,21 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan reader.Stop() chunk := *chunkSkel - chunkData, err := io.ReadAll(reader) - if err != nil { - s.log.Error(err, "Could not read file data.") - return nil + chunkReader := sources.NewChunkReader() + chunkResChan := chunkReader(ctx, reader) + for data := range chunkResChan { + chunk.Data = data.Bytes() + if err := data.Error(); err != nil { + s.log.Error(err, "error reading chunk.") + continue + } + if err := common.CancellableWrite(ctx, chunksChan, &chunk); err != nil { + return err + } } + atomic.AddUint64(objectCount, 1) s.log.V(5).Info("S3 object scanned.", "object_count", objectCount, "page_number", pageNumber) - chunk.Data = chunkData - chunksChan <- &chunk - nErr, ok = errorCount.Load(prefix) if !ok { nErr = 0