From 966778f36e3fc3d876f78da800cba6c99341d2dd Mon Sep 17 00:00:00 2001 From: Richard Gomez Date: Tue, 26 Mar 2024 11:21:24 -0400 Subject: [PATCH] fix(github): enumeration & ratelimit issues --- pkg/sources/github/github.go | 51 +++++++++++++++++++++++++----------- pkg/sources/github/repo.go | 19 +++++++++----- 2 files changed, 48 insertions(+), 22 deletions(-) diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index 112e7bc82cac..76a885d945dd 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -404,31 +404,50 @@ func (s *Source) enumerate(ctx context.Context, apiEndpoint string) (*github.Cli } s.repos = make([]string, 0, s.filteredRepoCache.Count()) + +RepoLoop: for _, repo := range s.filteredRepoCache.Values() { + repoCtx := context.WithValue(ctx, "repo", repo) + r, ok := repo.(string) if !ok { - ctx.Logger().Error(fmt.Errorf("type assertion failed"), "unexpected value in cache", "repo", repo) + repoCtx.Logger().Error(fmt.Errorf("type assertion failed"), "Unexpected value in cache") continue } - _, urlParts, err := getRepoURLParts(r) - if err != nil { - ctx.Logger().Error(err, "failed to parse repository URL") - continue - } + // Ensure that |s.repoInfoCache| contains an entry for |repo|. + // This compensates for differences in enumeration logic between `--org` and `--repo`. + // See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788 + if _, ok := s.repoInfoCache.get(r); !ok { + repoCtx.Logger().V(2).Info("Caching repository info") - // Ignore any gists in |s.filteredRepoCache|. - // Repos have three parts (github.com, owner, name), gists have two. - if len(urlParts) == 3 { - // Ensure that individual repos specified in --repo are cached. - // Gists should be cached elsewhere. - // https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788 - ghRepo, _, err := s.apiClient.Repositories.Get(ctx, urlParts[1], urlParts[2]) + _, urlParts, err := getRepoURLParts(r) if err != nil { - ctx.Logger().Error(err, "failed to fetch repository") + repoCtx.Logger().Error(err, "Failed to parse repository URL") continue } - s.cacheRepoInfo(ghRepo) + + // Ignore any gists in |s.filteredRepoCache|. + // (Repos have three parts: [github.com, owner, name], gists have two.) + if len(urlParts) != 3 { + // Gists _should_ be cached elsewhere. + err = fmt.Errorf("missing cached info for gist: %s", r) + repoCtx.Logger().Error(err, "Unable to cache repository info") + continue RepoLoop + } + + for { + ghRepo, _, err := s.apiClient.Repositories.Get(repoCtx, urlParts[1], urlParts[2]) + if s.handleRateLimit(err) { + continue + } + if err != nil { + repoCtx.Logger().Error(err, "Failed to fetch repository") + continue RepoLoop + } + s.cacheRepoInfo(ghRepo) + break + } } s.repos = append(s.repos, r) } @@ -998,7 +1017,7 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) { logger := s.log.WithValues("user", user) for { orgs, resp, err := s.apiClient.Organizations.List(ctx, "", orgOpts) - if handled := s.handleRateLimit(err); handled { + if s.handleRateLimit(err) { continue } if err != nil { diff --git a/pkg/sources/github/repo.go b/pkg/sources/github/repo.go index cc217a6e8f5a..bce512c21152 100644 --- a/pkg/sources/github/repo.go +++ b/pkg/sources/github/repo.go @@ -329,12 +329,19 @@ type commitQuery struct { // getDiffForFileInCommit retrieves the diff for a specified file in a commit. // If the file or its diff is not found, it returns an error. func (s *Source) getDiffForFileInCommit(ctx context.Context, query commitQuery) (string, error) { - commit, _, err := s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil) - if s.handleRateLimit(err) { - return "", fmt.Errorf("error fetching commit %s due to rate limit: %w", query.sha, err) - } - if err != nil { - return "", fmt.Errorf("error fetching commit %s: %w", query.sha, err) + var ( + commit *github.RepositoryCommit + err error + ) + for { + commit, _, err = s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil) + if s.handleRateLimit(err) { + continue + } + if err != nil { + return "", fmt.Errorf("error fetching commit %s: %w", query.sha, err) + } + break } if len(commit.Files) == 0 {