diff --git a/go.mod b/go.mod index ccbdc4c35f97..8c5fcdef268a 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/golang-jwt/jwt/v4 v4.5.0 github.com/google/go-cmp v0.6.0 github.com/google/go-containerregistry v0.17.0 - github.com/google/go-github/v42 v42.0.0 + github.com/google/go-github/v57 v57.0.0 github.com/google/uuid v1.5.0 github.com/googleapis/gax-go/v2 v2.12.0 github.com/h2non/filetype v1.1.3 @@ -174,7 +174,6 @@ require ( github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v23.1.21+incompatible // indirect - github.com/google/go-github/v57 v57.0.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect github.com/google/s2a-go v0.1.7 // indirect diff --git a/go.sum b/go.sum index 6c9dbde9dec1..f119c67a0caa 100644 --- a/go.sum +++ b/go.sum @@ -366,8 +366,6 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-containerregistry v0.17.0 h1:5p+zYs/R4VGHkhyvgWurWrpJ2hW4Vv9fQI+GzdcwXLk= github.com/google/go-containerregistry v0.17.0/go.mod h1:u0qB2l7mvtWVR5kNcbFIhFY1hLbf8eeGapA+vbFDCtQ= -github.com/google/go-github/v42 v42.0.0 h1:YNT0FwjPrEysRkLIiKuEfSvBPCGKphW5aS5PxwaoLec= -github.com/google/go-github/v42 v42.0.0/go.mod h1:jgg/jvyI0YlDOM1/ps6XYh04HNQ3vKf0CVko62/EhRg= github.com/google/go-github/v57 v57.0.0 h1:L+Y3UPTY8ALM8x+TV0lg+IEBI+upibemtBD8Q9u7zHs= github.com/google/go-github/v57 v57.0.0/go.mod h1:s0omdnye0hvK/ecLvpsGfJMiRt85PimQh4oygmLIxHw= github.com/google/go-querystring v0.0.0-20170111101155-53e6ce116135/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= diff --git a/pkg/sources/git/git.go b/pkg/sources/git/git.go index 9843b860cf6c..344202ee3340 100644 --- a/pkg/sources/git/git.go +++ b/pkg/sources/git/git.go @@ -18,13 +18,14 @@ import ( "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/plumbing" "github.com/go-git/go-git/v5/plumbing/object" - "github.com/google/go-github/v42/github" - diskbufferreader "github.com/trufflesecurity/disk-buffer-reader" + "github.com/google/go-github/v57/github" "golang.org/x/oauth2" "golang.org/x/sync/semaphore" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + diskbufferreader "github.com/trufflesecurity/disk-buffer-reader" + "github.com/trufflesecurity/trufflehog/v3/pkg/cleantemp" "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/context" diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index 555c69ab7d8c..1271d729396e 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -14,10 +14,12 @@ import ( "sync/atomic" "time" + "golang.org/x/exp/rand" + "github.com/bradleyfalzon/ghinstallation/v2" "github.com/go-logr/logr" "github.com/gobwas/glob" - "github.com/google/go-github/v42/github" + "github.com/google/go-github/v57/github" "golang.org/x/oauth2" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/proto" @@ -388,7 +390,6 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s return } - var resp *github.Response urlPathParts := strings.Split(u.Path, "/") switch len(urlPathParts) { case 2: @@ -397,8 +398,8 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s repoName := urlPathParts[1] repoName = strings.TrimSuffix(repoName, ".git") for { - gist, resp, err = s.apiClient.Gists.Get(ctx, repoName) - if !s.handleRateLimit(err, resp) { + gist, _, err = s.apiClient.Gists.Get(ctx, repoName) + if !s.handleRateLimit(err) { break } } @@ -415,8 +416,8 @@ func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility s repoName := urlPathParts[2] repoName = strings.TrimSuffix(repoName, ".git") for { - repo, resp, err = s.apiClient.Repositories.Get(ctx, owner, repoName) - if !s.handleRateLimit(err, resp) { + repo, _, err = s.apiClient.Repositories.Get(ctx, owner, repoName) + if !s.handleRateLimit(err) { break } } @@ -584,13 +585,12 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri var ( ghUser *github.User - resp *github.Response ) ctx.Logger().V(1).Info("Enumerating with token", "endpoint", apiEndpoint) for { - ghUser, resp, err = s.apiClient.Users.Get(ctx, "") - if handled := s.handleRateLimit(err, resp); handled { + ghUser, _, err = s.apiClient.Users.Get(ctx, "") + if s.handleRateLimit(err) { continue } if err != nil { @@ -869,53 +869,70 @@ func (s *Source) cloneAndScanRepo(ctx context.Context, client *github.Client, re return duration, nil } -// handleRateLimit returns true if a rate limit was handled -// Unauthenticated access to most github endpoints has a rate limit of 60 requests per hour. -// This will likely only be exhausted if many users/orgs are scanned without auth -func (s *Source) handleRateLimit(errIn error, res *github.Response) bool { - var ( - knownWait = true - remaining = 0 - retryAfter time.Duration - ) +var ( + rateLimitMu sync.RWMutex + rateLimitResumeTime time.Time +) - // GitHub has both primary (RateLimit) and secondary (AbuseRateLimit) errors. - var rateLimit *github.RateLimitError - var abuseLimit *github.AbuseRateLimitError - if errors.As(errIn, &rateLimit) { - // Do nothing - } else if errors.As(errIn, &abuseLimit) { - retryAfter = abuseLimit.GetRetryAfter() - } else { +// handleRateLimit returns true if a rate limit was handled +// +// Unauthenticated users have a rate limit of 60 requests per hour. +// Authenticated users have a rate limit of 5,000 requests per hour, +// however, certain actions are subject to a stricter "secondary" limit. +// https://docs.github.com/en/rest/overview/rate-limits-for-the-rest-api +func (s *Source) handleRateLimit(errIn error) bool { + if errIn == nil { return false } - githubNumRateLimitEncountered.WithLabelValues(s.name).Inc() - // Parse retry information from response headers, unless a Retry-After value was already provided. - // https://docs.github.com/en/rest/overview/resources-in-the-rest-api#exceeding-the-rate-limit - if retryAfter <= 0 && res != nil { - var err error - remaining, err = strconv.Atoi(res.Header.Get("x-ratelimit-remaining")) - if err != nil { - knownWait = false + rateLimitMu.RLock() + resumeTime := rateLimitResumeTime + rateLimitMu.RUnlock() + + var retryAfter time.Duration + if resumeTime.IsZero() || time.Now().After(resumeTime) { + rateLimitMu.Lock() + + var ( + now = time.Now() + + // GitHub has both primary (RateLimit) and secondary (AbuseRateLimit) errors. + limitType string + rateLimit *github.RateLimitError + abuseLimit *github.AbuseRateLimitError + ) + if errors.As(errIn, &rateLimit) { + limitType = "primary" + rate := rateLimit.Rate + if rate.Remaining == 0 { // TODO: Will we ever receive a |RateLimitError| when remaining > 0? + retryAfter = rate.Reset.Sub(now) + } + } else if errors.As(errIn, &abuseLimit) { + limitType = "secondary" + retryAfter = abuseLimit.GetRetryAfter() + } else { + rateLimitMu.Unlock() + return false } - resetTime, err := strconv.Atoi(res.Header.Get("x-ratelimit-reset")) - if err != nil || resetTime == 0 { - knownWait = false - } else if resetTime > 0 { - retryAfter = time.Duration(int64(resetTime)-time.Now().Unix()) * time.Second + jitter := time.Duration(rand.Intn(10)+1) * time.Second + if retryAfter > 0 { + retryAfter = retryAfter + jitter + rateLimitResumeTime = now.Add(retryAfter) + s.log.V(0).Info(fmt.Sprintf("exceeded %s rate limit", limitType), "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339)) + } else { + retryAfter = (5 * time.Minute) + jitter + rateLimitResumeTime = now.Add(retryAfter) + // TODO: Use exponential backoff instead of static retry time. + s.log.V(0).Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", rateLimitResumeTime.Format(time.RFC3339)) } - } - resumeTime := time.Now().Add(retryAfter).String() - if knownWait && remaining == 0 && retryAfter > 0 { - s.log.V(2).Info("rate limited", "retry_after", retryAfter.String(), "resume_time", resumeTime) + rateLimitMu.Unlock() } else { - // TODO: Use exponential backoff instead of static retry time. - retryAfter = time.Minute * 5 - s.log.V(2).Error(errIn, "unexpected rate limit error", "retry_after", retryAfter.String(), "resume_time", resumeTime) + retryAfter = resumeTime.Sub(time.Now()) } + + githubNumRateLimitEncountered.WithLabelValues(s.name).Inc() time.Sleep(retryAfter) githubSecondsSpentRateLimited.WithLabelValues(s.name).Add(retryAfter.Seconds()) return true @@ -940,10 +957,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string) error { logger := s.log.WithValues("user", user) for { gists, res, err := s.apiClient.Gists.List(ctx, user, gistOpts) - if err == nil { - res.Body.Close() - } - if handled := s.handleRateLimit(err, res); handled { + if s.handleRateLimit(err) { continue } if err != nil { @@ -996,11 +1010,8 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) { }, } for { - orgs, resp, err := s.apiClient.Organizations.ListAll(ctx, orgOpts) - if err == nil { - resp.Body.Close() - } - if handled := s.handleRateLimit(err, resp); handled { + orgs, _, err := s.apiClient.Organizations.ListAll(ctx, orgOpts) + if s.handleRateLimit(err) { continue } if err != nil { @@ -1037,10 +1048,7 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) { logger := s.log.WithValues("user", user) for { orgs, resp, err := s.apiClient.Organizations.List(ctx, "", orgOpts) - if err == nil { - resp.Body.Close() - } - if handled := s.handleRateLimit(err, resp); handled { + if handled := s.handleRateLimit(err); handled { continue } if err != nil { @@ -1075,10 +1083,7 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error { logger := s.log.WithValues("org", org) for { members, res, err := s.apiClient.Organizations.ListMembers(ctx, org, opts) - if err == nil { - defer res.Body.Close() - } - if handled := s.handleRateLimit(err, res); handled { + if s.handleRateLimit(err) { continue } if err != nil || len(members) == 0 { @@ -1150,8 +1155,8 @@ func (s *Source) processGistComments(ctx context.Context, repoPath string, trimm Page: initialPage, } for { - comments, resp, err := s.apiClient.Gists.ListComments(ctx, gistID, options) - if s.handleRateLimit(err, resp) { + comments, _, err := s.apiClient.Gists.ListComments(ctx, gistID, options) + if s.handleRateLimit(err) { break } if err != nil { @@ -1254,8 +1259,8 @@ func (s *Source) processIssues(ctx context.Context, info repoInfo, chunksChan ch } for { - issues, resp, err := s.apiClient.Issues.ListByRepo(ctx, info.owner, info.repo, bodyTextsOpts) - if s.handleRateLimit(err, resp) { + issues, _, err := s.apiClient.Issues.ListByRepo(ctx, info.owner, info.repo, bodyTextsOpts) + if s.handleRateLimit(err) { break } @@ -1287,8 +1292,8 @@ func (s *Source) processIssueComments(ctx context.Context, info repoInfo, chunks } for { - issueComments, resp, err := s.apiClient.Issues.ListComments(ctx, info.owner, info.repo, allComments, issueOpts) - if s.handleRateLimit(err, resp) { + issueComments, _, err := s.apiClient.Issues.ListComments(ctx, info.owner, info.repo, allComments, issueOpts) + if s.handleRateLimit(err) { break } @@ -1321,8 +1326,8 @@ func (s *Source) processPRs(ctx context.Context, info repoInfo, chunksChan chan } for { - prs, resp, err := s.apiClient.PullRequests.List(ctx, info.owner, info.repo, prOpts) - if s.handleRateLimit(err, resp) { + prs, _, err := s.apiClient.PullRequests.List(ctx, info.owner, info.repo, prOpts) + if s.handleRateLimit(err) { break } @@ -1354,8 +1359,8 @@ func (s *Source) processPRComments(ctx context.Context, info repoInfo, chunksCha } for { - prComments, resp, err := s.apiClient.PullRequests.ListComments(ctx, info.owner, info.repo, allComments, prOpts) - if s.handleRateLimit(err, resp) { + prComments, _, err := s.apiClient.PullRequests.ListComments(ctx, info.owner, info.repo, allComments, prOpts) + if s.handleRateLimit(err) { break } diff --git a/pkg/sources/github/github_test.go b/pkg/sources/github/github_test.go index 8b5fd27e1909..fba7ccd5494c 100644 --- a/pkg/sources/github/github_test.go +++ b/pkg/sources/github/github_test.go @@ -17,7 +17,7 @@ import ( "github.com/go-logr/logr" "github.com/google/go-cmp/cmp" - "github.com/google/go-github/v42/github" + "github.com/google/go-github/v57/github" "github.com/stretchr/testify/assert" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/types/known/anypb" diff --git a/pkg/sources/github/repo.go b/pkg/sources/github/repo.go index f76b3100643b..9e6ed9a9592e 100644 --- a/pkg/sources/github/repo.go +++ b/pkg/sources/github/repo.go @@ -7,7 +7,7 @@ import ( "strings" gogit "github.com/go-git/go-git/v5" - "github.com/google/go-github/v42/github" + "github.com/google/go-github/v57/github" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/giturl" @@ -98,12 +98,11 @@ func (s *Source) userAndToken(ctx context.Context, installationClient *github.Cl case *sourcespb.GitHub_Token: var ( ghUser *github.User - resp *github.Response err error ) for { - ghUser, resp, err = s.apiClient.Users.Get(ctx, "") - if handled := s.handleRateLimit(err, resp); handled { + ghUser, _, err = s.apiClient.Users.Get(ctx, "") + if s.handleRateLimit(err) { continue } if err != nil { @@ -204,10 +203,7 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo for { someRepos, res, err := listRepos(ctx, target, listOpts) - if err == nil { - res.Body.Close() - } - if handled := s.handleRateLimit(err, res); handled { + if s.handleRateLimit(err) { continue } if err != nil { @@ -287,8 +283,8 @@ type commitQuery struct { // getDiffForFileInCommit retrieves the diff for a specified file in a commit. // If the file or its diff is not found, it returns an error. func (s *Source) getDiffForFileInCommit(ctx context.Context, query commitQuery) (string, error) { - commit, resp, err := s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil) - if handled := s.handleRateLimit(err, resp); handled { + commit, _, err := s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil) + if s.handleRateLimit(err) { return "", fmt.Errorf("error fetching commit %s due to rate limit: %w", query.sha, err) } if err != nil {