Skip to content

Commit

Permalink
Fix race.
Browse files Browse the repository at this point in the history
  • Loading branch information
ahrav committed Jul 31, 2023
1 parent 661c6b4 commit 1545d65
Show file tree
Hide file tree
Showing 11 changed files with 22 additions and 38 deletions.
2 changes: 1 addition & 1 deletion hack/snifftest/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func main() {
})

logger.Info("scanning repo", "repo", r)
err = s.ScanRepo(ctx, repo, path, git.NewScanOptions(), chunksChan)
err = s.ScanRepo(ctx, repo, path, *git.NewScanOptions(), chunksChan)
if err != nil {
logFatal(err, "error scanning repo")
}
Expand Down
4 changes: 0 additions & 4 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,6 @@ func filterDetectors(filterFunc func(detectors.Detector) bool, input []detectors
return output
}

func (e *Engine) setFoundResults() {
atomic.StoreUint32(&e.numFoundResults, 1)
}

// HasFoundResults returns true if any results are found.
func (e *Engine) HasFoundResults() bool {
return atomic.LoadUint32(&e.numFoundResults) > 0
Expand Down
2 changes: 1 addition & 1 deletion pkg/engine/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (e *Engine) ScanGit(ctx context.Context, c sources.GitConfig) error {
if c.ExcludeGlobs != nil {
opts = append(opts, git.ScanOptionExcludeGlobs(c.ExcludeGlobs))
}
scanOptions := git.NewScanOptions(opts...)
scanOptions := *git.NewScanOptions(opts...)

gitSource := git.NewGit(sourcespb.SourceType_SOURCE_TYPE_GIT, 0, 0, "trufflehog - git", true, runtime.NumCPU(),
func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData {
Expand Down
2 changes: 1 addition & 1 deletion pkg/engine/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (e *Engine) ScanGitHub(ctx context.Context, c sources.GithubConfig) error {
git.ScanOptionFilter(c.Filter),
git.ScanOptionLogOptions(logOptions),
}
scanOptions := git.NewScanOptions(opts...)
scanOptions := *git.NewScanOptions(opts...)
source.WithScanOptions(scanOptions)

e.sourcesWg.Go(func() error {
Expand Down
19 changes: 8 additions & 11 deletions pkg/sources/git/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
if err != nil {
return err
}
return s.git.ScanRepo(ctx, repo, path, NewScanOptions(), chunksChan)
return s.git.ScanRepo(ctx, repo, path, *NewScanOptions(), chunksChan)
}(repoURI)
if err != nil {
ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err)
Expand All @@ -172,7 +172,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
if err != nil {
return err
}
return s.git.ScanRepo(ctx, repo, path, NewScanOptions(), chunksChan)
return s.git.ScanRepo(ctx, repo, path, *NewScanOptions(), chunksChan)
}(repoURI)
if err != nil {
ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err)
Expand All @@ -191,7 +191,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
if err != nil {
return err
}
return s.git.ScanRepo(ctx, repo, path, NewScanOptions(), chunksChan)
return s.git.ScanRepo(ctx, repo, path, *NewScanOptions(), chunksChan)
}(repoURI)
if err != nil {
ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err)
Expand Down Expand Up @@ -221,7 +221,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err
defer os.RemoveAll(repoPath)
}

return s.git.ScanRepo(ctx, repo, repoPath, NewScanOptions(), chunksChan)
return s.git.ScanRepo(ctx, repo, repoPath, *NewScanOptions(), chunksChan)
}(gitDir)
if err != nil {
ctx.Logger().Info("error scanning repository", "repo", gitDir, "error", err)
Expand Down Expand Up @@ -349,7 +349,7 @@ func (s *Git) CommitsScanned() uint64 {
return atomic.LoadUint64(&s.metrics.commitsScanned)
}

func (s *Git) ScanCommits(ctx context.Context, repo *git.Repository, path string, scanOptions *ScanOptions, chunksChan chan *sources.Chunk) error {
func (s *Git) ScanCommits(ctx context.Context, repo *git.Repository, path string, scanOptions ScanOptions, chunksChan chan *sources.Chunk) error {
if err := GitCmdCheck(); err != nil {
return err
}
Expand Down Expand Up @@ -490,7 +490,7 @@ func (s *Git) gitChunk(ctx context.Context, diff gitparse.Diff, fileName, email,
}

// ScanStaged chunks staged changes.
func (s *Git) ScanStaged(ctx context.Context, repo *git.Repository, path string, scanOptions *ScanOptions, chunksChan chan *sources.Chunk) error {
func (s *Git) ScanStaged(ctx context.Context, repo *git.Repository, path string, scanOptions ScanOptions, chunksChan chan *sources.Chunk) error {
// Get the URL metadata for reporting (may be empty).
urlMetadata := getSafeRemoteURL(repo, "origin")

Expand Down Expand Up @@ -570,10 +570,7 @@ func (s *Git) ScanStaged(ctx context.Context, repo *git.Repository, path string,
return nil
}

func (s *Git) ScanRepo(ctx context.Context, repo *git.Repository, repoPath string, scanOptions *ScanOptions, chunksChan chan *sources.Chunk) error {
if scanOptions == nil {
scanOptions = NewScanOptions()
}
func (s *Git) ScanRepo(ctx context.Context, repo *git.Repository, repoPath string, scanOptions ScanOptions, chunksChan chan *sources.Chunk) error {
if err := normalizeConfig(scanOptions, repo); err != nil {
return err
}
Expand All @@ -599,7 +596,7 @@ func (s *Git) ScanRepo(ctx context.Context, repo *git.Repository, repoPath strin
return nil
}

func normalizeConfig(scanOptions *ScanOptions, repo *git.Repository) (err error) {
func normalizeConfig(scanOptions ScanOptions, repo *git.Repository) (err error) {
var baseCommit *object.Commit
if len(scanOptions.BaseHash) > 0 {
baseHash := plumbing.NewHash(scanOptions.BaseHash)
Expand Down
2 changes: 1 addition & 1 deletion pkg/sources/git/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func TestSource_Chunks_Integration(t *testing.T) {
if err != nil {
panic(err)
}
err = s.git.ScanRepo(ctx, repo, repoPath, &tt.scanOptions, chunksCh)
err = s.git.ScanRepo(ctx, repo, repoPath, tt.scanOptions, chunksCh)
if err != nil {
panic(err)
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sources/git/scan_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package git

import (
"github.com/go-git/go-git/v5"

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
)

Expand Down
20 changes: 5 additions & 15 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ type Source struct {
totalRepoSize int // total size in bytes of all repos
git *git.Git

scanOptMu sync.Mutex // protects the scanOptions
scanOptions *git.ScanOptions
scanOptions git.ScanOptions

httpClient *http.Client
log logr.Logger
Expand All @@ -83,17 +82,10 @@ type Source struct {
sources.CommonSourceUnitUnmarshaller
}

func (s *Source) WithScanOptions(scanOptions *git.ScanOptions) {
func (s *Source) WithScanOptions(scanOptions git.ScanOptions) {
s.scanOptions = scanOptions
}

func (s *Source) setScanOptions(base, head string) {
s.scanOptMu.Lock()
defer s.scanOptMu.Unlock()
s.scanOptions.BaseHash = base
s.scanOptions.HeadHash = head
}

// Ensure the Source satisfies the interfaces at compile time
var _ sources.Source = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
Expand Down Expand Up @@ -664,9 +656,6 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch

scanErrs := sources.NewScanErrors()
// Setup scan options if it wasn't provided.
if s.scanOptions == nil {
s.scanOptions = &git.ScanOptions{}
}

for i, repoURL := range s.repos {
i, repoURL := i, repoURL
Expand Down Expand Up @@ -707,7 +696,8 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch
return nil
}

s.setScanOptions(s.conn.Base, s.conn.Head)
s.scanOptions.HeadHash = s.conn.Head
s.scanOptions.BaseHash = s.conn.Base

repoSize := s.repoSizes.getRepo(repoURL)
logger.V(2).Info(fmt.Sprintf("scanning repo %d/%d", i, len(s.repos)), "repo_size", repoSize)
Expand Down Expand Up @@ -1186,7 +1176,7 @@ func (s *Source) chunkGistComments(ctx context.Context, gistUrl string, comments
Timestamp: sanitizer.UTF8(comment.GetCreatedAt().String()),
// Fetching this information requires making an additional API call.
// We may want to include this in the future.
//Visibility: s.visibilityOf(ctx, repoPath),
// Visibility: s.visibilityOf(ctx, repoPath),
},
},
},
Expand Down
2 changes: 1 addition & 1 deletion pkg/sources/github/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (s *Source) cloneRepo(
func (s *Source) getUserAndToken(ctx context.Context, repoURL string, installationClient *github.Client) error {
// We never refresh user provided tokens, so if we already have them, we never need to try and fetch them again.
s.userMu.Lock()
defer s.mu.Unlock()
defer s.userMu.Unlock()
if s.githubUser == "" || s.githubToken == "" {
var err error
s.githubUser, s.githubToken, err = s.userAndToken(ctx, installationClient)
Expand Down
4 changes: 2 additions & 2 deletions pkg/sources/gitlab/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type Source struct {
repos []string
ignoreRepos []string
git *git.Git
scanOptions *git.ScanOptions
scanOptions git.ScanOptions
resumeInfoSlice []string
resumeInfoMutex sync.Mutex
sources.Progress
Expand Down Expand Up @@ -443,5 +443,5 @@ func (s *Source) setProgressCompleteWithRepo(index int, offset int, repoURL stri
}

func (s *Source) WithScanOptions(scanOptions *git.ScanOptions) {
s.scanOptions = scanOptions
s.scanOptions = *scanOptions
}
2 changes: 1 addition & 1 deletion pkg/sources/gitlab/gitlab_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func Test_scanRepos_SetProgressComplete(t *testing.T) {
repos: tc.repos,
}
src.jobPool = &errgroup.Group{}
src.scanOptions = &git.ScanOptions{}
src.scanOptions = git.ScanOptions{}

_ = src.scanRepos(context.Background(), nil)
if !tc.wantErr {
Expand Down

0 comments on commit 1545d65

Please sign in to comment.