diff --git a/pkg/context/context.go b/pkg/context/context.go index ef0df1086825..330290808ae6 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -18,22 +18,11 @@ var ( type Context interface { context.Context Logger() logr.Logger - Parent() context.Context - SetParent(ctx context.Context) Context } -// Parent returns the parent context. -func (l logCtx) Parent() context.Context { - return l.Context -} - -// SetParent sets the parent context on the context. -func (l logCtx) SetParent(ctx context.Context) Context { - l.Context = ctx - return l -} - -type CancelFunc context.CancelFunc +// CancelFunc is a type alias to context.CancelFunc to allow use as if they are +// the same types. +type CancelFunc = context.CancelFunc // logCtx implements Context. type logCtx struct { diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 9c95ba745d64..627787c22d29 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -113,6 +113,7 @@ func Start(ctx context.Context, options ...EngineOption) *Engine { chunks: make(chan *sources.Chunk), results: make(chan detectors.ResultWithMetadata), detectorAvgTime: sync.Map{}, + sourcesWg: &errgroup.Group{}, } for _, option := range options { @@ -127,10 +128,8 @@ func Start(ctx context.Context, options ...EngineOption) *Engine { } ctx.Logger().V(2).Info("engine started", "workers", e.concurrency) - sourcesWg, egCtx := errgroup.WithContext(ctx) - sourcesWg.SetLimit(e.concurrency) - e.sourcesWg = sourcesWg - ctx.SetParent(egCtx) + // Limit number of concurrent goroutines dedicated to chunking a source. + e.sourcesWg.SetLimit(e.concurrency) if len(e.decoders) == 0 { e.decoders = decoders.DefaultDecoders() diff --git a/pkg/engine/git_test.go b/pkg/engine/git_test.go index 036172930fa5..48ed385f18f7 100644 --- a/pkg/engine/git_test.go +++ b/pkg/engine/git_test.go @@ -7,6 +7,7 @@ import ( "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/decoders" + "github.com/trufflesecurity/trufflehog/v3/pkg/detectors" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sources/git" @@ -15,6 +16,7 @@ import ( type expResult struct { B string LineNumber int64 + Verified bool } func TestGitEngine(t *testing.T) { @@ -38,16 +40,16 @@ func TestGitEngine(t *testing.T) { for tName, tTest := range map[string]testProfile{ "all_secrets": { expected: map[string]expResult{ - "70001020fab32b1fcf2f1f0e5c66424eae649826": {"AKIAXYZDQCEN4B6JSJQI", 2}, - "84e9c75e388ae3e866e121087ea2dd45a71068f2": {"AKIAILE3JG6KMS3HZGCA", 4}, - "8afb0ecd4998b1179e428db5ebbcdc8221214432": {"369963c1434c377428ca8531fbc46c0c43d037a0", 3}, - "27fbead3bf883cdb7de9d7825ed401f28f9398f1": {"ffc7e0f9400fb6300167009e42d2f842cd7956e2", 7}, + "70001020fab32b1fcf2f1f0e5c66424eae649826": {"AKIAXYZDQCEN4B6JSJQI", 2, true}, + "84e9c75e388ae3e866e121087ea2dd45a71068f2": {"AKIAILE3JG6KMS3HZGCA", 4, true}, + "8afb0ecd4998b1179e428db5ebbcdc8221214432": {"369963c1434c377428ca8531fbc46c0c43d037a0", 3, false}, + "27fbead3bf883cdb7de9d7825ed401f28f9398f1": {"ffc7e0f9400fb6300167009e42d2f842cd7956e2", 7, false}, }, filter: common.FilterEmpty(), }, "base_commit": { expected: map[string]expResult{ - "70001020fab32b1fcf2f1f0e5c66424eae649826": {"AKIAXYZDQCEN4B6JSJQI", 2}, + "70001020fab32b1fcf2f1f0e5c66424eae649826": {"AKIAXYZDQCEN4B6JSJQI", 2, true}, }, filter: common.FilterEmpty(), base: "2f251b8c1e72135a375b659951097ec7749d4af9", @@ -57,8 +59,12 @@ func TestGitEngine(t *testing.T) { e := Start(ctx, WithConcurrency(1), WithDecoders(decoders.DefaultDecoders()...), - WithDetectors(false, DefaultDetectors()...), + WithDetectors(true, DefaultDetectors()...), ) + // Make the channels buffered so Finish returns. + e.chunks = make(chan *sources.Chunk, 10) + e.results = make(chan detectors.ResultWithMetadata, 10) + cfg := sources.GitConfig{ RepoPath: path, HeadRef: tTest.branch, @@ -73,7 +79,8 @@ func TestGitEngine(t *testing.T) { logFatalFunc := func(_ error, _ string, _ ...any) { t.Fatalf("error logging function should not have been called") } - go e.Finish(ctx, logFatalFunc) + // Wait for all the chunks to be processed. + e.Finish(ctx, logFatalFunc) resultCount := 0 for result := range e.ResultsChan() { switch meta := result.SourceMetadata.GetData().(type) { @@ -84,6 +91,9 @@ func TestGitEngine(t *testing.T) { if tTest.expected[meta.Git.Commit].LineNumber != result.SourceMetadata.GetGit().Line { t.Errorf("%s: unexpected line number. Got: %d, Expected: %d", tName, result.SourceMetadata.GetGit().Line, tTest.expected[meta.Git.Commit].LineNumber) } + if tTest.expected[meta.Git.Commit].Verified != result.Verified { + t.Errorf("%s: unexpected verification. Got: %v, Expected: %v", tName, result.Verified, tTest.expected[meta.Git.Commit].Verified) + } } resultCount++