From 1545d65f0a7013808f3ff7e208186da928d3f637 Mon Sep 17 00:00:00 2001 From: Ahrav Dutta Date: Mon, 31 Jul 2023 13:34:11 -0700 Subject: [PATCH] Fix race. --- hack/snifftest/main.go | 2 +- pkg/engine/engine.go | 4 ---- pkg/engine/git.go | 2 +- pkg/engine/github.go | 2 +- pkg/sources/git/git.go | 19 ++++++++----------- pkg/sources/git/git_test.go | 2 +- pkg/sources/git/scan_options.go | 1 + pkg/sources/github/github.go | 20 +++++--------------- pkg/sources/github/repo.go | 2 +- pkg/sources/gitlab/gitlab.go | 4 ++-- pkg/sources/gitlab/gitlab_test.go | 2 +- 11 files changed, 22 insertions(+), 38 deletions(-) diff --git a/hack/snifftest/main.go b/hack/snifftest/main.go index 6d8f704b2db6..7cc5e5b0f4c2 100644 --- a/hack/snifftest/main.go +++ b/hack/snifftest/main.go @@ -204,7 +204,7 @@ func main() { }) logger.Info("scanning repo", "repo", r) - err = s.ScanRepo(ctx, repo, path, git.NewScanOptions(), chunksChan) + err = s.ScanRepo(ctx, repo, path, *git.NewScanOptions(), chunksChan) if err != nil { logFatal(err, "error scanning repo") } diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index b4fda8918a02..68c912d17722 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -176,10 +176,6 @@ func filterDetectors(filterFunc func(detectors.Detector) bool, input []detectors return output } -func (e *Engine) setFoundResults() { - atomic.StoreUint32(&e.numFoundResults, 1) -} - // HasFoundResults returns true if any results are found. func (e *Engine) HasFoundResults() bool { return atomic.LoadUint32(&e.numFoundResults) > 0 diff --git a/pkg/engine/git.go b/pkg/engine/git.go index 370b4642a19b..229bd020ac9b 100644 --- a/pkg/engine/git.go +++ b/pkg/engine/git.go @@ -44,7 +44,7 @@ func (e *Engine) ScanGit(ctx context.Context, c sources.GitConfig) error { if c.ExcludeGlobs != nil { opts = append(opts, git.ScanOptionExcludeGlobs(c.ExcludeGlobs)) } - scanOptions := git.NewScanOptions(opts...) + scanOptions := *git.NewScanOptions(opts...) gitSource := git.NewGit(sourcespb.SourceType_SOURCE_TYPE_GIT, 0, 0, "trufflehog - git", true, runtime.NumCPU(), func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData { diff --git a/pkg/engine/github.go b/pkg/engine/github.go index 79d34d76806f..5797d8913130 100644 --- a/pkg/engine/github.go +++ b/pkg/engine/github.go @@ -57,7 +57,7 @@ func (e *Engine) ScanGitHub(ctx context.Context, c sources.GithubConfig) error { git.ScanOptionFilter(c.Filter), git.ScanOptionLogOptions(logOptions), } - scanOptions := git.NewScanOptions(opts...) + scanOptions := *git.NewScanOptions(opts...) source.WithScanOptions(scanOptions) e.sourcesWg.Go(func() error { diff --git a/pkg/sources/git/git.go b/pkg/sources/git/git.go index 80417b50a27e..8724d339a633 100644 --- a/pkg/sources/git/git.go +++ b/pkg/sources/git/git.go @@ -153,7 +153,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err if err != nil { return err } - return s.git.ScanRepo(ctx, repo, path, NewScanOptions(), chunksChan) + return s.git.ScanRepo(ctx, repo, path, *NewScanOptions(), chunksChan) }(repoURI) if err != nil { ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err) @@ -172,7 +172,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err if err != nil { return err } - return s.git.ScanRepo(ctx, repo, path, NewScanOptions(), chunksChan) + return s.git.ScanRepo(ctx, repo, path, *NewScanOptions(), chunksChan) }(repoURI) if err != nil { ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err) @@ -191,7 +191,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err if err != nil { return err } - return s.git.ScanRepo(ctx, repo, path, NewScanOptions(), chunksChan) + return s.git.ScanRepo(ctx, repo, path, *NewScanOptions(), chunksChan) }(repoURI) if err != nil { ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err) @@ -221,7 +221,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err defer os.RemoveAll(repoPath) } - return s.git.ScanRepo(ctx, repo, repoPath, NewScanOptions(), chunksChan) + return s.git.ScanRepo(ctx, repo, repoPath, *NewScanOptions(), chunksChan) }(gitDir) if err != nil { ctx.Logger().Info("error scanning repository", "repo", gitDir, "error", err) @@ -349,7 +349,7 @@ func (s *Git) CommitsScanned() uint64 { return atomic.LoadUint64(&s.metrics.commitsScanned) } -func (s *Git) ScanCommits(ctx context.Context, repo *git.Repository, path string, scanOptions *ScanOptions, chunksChan chan *sources.Chunk) error { +func (s *Git) ScanCommits(ctx context.Context, repo *git.Repository, path string, scanOptions ScanOptions, chunksChan chan *sources.Chunk) error { if err := GitCmdCheck(); err != nil { return err } @@ -490,7 +490,7 @@ func (s *Git) gitChunk(ctx context.Context, diff gitparse.Diff, fileName, email, } // ScanStaged chunks staged changes. -func (s *Git) ScanStaged(ctx context.Context, repo *git.Repository, path string, scanOptions *ScanOptions, chunksChan chan *sources.Chunk) error { +func (s *Git) ScanStaged(ctx context.Context, repo *git.Repository, path string, scanOptions ScanOptions, chunksChan chan *sources.Chunk) error { // Get the URL metadata for reporting (may be empty). urlMetadata := getSafeRemoteURL(repo, "origin") @@ -570,10 +570,7 @@ func (s *Git) ScanStaged(ctx context.Context, repo *git.Repository, path string, return nil } -func (s *Git) ScanRepo(ctx context.Context, repo *git.Repository, repoPath string, scanOptions *ScanOptions, chunksChan chan *sources.Chunk) error { - if scanOptions == nil { - scanOptions = NewScanOptions() - } +func (s *Git) ScanRepo(ctx context.Context, repo *git.Repository, repoPath string, scanOptions ScanOptions, chunksChan chan *sources.Chunk) error { if err := normalizeConfig(scanOptions, repo); err != nil { return err } @@ -599,7 +596,7 @@ func (s *Git) ScanRepo(ctx context.Context, repo *git.Repository, repoPath strin return nil } -func normalizeConfig(scanOptions *ScanOptions, repo *git.Repository) (err error) { +func normalizeConfig(scanOptions ScanOptions, repo *git.Repository) (err error) { var baseCommit *object.Commit if len(scanOptions.BaseHash) > 0 { baseHash := plumbing.NewHash(scanOptions.BaseHash) diff --git a/pkg/sources/git/git_test.go b/pkg/sources/git/git_test.go index d8d54adfd603..172546843dbc 100644 --- a/pkg/sources/git/git_test.go +++ b/pkg/sources/git/git_test.go @@ -285,7 +285,7 @@ func TestSource_Chunks_Integration(t *testing.T) { if err != nil { panic(err) } - err = s.git.ScanRepo(ctx, repo, repoPath, &tt.scanOptions, chunksCh) + err = s.git.ScanRepo(ctx, repo, repoPath, tt.scanOptions, chunksCh) if err != nil { panic(err) } diff --git a/pkg/sources/git/scan_options.go b/pkg/sources/git/scan_options.go index 09572a6fb408..bc4a7be79c1a 100644 --- a/pkg/sources/git/scan_options.go +++ b/pkg/sources/git/scan_options.go @@ -2,6 +2,7 @@ package git import ( "github.com/go-git/go-git/v5" + "github.com/trufflesecurity/trufflehog/v3/pkg/common" ) diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index c4b0df8517ff..9d86dcce4480 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -62,8 +62,7 @@ type Source struct { totalRepoSize int // total size in bytes of all repos git *git.Git - scanOptMu sync.Mutex // protects the scanOptions - scanOptions *git.ScanOptions + scanOptions git.ScanOptions httpClient *http.Client log logr.Logger @@ -83,17 +82,10 @@ type Source struct { sources.CommonSourceUnitUnmarshaller } -func (s *Source) WithScanOptions(scanOptions *git.ScanOptions) { +func (s *Source) WithScanOptions(scanOptions git.ScanOptions) { s.scanOptions = scanOptions } -func (s *Source) setScanOptions(base, head string) { - s.scanOptMu.Lock() - defer s.scanOptMu.Unlock() - s.scanOptions.BaseHash = base - s.scanOptions.HeadHash = head -} - // Ensure the Source satisfies the interfaces at compile time var _ sources.Source = (*Source)(nil) var _ sources.SourceUnitUnmarshaller = (*Source)(nil) @@ -664,9 +656,6 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch scanErrs := sources.NewScanErrors() // Setup scan options if it wasn't provided. - if s.scanOptions == nil { - s.scanOptions = &git.ScanOptions{} - } for i, repoURL := range s.repos { i, repoURL := i, repoURL @@ -707,7 +696,8 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch return nil } - s.setScanOptions(s.conn.Base, s.conn.Head) + s.scanOptions.HeadHash = s.conn.Head + s.scanOptions.BaseHash = s.conn.Base repoSize := s.repoSizes.getRepo(repoURL) logger.V(2).Info(fmt.Sprintf("scanning repo %d/%d", i, len(s.repos)), "repo_size", repoSize) @@ -1186,7 +1176,7 @@ func (s *Source) chunkGistComments(ctx context.Context, gistUrl string, comments Timestamp: sanitizer.UTF8(comment.GetCreatedAt().String()), // Fetching this information requires making an additional API call. // We may want to include this in the future. - //Visibility: s.visibilityOf(ctx, repoPath), + // Visibility: s.visibilityOf(ctx, repoPath), }, }, }, diff --git a/pkg/sources/github/repo.go b/pkg/sources/github/repo.go index c0c7709d0055..502a897970cd 100644 --- a/pkg/sources/github/repo.go +++ b/pkg/sources/github/repo.go @@ -65,7 +65,7 @@ func (s *Source) cloneRepo( func (s *Source) getUserAndToken(ctx context.Context, repoURL string, installationClient *github.Client) error { // We never refresh user provided tokens, so if we already have them, we never need to try and fetch them again. s.userMu.Lock() - defer s.mu.Unlock() + defer s.userMu.Unlock() if s.githubUser == "" || s.githubToken == "" { var err error s.githubUser, s.githubToken, err = s.userAndToken(ctx, installationClient) diff --git a/pkg/sources/gitlab/gitlab.go b/pkg/sources/gitlab/gitlab.go index 9ce1057ff6ae..5581de19790a 100644 --- a/pkg/sources/gitlab/gitlab.go +++ b/pkg/sources/gitlab/gitlab.go @@ -42,7 +42,7 @@ type Source struct { repos []string ignoreRepos []string git *git.Git - scanOptions *git.ScanOptions + scanOptions git.ScanOptions resumeInfoSlice []string resumeInfoMutex sync.Mutex sources.Progress @@ -443,5 +443,5 @@ func (s *Source) setProgressCompleteWithRepo(index int, offset int, repoURL stri } func (s *Source) WithScanOptions(scanOptions *git.ScanOptions) { - s.scanOptions = scanOptions + s.scanOptions = *scanOptions } diff --git a/pkg/sources/gitlab/gitlab_test.go b/pkg/sources/gitlab/gitlab_test.go index 68088c52210e..523f00a9e466 100644 --- a/pkg/sources/gitlab/gitlab_test.go +++ b/pkg/sources/gitlab/gitlab_test.go @@ -285,7 +285,7 @@ func Test_scanRepos_SetProgressComplete(t *testing.T) { repos: tc.repos, } src.jobPool = &errgroup.Group{} - src.scanOptions = &git.ScanOptions{} + src.scanOptions = git.ScanOptions{} _ = src.scanRepos(context.Background(), nil) if !tc.wantErr {