Skip to content

Commit

Permalink
[chore] Remove parent manipulation in context package
Browse files Browse the repository at this point in the history
The ability to set the parent allowed creating context cycles which
shouldn't be allowed, or at the very least have unintuitive behavior.
  • Loading branch information
mcastorina committed Jul 21, 2023
1 parent 06a5626 commit b1738d2
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 25 deletions.
17 changes: 3 additions & 14 deletions pkg/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,11 @@ var (
type Context interface {
context.Context
Logger() logr.Logger
Parent() context.Context
SetParent(ctx context.Context) Context
}

// Parent returns the parent context.
func (l logCtx) Parent() context.Context {
return l.Context
}

// SetParent sets the parent context on the context.
func (l logCtx) SetParent(ctx context.Context) Context {
l.Context = ctx
return l
}

type CancelFunc context.CancelFunc
// CancelFunc is a type alias to context.CancelFunc to allow use as if they are
// the same types.
type CancelFunc = context.CancelFunc

// logCtx implements Context.
type logCtx struct {
Expand Down
7 changes: 3 additions & 4 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ func Start(ctx context.Context, options ...EngineOption) *Engine {
chunks: make(chan *sources.Chunk),
results: make(chan detectors.ResultWithMetadata),
detectorAvgTime: sync.Map{},
sourcesWg: &errgroup.Group{},
}

for _, option := range options {
Expand All @@ -127,10 +128,8 @@ func Start(ctx context.Context, options ...EngineOption) *Engine {
}
ctx.Logger().V(2).Info("engine started", "workers", e.concurrency)

sourcesWg, egCtx := errgroup.WithContext(ctx)
sourcesWg.SetLimit(e.concurrency)
e.sourcesWg = sourcesWg
ctx.SetParent(egCtx)
// Limit number of concurrent goroutines dedicated to chunking a source.
e.sourcesWg.SetLimit(e.concurrency)

if len(e.decoders) == 0 {
e.decoders = decoders.DefaultDecoders()
Expand Down
24 changes: 17 additions & 7 deletions pkg/engine/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/decoders"
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/git"
Expand All @@ -15,6 +16,7 @@ import (
type expResult struct {
B string
LineNumber int64
Verified bool
}

func TestGitEngine(t *testing.T) {
Expand All @@ -38,16 +40,16 @@ func TestGitEngine(t *testing.T) {
for tName, tTest := range map[string]testProfile{
"all_secrets": {
expected: map[string]expResult{
"70001020fab32b1fcf2f1f0e5c66424eae649826": {"AKIAXYZDQCEN4B6JSJQI", 2},
"84e9c75e388ae3e866e121087ea2dd45a71068f2": {"AKIAILE3JG6KMS3HZGCA", 4},
"8afb0ecd4998b1179e428db5ebbcdc8221214432": {"369963c1434c377428ca8531fbc46c0c43d037a0", 3},
"27fbead3bf883cdb7de9d7825ed401f28f9398f1": {"ffc7e0f9400fb6300167009e42d2f842cd7956e2", 7},
"70001020fab32b1fcf2f1f0e5c66424eae649826": {"AKIAXYZDQCEN4B6JSJQI", 2, true},
"84e9c75e388ae3e866e121087ea2dd45a71068f2": {"AKIAILE3JG6KMS3HZGCA", 4, true},
"8afb0ecd4998b1179e428db5ebbcdc8221214432": {"369963c1434c377428ca8531fbc46c0c43d037a0", 3, false},
"27fbead3bf883cdb7de9d7825ed401f28f9398f1": {"ffc7e0f9400fb6300167009e42d2f842cd7956e2", 7, false},
},
filter: common.FilterEmpty(),
},
"base_commit": {
expected: map[string]expResult{
"70001020fab32b1fcf2f1f0e5c66424eae649826": {"AKIAXYZDQCEN4B6JSJQI", 2},
"70001020fab32b1fcf2f1f0e5c66424eae649826": {"AKIAXYZDQCEN4B6JSJQI", 2, true},
},
filter: common.FilterEmpty(),
base: "2f251b8c1e72135a375b659951097ec7749d4af9",
Expand All @@ -57,8 +59,12 @@ func TestGitEngine(t *testing.T) {
e := Start(ctx,
WithConcurrency(1),
WithDecoders(decoders.DefaultDecoders()...),
WithDetectors(false, DefaultDetectors()...),
WithDetectors(true, DefaultDetectors()...),
)
// Make the channels buffered so Finish returns.
e.chunks = make(chan *sources.Chunk, 10)
e.results = make(chan detectors.ResultWithMetadata, 10)

cfg := sources.GitConfig{
RepoPath: path,
HeadRef: tTest.branch,
Expand All @@ -73,7 +79,8 @@ func TestGitEngine(t *testing.T) {
logFatalFunc := func(_ error, _ string, _ ...any) {
t.Fatalf("error logging function should not have been called")
}
go e.Finish(ctx, logFatalFunc)
// Wait for all the chunks to be processed.
e.Finish(ctx, logFatalFunc)
resultCount := 0
for result := range e.ResultsChan() {
switch meta := result.SourceMetadata.GetData().(type) {
Expand All @@ -84,6 +91,9 @@ func TestGitEngine(t *testing.T) {
if tTest.expected[meta.Git.Commit].LineNumber != result.SourceMetadata.GetGit().Line {
t.Errorf("%s: unexpected line number. Got: %d, Expected: %d", tName, result.SourceMetadata.GetGit().Line, tTest.expected[meta.Git.Commit].LineNumber)
}
if tTest.expected[meta.Git.Commit].Verified != result.Verified {
t.Errorf("%s: unexpected verification. Got: %v, Expected: %v", tName, result.Verified, tTest.expected[meta.Git.Commit].Verified)
}
}
resultCount++

Expand Down

0 comments on commit b1738d2

Please sign in to comment.