From 5e7a6ca11cca2074d1e5d8a130ad450f671989ed Mon Sep 17 00:00:00 2001 From: ahrav Date: Mon, 31 Jul 2023 11:12:08 -0700 Subject: [PATCH] Concurrent detection (#1580) * Run detection on each chunk concurrently. * Add printer functionality. * Add logic for dedupe. * cleanup. * Moddify number of notifier workers. * Add comment. * move consts into fxn. * buffer resutls chan. * fix test. * address comments. * return an error from Finish. * fix test. * fix test. * linter. * check err. * address comments. --- go.mod | 2 + go.sum | 3 + main.go | 100 ++++---- pkg/common/utils.go | 15 ++ pkg/engine/engine.go | 462 ++++++++++++++++++++++++----------- pkg/engine/gcs_test.go | 13 +- pkg/engine/git_test.go | 37 +-- pkg/output/github_actions.go | 6 +- pkg/output/json.go | 6 +- pkg/output/legacy_json.go | 9 +- pkg/output/plain.go | 6 +- 11 files changed, 440 insertions(+), 219 deletions(-) diff --git a/go.mod b/go.mod index faff10cde342..3b1c07afeed2 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( github.com/google/go-github/v42 v42.0.0 github.com/googleapis/gax-go/v2 v2.12.0 github.com/hashicorp/go-retryablehttp v0.7.4 + github.com/hashicorp/golang-lru v0.5.1 github.com/jlaffaye/ftp v0.2.0 github.com/joho/godotenv v1.5.1 github.com/jpillora/overseer v1.1.6 @@ -128,6 +129,7 @@ require ( github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.4 // indirect github.com/imdario/mergo v0.3.15 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect diff --git a/go.sum b/go.sum index 1185501099c6..c2c276ae8acc 100644 --- a/go.sum +++ b/go.sum @@ -301,7 +301,10 @@ github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9 github.com/hashicorp/go-retryablehttp v0.7.4 h1:ZQgVdpTdAL7WpMIwLzCfbalOcSUdkDZnpUv3/+BxzFA= github.com/hashicorp/go-retryablehttp v0.7.4/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru/v2 v2.0.4 h1:7GHuZcgid37q8o5i3QI9KMT4nCWQQ3Kx3Ov6bb9MfK0= +github.com/hashicorp/golang-lru/v2 v2.0.4/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= diff --git a/main.go b/main.go index d89b977d466a..9a0394b450be 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,6 @@ import ( "strconv" "strings" "syscall" - "time" "github.com/felixge/fgprof" "github.com/go-logr/logr" @@ -316,8 +315,21 @@ func run(state overseer.State) { return true } - e := engine.Start(ctx, - engine.WithConcurrency(*concurrency), + // Set how the engine will print its results. + var printer engine.Printer + switch { + case *jsonLegacy: + printer = new(output.LegacyJSONPrinter) + case *jsonOut: + printer = new(output.JSONPrinter) + case *gitHubActionsFormat: + printer = new(output.GitHubActionsPrinter) + default: + printer = new(output.PlainPrinter) + } + + e, err := engine.Start(ctx, + engine.WithConcurrency(uint8(*concurrency)), engine.WithDecoders(decoders.DefaultDecoders()...), engine.WithDetectors(!*noVerification, engine.DefaultDetectors()...), engine.WithDetectors(!*noVerification, conf.Detectors...), @@ -325,7 +337,13 @@ func run(state overseer.State) { engine.WithFilterDetectors(excludeFilter), engine.WithFilterDetectors(endpointCustomizer), engine.WithFilterUnverified(*filterUnverified), + engine.WithOnlyVerified(*onlyVerified), + engine.WithPrintAvgDetectorTime(*printAvgDetectorTime), + engine.WithPrinter(printer), ) + if err != nil { + logFatal(err, "error initializing engine") + } var repoPath string var remote bool @@ -475,52 +493,48 @@ func run(state overseer.State) { logFatal(err, "Failed to scan Docker.") } } - // asynchronously wait for scanning to finish and cleanup - go e.Finish(ctx, logFatal) if !*jsonLegacy && !*jsonOut { fmt.Fprintf(os.Stderr, "🐷🔑🐷 TruffleHog. Unearth your secrets. 🐷🔑🐷\n\n") } - // NOTE: this loop will terminate when the results channel is closed in - // e.Finish() - foundResults := false - for r := range e.ResultsChan() { - if *onlyVerified && !r.Verified { - continue - } - foundResults = true - - var err error - switch { - case *jsonLegacy: - err = output.PrintLegacyJSON(ctx, &r) - case *jsonOut: - err = output.PrintJSON(&r) - case *gitHubActionsFormat: - err = output.PrintGitHubActionsOutput(&r) - default: - err = output.PrintPlainOutput(&r) - } - if err != nil { - logFatal(err, "error printing results") - } + // Wait for all workers to finish. + if err = e.Finish(ctx); err != nil { + logFatal(err, "engine failed to finish execution") } - logger.V(2).Info("finished scanning", - "chunks", e.ChunksScanned(), - "bytes", e.BytesScanned(), + + metrics := e.GetMetrics() + // Print results. + logger.Info("finished scanning", + "chunks", metrics.ChunksScanned, + "bytes", metrics.BytesScanned, + "verified_secrets", metrics.VerifiedSecretsFound, + "unverified_secrets", metrics.UnverifiedSecretsFound, ) if *printAvgDetectorTime { printAverageDetectorTime(e) } - if foundResults && *fail { + if e.HasFoundResults() && *fail { logger.V(2).Info("exiting with code 183 because results were found") os.Exit(183) } } +// logFatalFunc returns a log.Fatal style function. Calling the returned +// function will terminate the program without cleanup. +func logFatalFunc(logger logr.Logger) func(error, string, ...any) { + return func(err error, message string, keyAndVals ...any) { + logger.Error(err, message, keyAndVals...) + if err != nil { + os.Exit(1) + return + } + os.Exit(0) + } +} + func commaSeparatedToSlice(s []string) []string { var result []string for _, items := range s { @@ -537,26 +551,8 @@ func commaSeparatedToSlice(s []string) []string { func printAverageDetectorTime(e *engine.Engine) { fmt.Fprintln(os.Stderr, "Average detector time is the measurement of average time spent on each detector when results are returned.") - for detectorName, durations := range e.DetectorAvgTime() { - var total time.Duration - for _, d := range durations { - total += d - } - avgDuration := total / time.Duration(len(durations)) - fmt.Fprintf(os.Stderr, "%s: %s\n", detectorName, avgDuration) - } -} - -// logFatalFunc returns a log.Fatal style function. Calling the returned -// function will terminate the program without cleanup. -func logFatalFunc(logger logr.Logger) func(error, string, ...any) { - return func(err error, message string, keyAndVals ...any) { - logger.Error(err, message, keyAndVals...) - if err != nil { - os.Exit(1) - return - } - os.Exit(0) + for detectorName, duration := range e.GetDetectorsMetrics() { + fmt.Fprintf(os.Stderr, "%s: %s\n", detectorName, duration) } } diff --git a/pkg/common/utils.go b/pkg/common/utils.go index 7324b42c875b..314c2351f30e 100644 --- a/pkg/common/utils.go +++ b/pkg/common/utils.go @@ -3,7 +3,9 @@ package common import ( "bufio" "bytes" + "crypto/rand" "io" + "math/big" "strings" ) @@ -49,3 +51,16 @@ func ResponseContainsSubstring(reader io.ReadCloser, target string) (bool, error } return false, nil } + +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + +// RandomID returns a random string of the given length. +func RandomID(length int) string { + b := make([]rune, length) + for i := range b { + randInt, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) + b[i] = letters[randInt.Int64()] + } + + return string(b) +} diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 5119b4dc1557..b4fda8918a02 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -11,6 +11,7 @@ import ( "time" ahocorasick "github.com/BobuSumisu/aho-corasick" + lru "github.com/hashicorp/golang-lru" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" @@ -19,36 +20,76 @@ import ( "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/decoders" "github.com/trufflesecurity/trufflehog/v3/pkg/detectors" + "github.com/trufflesecurity/trufflehog/v3/pkg/output" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" ) -type Engine struct { - concurrency int - chunks chan *sources.Chunk - results chan detectors.ResultWithMetadata - decoders []decoders.Decoder - detectors map[bool][]detectors.Detector - chunksScanned uint64 - bytesScanned uint64 +// Metrics for the scan engine for external consumption. +type Metrics struct { + BytesScanned uint64 + ChunksScanned uint64 + VerifiedSecretsFound uint64 + UnverifiedSecretsFound uint64 + AvgDetectorTime map[string]time.Duration +} + +// runtimeMetrics for the scan engine for internal use by the engine. +type runtimeMetrics struct { + mu sync.RWMutex + Metrics detectorAvgTime sync.Map - sourcesWg *errgroup.Group - workersWg sync.WaitGroup +} + +// Printer is used to format found results and output them to the user. Ex JSON, plain text, etc. +type Printer interface { + Print(ctx context.Context, r *detectors.ResultWithMetadata) error +} + +type Engine struct { + concurrency uint8 + chunks chan *sources.Chunk + results chan detectors.ResultWithMetadata + decoders []decoders.Decoder + detectors map[bool][]detectors.Detector + sourcesWg *errgroup.Group + workersWg sync.WaitGroup // filterUnverified is used to reduce the number of unverified results. // If there are multiple unverified results for the same chunk for the same detector, // only the first one will be kept. - filterUnverified bool + filterUnverified bool + onlyVerified bool + printAvgDetectorTime bool // prefilter is a ahocorasick struct used for doing efficient string // matching given a set of words (keywords from the rules in the config) prefilter ahocorasick.Trie + + detectableChunksChan chan detectableChunk + wgDetectorWorkers sync.WaitGroup + WgNotifier sync.WaitGroup + + // Runtime metrics. + metrics runtimeMetrics + + // numFoundResults is used to keep track of the number of results found. + numFoundResults uint32 + + // printer provides a method for formatting and outputting search results. + // The specific implementation (e.g., JSON, plain text) + // should be set during initialization based on user preference or program requirements. + printer Printer + + // dedupeCache is used to deduplicate results by comparing the + // detector type, raw result, and source metadata + dedupeCache *lru.Cache } type EngineOption func(*Engine) -func WithConcurrency(concurrency int) EngineOption { +func WithConcurrency(concurrency uint8) EngineOption { return func(e *Engine) { e.concurrency = concurrency } @@ -83,6 +124,26 @@ func WithFilterUnverified(filter bool) EngineOption { } } +// WithOnlyVerified sets the onlyVerified flag on the engine. If set to true, +// the engine will only print verified results. +func WithOnlyVerified(onlyVerified bool) EngineOption { + return func(e *Engine) { + e.onlyVerified = onlyVerified + } +} + +// WithPrintAvgDetectorTime sets the printAvgDetectorTime flag on the engine. If set to +// true, the engine will print the average time taken by each detector. +// This option allows us to measure the time taken for each detector ONLY if +// the engine is configured to print the results. +// Calculating the average time taken by each detector is an expensive operation +// and should be avoided unless specified by the user. +func WithPrintAvgDetectorTime(printAvgDetectorTime bool) EngineOption { + return func(e *Engine) { + e.printAvgDetectorTime = printAvgDetectorTime + } +} + // WithFilterDetectors applies a filter to the configured list of detectors. If // the filterFunc returns true, the detector will be included for scanning. // This option applies to the existing list of detectors configured, so the @@ -98,6 +159,13 @@ func WithFilterDetectors(filterFunc func(detectors.Detector) bool) EngineOption } } +// WithPrinter sets the Printer on the engine. +func WithPrinter(printer Printer) EngineOption { + return func(e *Engine) { + e.printer = printer + } +} + func filterDetectors(filterFunc func(detectors.Detector) bool, input []detectors.Detector) []detectors.Detector { var output []detectors.Detector for _, detector := range input { @@ -108,12 +176,96 @@ func filterDetectors(filterFunc func(detectors.Detector) bool, input []detectors return output } -func Start(ctx context.Context, options ...EngineOption) *Engine { +func (e *Engine) setFoundResults() { + atomic.StoreUint32(&e.numFoundResults, 1) +} + +// HasFoundResults returns true if any results are found. +func (e *Engine) HasFoundResults() bool { + return atomic.LoadUint32(&e.numFoundResults) > 0 +} + +// GetMetrics returns a copy of Metrics. +// It's safe for concurrent use, and the caller can't modify the original data. +func (e *Engine) GetMetrics() Metrics { + e.metrics.mu.RLock() + defer e.metrics.mu.RUnlock() + + result := e.metrics.Metrics + result.AvgDetectorTime = make(map[string]time.Duration, len(e.metrics.AvgDetectorTime)) + + for detectorName, durations := range e.DetectorAvgTime() { + var total time.Duration + for _, d := range durations { + total += d + } + avgDuration := total / time.Duration(len(durations)) + result.AvgDetectorTime[detectorName] = avgDuration + } + + return result +} + +// GetDetectorsMetrics returns a copy of the average time taken by each detector. +func (e *Engine) GetDetectorsMetrics() map[string]time.Duration { + e.metrics.mu.RLock() + defer e.metrics.mu.RUnlock() + + result := make(map[string]time.Duration, len(DefaultDetectors())) + for detectorName, durations := range e.DetectorAvgTime() { + var total time.Duration + for _, d := range durations { + total += d + } + avgDuration := total / time.Duration(len(durations)) + result[detectorName] = avgDuration + } + + return result +} + +// DetectorAvgTime returns the average time taken by each detector. +func (e *Engine) DetectorAvgTime() map[string][]time.Duration { + logger := context.Background().Logger() + avgTime := map[string][]time.Duration{} + e.metrics.detectorAvgTime.Range(func(k, v interface{}) bool { + key, ok := k.(string) + if !ok { + logger.Info("expected detectorAvgTime key to be a string") + return true + } + + value, ok := v.([]time.Duration) + if !ok { + logger.Info("expected detectorAvgTime value to be []time.Duration") + return true + } + avgTime[key] = value + return true + }) + return avgTime +} + +// Start the engine with options. +func Start(ctx context.Context, options ...EngineOption) (*Engine, error) { + const ( + defaultChannelBuffer = 1 + // TODO (ahrav): Determine the optimal cache size. + cacheSize = 512 // number of entries in the LRU cache + ) + + cache, err := lru.New(cacheSize) + if err != nil { + return nil, fmt.Errorf("failed to initialize LRU cache: %w", err) + } + e := &Engine{ - chunks: make(chan *sources.Chunk), - results: make(chan detectors.ResultWithMetadata), - detectorAvgTime: sync.Map{}, - sourcesWg: &errgroup.Group{}, + chunks: make(chan *sources.Chunk, defaultChannelBuffer), + detectableChunksChan: make(chan detectableChunk, defaultChannelBuffer), + results: make(chan detectors.ResultWithMetadata, defaultChannelBuffer), + sourcesWg: &errgroup.Group{}, + dedupeCache: cache, + printer: new(output.PlainPrinter), // default printer } for _, option := range options { @@ -124,12 +276,12 @@ func Start(ctx context.Context, options ...EngineOption) *Engine { if e.concurrency == 0 { numCPU := runtime.NumCPU() ctx.Logger().Info("No concurrency specified, defaulting to max", "cpu", numCPU) - e.concurrency = numCPU + e.concurrency = uint8(numCPU) } ctx.Logger().V(3).Info("engine started", "workers", e.concurrency) // Limit number of concurrent goroutines dedicated to chunking a source. - e.sourcesWg.SetLimit(e.concurrency) + e.sourcesWg.SetLimit(int(e.concurrency)) if len(e.decoders) == 0 { e.decoders = decoders.DefaultDecoders() @@ -178,41 +330,68 @@ func Start(ctx context.Context, options ...EngineOption) *Engine { } } - // Start the workers. - for i := 0; i < e.concurrency; i++ { + ctx.Logger().V(2).Info("starting scanner workers", "count", e.concurrency) + // Run the Secret scanner workers and Notifier pipelines. + for worker := uint64(0); worker < uint64(e.concurrency); worker++ { e.workersWg.Add(1) go func() { + ctx := context.WithValue(ctx, "secret_worker_id", common.RandomID(5)) defer common.Recover(ctx) defer e.workersWg.Done() e.detectorWorker(ctx) }() } - return e + const detectorWorkerMultiplier = 50 + ctx.Logger().V(2).Info("starting detector workers", "count", e.concurrency*detectorWorkerMultiplier) + for worker := uint64(0); worker < uint64(e.concurrency*detectorWorkerMultiplier); worker++ { + e.wgDetectorWorkers.Add(1) + go func() { + ctx := context.WithValue(ctx, "detector_worker_id", common.RandomID(5)) + defer common.Recover(ctx) + defer e.wgDetectorWorkers.Done() + e.detectChunks(ctx) + }() + } + + // We want 1/4th of the notifier workers as the number of scanner workers. + const notifierWorkerRatio = 4 + maxNotifierWorkers := 1 + if numWorkers := e.concurrency / notifierWorkerRatio; numWorkers > 0 { + maxNotifierWorkers = int(numWorkers) + } + ctx.Logger().V(2).Info("starting notifier workers", "count", maxNotifierWorkers) + for worker := 0; worker < maxNotifierWorkers; worker++ { + e.WgNotifier.Add(1) + go func() { + ctx := context.WithValue(ctx, "notifier_worker_id", common.RandomID(5)) + defer common.Recover(ctx) + defer e.WgNotifier.Done() + e.notifyResults(ctx) + }() + } + + return e, nil } // Finish waits for running sources to complete and workers to finish scanning // chunks before closing their respective channels. Once Finish is called, no // more sources may be scanned by the engine. -func (e *Engine) Finish(ctx context.Context, logFunc func(error, string, ...any)) { +func (e *Engine) Finish(ctx context.Context) error { defer common.RecoverWithExit(ctx) - // wait for the sources to finish putting chunks onto the chunks channel - sourceErr := e.sourcesWg.Wait() - if sourceErr != nil { - logFunc(sourceErr, "error occurred while collecting chunks") - } + // Wait for the sources to finish putting chunks onto the chunks channel. + err := e.sourcesWg.Wait() + + close(e.chunks) // Source workers are done. - close(e.chunks) - // wait for the workers to finish processing all of the chunks and putting - // results onto the results channel - e.workersWg.Wait() + e.workersWg.Wait() // Wait for the workers to finish scanning chunks. + close(e.detectableChunksChan) + e.wgDetectorWorkers.Wait() // Wait for the detector workers to finish detecting chunks. - // TODO: re-evaluate whether this is needed and investigate why if so - // - // not entirely sure why results don't get processed without this pause - // since we've put all results on the channel at this point. - time.Sleep(time.Second) - close(e.results) + close(e.results) // Detector workers are done, close the results channel and call it a day. + e.WgNotifier.Wait() // Wait for the notifier workers to finish notifying results. + + return err } func (e *Engine) ChunksChan() chan *sources.Chunk { @@ -223,57 +402,21 @@ func (e *Engine) ResultsChan() chan detectors.ResultWithMetadata { return e.results } -func (e *Engine) ChunksScanned() uint64 { - return e.chunksScanned -} - -func (e *Engine) BytesScanned() uint64 { - return e.bytesScanned -} - -func (e *Engine) dedupeAndSend(chunkResults []detectors.ResultWithMetadata) { - dedupeMap := make(map[string]struct{}) - for _, result := range chunkResults { - // dedupe by comparing the detector type, raw result, and source metadata - // NOTE: in order for the PLAIN decoder to maintain precedence, make sure UTF8 is the first decoder in the - // default decoders list - key := fmt.Sprintf("%s%s%s%+v", result.DetectorType.String(), result.Raw, result.RawV2, result.SourceMetadata) - if _, ok := dedupeMap[key]; ok { - continue - } - dedupeMap[key] = struct{}{} - e.results <- result - } - -} - -func (e *Engine) DetectorAvgTime() map[string][]time.Duration { - logger := context.Background().Logger() - avgTime := map[string][]time.Duration{} - e.detectorAvgTime.Range(func(k, v interface{}) bool { - key, ok := k.(string) - if !ok { - logger.Info("expected DetectorAvgTime key to be a string") - return true - } - - value, ok := v.([]time.Duration) - if !ok { - logger.Info("expected DetectorAvgTime value to be []time.Duration") - return true - } - avgTime[key] = value - return true - }) - return avgTime +// detectableChunk is a decoded chunk that is ready to be scanned by its detector. +type detectableChunk struct { + detector detectors.Detector + chunk sources.Chunk + decoder detectorspb.DecoderType + wgDoneFn func() } func (e *Engine) detectorWorker(ctx context.Context) { + var wgDetect sync.WaitGroup + for originalChunk := range e.chunks { for chunk := range sources.Chunker(originalChunk) { - var chunkResults []detectors.ResultWithMetadata matchedKeywords := make(map[string]struct{}) - atomic.AddUint64(&e.bytesScanned, uint64(len(chunk.Data))) + atomic.AddUint64(&e.metrics.BytesScanned, uint64(len(chunk.Data))) for _, decoder := range e.decoders { var decoderType detectorspb.DecoderType switch decoder.(type) { @@ -313,65 +456,110 @@ func (e *Engine) detectorWorker(ctx context.Context) { continue } - start := time.Now() - - results, err := func() ([]detectors.Result, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*10) - defer cancel() - defer common.Recover(ctx) - return detector.FromData(ctx, verify, decoded.Data) - }() - if err != nil { - ctx.Logger().Error(err, "could not scan chunk", - "source_type", decoded.SourceType.String(), - "metadata", decoded.SourceMetadata, - ) - continue - } - - if e.filterUnverified { - results = detectors.CleanResults(results) - } - for _, result := range results { - resultChunk := chunk - ignoreLinePresent := false - if SupportsLineNumbers(chunk.SourceType) { - copyChunk := *chunk - copyMetaDataClone := proto.Clone(chunk.SourceMetadata) - if copyMetaData, ok := copyMetaDataClone.(*source_metadatapb.MetaData); ok { - copyChunk.SourceMetadata = copyMetaData - } - fragStart, mdLine := FragmentFirstLine(©Chunk) - ignoreLinePresent = SetResultLineNumber(©Chunk, &result, fragStart, mdLine) - resultChunk = ©Chunk - } - if ignoreLinePresent { - continue - } - result.DecoderType = decoderType - chunkResults = append(chunkResults, detectors.CopyMetadata(resultChunk, result)) - - } - if len(results) > 0 { - elapsed := time.Since(start) - detectorName := results[0].DetectorType.String() - avgTimeI, ok := e.detectorAvgTime.Load(detectorName) - var avgTime []time.Duration - if ok { - avgTime, ok = avgTimeI.([]time.Duration) - if !ok { - continue - } - } - avgTime = append(avgTime, elapsed) - e.detectorAvgTime.Store(detectorName, avgTime) + decoded.Verify = verify + wgDetect.Add(1) + e.detectableChunksChan <- detectableChunk{ + chunk: *decoded, + detector: detector, + decoder: decoderType, + wgDoneFn: wgDetect.Done, } } } } - e.dedupeAndSend(chunkResults) } - atomic.AddUint64(&e.chunksScanned, 1) + atomic.AddUint64(&e.metrics.ChunksScanned, 1) + } + wgDetect.Wait() +} + +func (e *Engine) detectChunks(ctx context.Context) { + for data := range e.detectableChunksChan { + e.detectChunk(ctx, data) + } +} + +func (e *Engine) detectChunk(ctx context.Context, data detectableChunk) { + var start time.Time + if e.printAvgDetectorTime { + start = time.Now() + } + ctx, cancel := context.WithTimeout(ctx, time.Second*10) + defer common.Recover(ctx) + defer cancel() + + results, err := data.detector.FromData(ctx, data.chunk.Verify, data.chunk.Data) + if err != nil { + ctx.Logger().Error(err, "error scanning chunk") + } + if e.printAvgDetectorTime && len(results) > 0 { + elapsed := time.Since(start) + detectorName := results[0].DetectorType.String() + avgTimeI, ok := e.metrics.detectorAvgTime.Load(detectorName) + var avgTime []time.Duration + if ok { + avgTime, ok = avgTimeI.([]time.Duration) + if !ok { + return + } + } + avgTime = append(avgTime, elapsed) + e.metrics.detectorAvgTime.Store(detectorName, avgTime) + } + + if e.filterUnverified { + results = detectors.CleanResults(results) + } + + for _, res := range results { + e.processResult(data, res) + } + data.wgDoneFn() +} + +func (e *Engine) processResult(data detectableChunk, res detectors.Result) { + ignoreLinePresent := false + if SupportsLineNumbers(data.chunk.SourceType) { + copyChunk := data.chunk + copyMetaDataClone := proto.Clone(data.chunk.SourceMetadata) + if copyMetaData, ok := copyMetaDataClone.(*source_metadatapb.MetaData); ok { + copyChunk.SourceMetadata = copyMetaData + } + fragStart, mdLine := FragmentFirstLine(©Chunk) + ignoreLinePresent = SetResultLineNumber(©Chunk, &res, fragStart, mdLine) + data.chunk = copyChunk + } + if ignoreLinePresent { + return + } + + secret := detectors.CopyMetadata(&data.chunk, res) + secret.DecoderType = data.decoder + e.results <- secret +} + +func (e *Engine) notifyResults(ctx context.Context) { + for r := range e.ResultsChan() { + if e.onlyVerified && !r.Verified { + continue + } + atomic.AddUint32(&e.numFoundResults, 1) + + key := fmt.Sprintf("%s%s%s%+v", r.DetectorType.String(), r.Raw, r.RawV2, r.SourceMetadata) + if _, ok := e.dedupeCache.Get(key); ok { + continue + } + e.dedupeCache.Add(key, struct{}{}) + + if r.Verified { + atomic.AddUint64(&e.metrics.VerifiedSecretsFound, 1) + } else { + atomic.AddUint64(&e.metrics.UnverifiedSecretsFound, 1) + } + + if err := e.printer.Print(ctx, &r); err != nil { + ctx.Logger().Error(err, "error printing result") + } } } diff --git a/pkg/engine/gcs_test.go b/pkg/engine/gcs_test.go index d1b225f91a67..b62783408a9d 100644 --- a/pkg/engine/gcs_test.go +++ b/pkg/engine/gcs_test.go @@ -4,6 +4,8 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/decoders" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" @@ -59,11 +61,13 @@ func TestScanGCS(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) defer cancel() - e := Start(ctx, + e, err := Start(ctx, WithConcurrency(1), WithDecoders(decoders.DefaultDecoders()...), WithDetectors(false, DefaultDetectors()...), ) + assert.Nil(t, err) + go func() { resultCount := 0 for range e.ResultsChan() { @@ -71,15 +75,12 @@ func TestScanGCS(t *testing.T) { } }() - err := e.ScanGCS(ctx, test.gcsConfig) + err = e.ScanGCS(ctx, test.gcsConfig) if err != nil && !test.wantErr && !strings.Contains(err.Error(), "googleapi: Error 400: Bad Request") { t.Errorf("ScanGCS() got: %v, want: %v", err, nil) return } - logFatalFunc := func(_ error, _ string, _ ...any) { - t.Fatalf("error logging function should not have been called") - } - e.Finish(ctx, logFatalFunc) + assert.Nil(t, e.Finish(ctx)) if err == nil && test.wantErr { t.Errorf("ScanGCS() got: %v, want: %v", err, "error") diff --git a/pkg/engine/git_test.go b/pkg/engine/git_test.go index 48ed385f18f7..821df94a27ac 100644 --- a/pkg/engine/git_test.go +++ b/pkg/engine/git_test.go @@ -4,6 +4,8 @@ import ( "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/decoders" @@ -19,6 +21,13 @@ type expResult struct { Verified bool } +type discardPrinter struct{} + +func (p *discardPrinter) Print(context.Context, *detectors.ResultWithMetadata) error { + // This method intentionally does nothing. + return nil +} + func TestGitEngine(t *testing.T) { ctx := context.Background() repoUrl := "https://github.com/dustin-decker/secretsandstuff.git" @@ -56,14 +65,13 @@ func TestGitEngine(t *testing.T) { }, } { t.Run(tName, func(t *testing.T) { - e := Start(ctx, + e, err := Start(ctx, WithConcurrency(1), WithDecoders(decoders.DefaultDecoders()...), WithDetectors(true, DefaultDetectors()...), + WithPrinter(new(discardPrinter)), ) - // Make the channels buffered so Finish returns. - e.chunks = make(chan *sources.Chunk, 10) - e.results = make(chan detectors.ResultWithMetadata, 10) + assert.Nil(t, err) cfg := sources.GitConfig{ RepoPath: path, @@ -76,12 +84,8 @@ func TestGitEngine(t *testing.T) { return } - logFatalFunc := func(_ error, _ string, _ ...any) { - t.Fatalf("error logging function should not have been called") - } // Wait for all the chunks to be processed. - e.Finish(ctx, logFatalFunc) - resultCount := 0 + assert.Nil(t, e.Finish(ctx)) for result := range e.ResultsChan() { switch meta := result.SourceMetadata.GetData().(type) { case *source_metadatapb.MetaData_Git: @@ -95,12 +99,10 @@ func TestGitEngine(t *testing.T) { t.Errorf("%s: unexpected verification. Got: %v, Expected: %v", tName, result.Verified, tTest.expected[meta.Git.Commit].Verified) } } - resultCount++ } - if resultCount != len(tTest.expected) { - t.Errorf("%s: unexpected number of results. Got: %d, Expected: %d", tName, resultCount, len(tTest.expected)) - } + metrics := e.GetMetrics() + assert.Equal(t, len(tTest.expected), int(metrics.VerifiedSecretsFound)+int(metrics.UnverifiedSecretsFound)) }) } } @@ -117,11 +119,13 @@ func BenchmarkGitEngine(b *testing.B) { ctx, cancel := context.WithCancel(ctx) defer cancel() - e := Start(ctx, + e, err := Start(ctx, WithConcurrency(1), WithDecoders(decoders.DefaultDecoders()...), WithDetectors(false, DefaultDetectors()...), ) + assert.Nil(b, err) + go func() { resultCount := 0 for range e.ResultsChan() { @@ -140,8 +144,5 @@ func BenchmarkGitEngine(b *testing.B) { return } } - logFatalFunc := func(_ error, _ string, _ ...any) { - b.Fatalf("error logging function should not have been called") - } - e.Finish(ctx, logFatalFunc) + assert.Nil(b, e.Finish(ctx)) } diff --git a/pkg/output/github_actions.go b/pkg/output/github_actions.go index 11d0e19b810b..077a4da7f8dc 100644 --- a/pkg/output/github_actions.go +++ b/pkg/output/github_actions.go @@ -5,13 +5,17 @@ import ( "encoding/hex" "fmt" + "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/detectors" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb" ) var dedupeCache = make(map[string]struct{}) -func PrintGitHubActionsOutput(r *detectors.ResultWithMetadata) error { +// GitHubActionsPrinter is a printer that prints results in GitHub Actions format. +type GitHubActionsPrinter struct{} + +func (p *GitHubActionsPrinter) Print(_ context.Context, r *detectors.ResultWithMetadata) error { out := gitHubActionsOutputFormat{ DetectorType: r.Result.DetectorType.String(), DecoderType: r.Result.DecoderType.String(), diff --git a/pkg/output/json.go b/pkg/output/json.go index a37c5830e0e4..af15be63bccc 100644 --- a/pkg/output/json.go +++ b/pkg/output/json.go @@ -4,13 +4,17 @@ import ( "encoding/json" "fmt" + "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/detectors" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" ) -func PrintJSON(r *detectors.ResultWithMetadata) error { +// JSONPrinter is a printer that prints results in JSON format. +type JSONPrinter struct{} + +func (p *JSONPrinter) Print(_ context.Context, r *detectors.ResultWithMetadata) error { v := &struct { // SourceMetadata contains source-specific contextual information. SourceMetadata *source_metadatapb.MetaData diff --git a/pkg/output/legacy_json.go b/pkg/output/legacy_json.go index 53f5865b3761..bbeea627a381 100644 --- a/pkg/output/legacy_json.go +++ b/pkg/output/legacy_json.go @@ -19,7 +19,10 @@ import ( "github.com/trufflesecurity/trufflehog/v3/pkg/sources/git" ) -func PrintLegacyJSON(ctx context.Context, r *detectors.ResultWithMetadata) error { +// LegacyJSONPrinter is a printer that prints results in legacy JSON format for backwards compatibility. +type LegacyJSONPrinter struct{} + +func (p *LegacyJSONPrinter) Print(ctx context.Context, r *detectors.ResultWithMetadata) error { var repo string switch r.SourceType { case sourcespb.SourceType_SOURCE_TYPE_GIT: @@ -41,7 +44,7 @@ func PrintLegacyJSON(ctx context.Context, r *detectors.ResultWithMetadata) error defer os.RemoveAll(repoPath) } - legacy, err := ConvertToLegacyJSON(r, repoPath) + legacy, err := convertToLegacyJSON(r, repoPath) if err != nil { return fmt.Errorf("could not convert to legacy JSON: %w", err) } @@ -53,7 +56,7 @@ func PrintLegacyJSON(ctx context.Context, r *detectors.ResultWithMetadata) error return nil } -func ConvertToLegacyJSON(r *detectors.ResultWithMetadata, repoPath string) (*LegacyJSONOutput, error) { +func convertToLegacyJSON(r *detectors.ResultWithMetadata, repoPath string) (*LegacyJSONOutput, error) { var source LegacyJSONCompatibleSource switch r.SourceType { case sourcespb.SourceType_SOURCE_TYPE_GIT: diff --git a/pkg/output/plain.go b/pkg/output/plain.go index ac2ef1c60729..ec70f8ae0b8f 100644 --- a/pkg/output/plain.go +++ b/pkg/output/plain.go @@ -10,6 +10,7 @@ import ( "golang.org/x/text/cases" "golang.org/x/text/language" + "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/detectors" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" ) @@ -20,7 +21,10 @@ var ( whitePrinter = color.New(color.FgWhite) ) -func PrintPlainOutput(r *detectors.ResultWithMetadata) error { +// PlainPrinter is a printer that prints results in plain text format. +type PlainPrinter struct{} + +func (p *PlainPrinter) Print(_ context.Context, r *detectors.ResultWithMetadata) error { out := outputFormat{ DetectorType: r.Result.DetectorType.String(), DecoderType: r.Result.DecoderType.String(),