Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GitHub enumeration & rate-limiting logic #2625

Merged
merged 2 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 45 additions & 26 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,31 +404,59 @@ func (s *Source) enumerate(ctx context.Context, apiEndpoint string) (*github.Cli
}

s.repos = make([]string, 0, s.filteredRepoCache.Count())

RepoLoop:
for _, repo := range s.filteredRepoCache.Values() {
repoCtx := context.WithValue(ctx, "repo", repo)

r, ok := repo.(string)
if !ok {
ctx.Logger().Error(fmt.Errorf("type assertion failed"), "unexpected value in cache", "repo", repo)
repoCtx.Logger().Error(fmt.Errorf("type assertion failed"), "Unexpected value in cache")
continue
}

_, urlParts, err := getRepoURLParts(r)
if err != nil {
ctx.Logger().Error(err, "failed to parse repository URL")
continue
}
// Ensure that |s.repoInfoCache| contains an entry for |repo|.
// This compensates for differences in enumeration logic between `--org` and `--repo`.
// See: https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788
if _, ok := s.repoInfoCache.get(r); !ok {
repoCtx.Logger().V(2).Info("Caching repository info")

// Ignore any gists in |s.filteredRepoCache|.
// Repos have three parts (github.com, owner, name), gists have two.
if len(urlParts) == 3 {
// Ensure that individual repos specified in --repo are cached.
// Gists should be cached elsewhere.
// https://github.com/trufflesecurity/trufflehog/pull/2379#discussion_r1487454788
ghRepo, _, err := s.apiClient.Repositories.Get(ctx, urlParts[1], urlParts[2])
_, urlParts, err := getRepoURLParts(r)
if err != nil {
ctx.Logger().Error(err, "failed to fetch repository")
repoCtx.Logger().Error(err, "Failed to parse repository URL")
continue
}
s.cacheRepoInfo(ghRepo)

if strings.EqualFold(urlParts[0], "gist.github.com") {
// Cache gist info.
for {
gistID := extractGistID(urlParts)
gist, _, err := s.apiClient.Gists.Get(repoCtx, gistID)
if s.handleRateLimit(err) {
continue
}
if err != nil {
repoCtx.Logger().Error(err, "Failed to fetch gist")
continue RepoLoop
}
s.cacheGistInfo(gist)
break
}
} else {
// Cache repository info.
for {
ghRepo, _, err := s.apiClient.Repositories.Get(repoCtx, urlParts[1], urlParts[2])
if s.handleRateLimit(err) {
continue
}
if err != nil {
repoCtx.Logger().Error(err, "Failed to fetch repository")
continue RepoLoop
}
s.cacheRepoInfo(ghRepo)
break
}
}
}
s.repos = append(s.repos, r)
}
Expand Down Expand Up @@ -902,16 +930,7 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string) error {

for _, gist := range gists {
s.filteredRepoCache.Set(gist.GetID(), gist.GetGitPullURL())

info := repoInfo{
owner: gist.GetOwner().GetLogin(),
}
if gist.GetPublic() {
info.visibility = source_metadatapb.Visibility_public
} else {
info.visibility = source_metadatapb.Visibility_private
}
s.repoInfoCache.put(gist.GetGitPullURL(), info)
s.cacheGistInfo(gist)
}

if res == nil || res.NextPage == 0 {
Expand Down Expand Up @@ -998,7 +1017,7 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) {
logger := s.log.WithValues("user", user)
for {
orgs, resp, err := s.apiClient.Organizations.List(ctx, "", orgOpts)
if handled := s.handleRateLimit(err); handled {
if s.handleRateLimit(err) {
continue
}
if err != nil {
Expand Down
53 changes: 44 additions & 9 deletions pkg/sources/github/github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,20 +451,17 @@ func BenchmarkEnumerateWithToken(b *testing.B) {
func TestEnumerate(t *testing.T) {
defer gock.Off()

// Arrange
gock.New("https://api.github.com").
Get("/user").
Reply(200).
JSON(map[string]string{"login": "super-secret-user"})

//
gock.New("https://api.github.com").
Get("/users/super-secret-user/repos").
Reply(200).
JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-user/super-secret-repo.git", "full_name": "super-secret-user/super-secret-repo"}})

gock.New("https://api.github.com").
Get("/repos/super-secret-user/super-secret-repo").
Reply(200).
JSON(`{"owner": {"login": "super-secret-user"}, "name": "super-secret-repo", "full_name": "super-secret-user/super-secret-repo", "has_wiki": false, "size": 1}`)
JSON(`[{"name": "super-secret-repo", "full_name": "super-secret-user/super-secret-repo", "owner": {"login": "super-secret-user"}, "clone_url": "https://github.com/super-secret-user/super-secret-repo.git", "has_wiki": false, "size": 1}]`)

gock.New("https://api.github.com").
Get("/user/orgs").
Expand All @@ -483,12 +480,50 @@ func TestEnumerate(t *testing.T) {
},
})

// Manually cache a repository to ensure that enumerate
// doesn't make duplicate API calls.
// See https://github.com/trufflesecurity/trufflehog/pull/2625
repo := func() *github.Repository {
var (
name = "cached-repo"
fullName = "cached-user/cached-repo"
login = "cached-user"
cloneUrl = "https://github.com/cached-user/cached-repo.git"
owner = &github.User{
Login: &login,
}
hasWiki = false
size = 1234
)
return &github.Repository{
Name: &name,
FullName: &fullName,
Owner: owner,
HasWiki: &hasWiki,
Size: &size,
CloneURL: &cloneUrl,
}
}()
s.cacheRepoInfo(repo)
s.filteredRepoCache.Set(repo.GetFullName(), repo.GetCloneURL())

// Act
_, err := s.enumerate(context.Background(), "https://api.github.com")

// Assert
assert.Nil(t, err)
assert.Equal(t, 2, s.filteredRepoCache.Count())
ok := s.filteredRepoCache.Exists("super-secret-user/super-secret-repo")
// Enumeration found all repos.
assert.Equal(t, 3, s.filteredRepoCache.Count())
assert.True(t, s.filteredRepoCache.Exists("super-secret-user/super-secret-repo"))
assert.True(t, s.filteredRepoCache.Exists("cached-user/cached-repo"))
assert.True(t, s.filteredRepoCache.Exists("2801a2b0523099d0614a951579d99ba9"))
// Enumeration cached all repos.
assert.Equal(t, 3, len(s.repoInfoCache.cache))
_, ok := s.repoInfoCache.get("https://github.com/super-secret-user/super-secret-repo.git")
assert.True(t, ok)
_, ok = s.repoInfoCache.get("https://github.com/cached-user/cached-repo.git")
assert.True(t, ok)
ok = s.filteredRepoCache.Exists("2801a2b0523099d0614a951579d99ba9")
_, ok = s.repoInfoCache.get("https://gist.github.com/2801a2b0523099d0614a951579d99ba9.git")
assert.True(t, ok)
assert.True(t, gock.IsDone())
}
Expand Down
31 changes: 25 additions & 6 deletions pkg/sources/github/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,18 @@ func (s *Source) cacheRepoInfo(r *github.Repository) {
s.repoInfoCache.put(r.GetCloneURL(), info)
}

func (s *Source) cacheGistInfo(g *github.Gist) {
info := repoInfo{
owner: g.GetOwner().GetLogin(),
}
if g.GetPublic() {
info.visibility = source_metadatapb.Visibility_public
} else {
info.visibility = source_metadatapb.Visibility_private
}
s.repoInfoCache.put(g.GetGitPullURL(), info)
}

// wikiIsReachable returns true if https://github.com/$org/$repo/wiki is not redirected.
// Unfortunately, this isn't 100% accurate. Some repositories have `has_wiki: true` and don't redirect their wiki page,
// but still don't have a cloneable wiki.
Expand Down Expand Up @@ -329,12 +341,19 @@ type commitQuery struct {
// getDiffForFileInCommit retrieves the diff for a specified file in a commit.
// If the file or its diff is not found, it returns an error.
func (s *Source) getDiffForFileInCommit(ctx context.Context, query commitQuery) (string, error) {
commit, _, err := s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil)
if s.handleRateLimit(err) {
return "", fmt.Errorf("error fetching commit %s due to rate limit: %w", query.sha, err)
}
if err != nil {
return "", fmt.Errorf("error fetching commit %s: %w", query.sha, err)
var (
commit *github.RepositoryCommit
err error
)
for {
commit, _, err = s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil)
if s.handleRateLimit(err) {
continue
}
if err != nil {
return "", fmt.Errorf("error fetching commit %s: %w", query.sha, err)
}
break
}

if len(commit.Files) == 0 {
Expand Down
Loading