diff --git a/pkg/sources/git/git.go b/pkg/sources/git/git.go index c01fee4e58c4..8b04fe392abd 100644 --- a/pkg/sources/git/git.go +++ b/pkg/sources/git/git.go @@ -80,8 +80,11 @@ func NewGit(sourceType sourcespb.SourceType, jobID sources.JobID, sourceID sourc } // Ensure the Source satisfies the interfaces at compile time. -var _ sources.Source = (*Source)(nil) -var _ sources.SourceUnitUnmarshaller = (*Source)(nil) +var _ interface { + sources.Source + sources.SourceUnitEnumChunker + sources.SourceUnitUnmarshaller +} = (*Source)(nil) // Type returns the type of source. // It is used for matching source types in configuration and job input. @@ -179,71 +182,52 @@ func (s *Source) scanRepos(ctx context.Context, reporter sources.ChunkReporter) return nil } totalRepos := len(s.conn.Repositories) + len(s.conn.Directories) - // TODO: refactor to remove duplicate code + for i, repoURI := range s.conn.Repositories { + s.SetProgressComplete(i, totalRepos, fmt.Sprintf("Repo: %s", repoURI), "") + if len(repoURI) == 0 { + continue + } + if err := s.scanRepo(ctx, repoURI, reporter); err != nil { + ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err) + continue + } + } + return nil +} + +// scanRepo scans a single provided repository. +func (s *Source) scanRepo(ctx context.Context, repoURI string, reporter sources.ChunkReporter) error { + var cloneFunc func() (string, *git.Repository, error) switch cred := s.conn.GetCredential().(type) { case *sourcespb.Git_BasicAuth: - user := cred.BasicAuth.Username - token := cred.BasicAuth.Password - - for i, repoURI := range s.conn.Repositories { - s.SetProgressComplete(i, totalRepos, fmt.Sprintf("Repo: %s", repoURI), "") - if len(repoURI) == 0 { - continue - } - err := func(repoURI string) error { - path, repo, err := CloneRepoUsingToken(ctx, token, repoURI, user) - defer os.RemoveAll(path) - if err != nil { - return err - } - return s.git.ScanRepo(ctx, repo, path, s.scanOptions, reporter) - }(repoURI) - if err != nil { - ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err) - continue - } + cloneFunc = func() (string, *git.Repository, error) { + user := cred.BasicAuth.Username + token := cred.BasicAuth.Password + return CloneRepoUsingToken(ctx, token, repoURI, user) } case *sourcespb.Git_Unauthenticated: - for i, repoURI := range s.conn.Repositories { - s.SetProgressComplete(i, totalRepos, fmt.Sprintf("Repo: %s", repoURI), "") - if len(repoURI) == 0 { - continue - } - err := func(repoURI string) error { - path, repo, err := CloneRepoUsingUnauthenticated(ctx, repoURI) - defer os.RemoveAll(path) - if err != nil { - return err - } - return s.git.ScanRepo(ctx, repo, path, s.scanOptions, reporter) - }(repoURI) - if err != nil { - ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err) - continue - } + cloneFunc = func() (string, *git.Repository, error) { + return CloneRepoUsingUnauthenticated(ctx, repoURI) } case *sourcespb.Git_SshAuth: - for i, repoURI := range s.conn.Repositories { - s.SetProgressComplete(i, totalRepos, fmt.Sprintf("Repo: %s", repoURI), "") - if len(repoURI) == 0 { - continue - } - err := func(repoURI string) error { - path, repo, err := CloneRepoUsingSSH(ctx, repoURI) - defer os.RemoveAll(path) - if err != nil { - return err - } - return s.git.ScanRepo(ctx, repo, path, s.scanOptions, reporter) - }(repoURI) - if err != nil { - ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err) - continue - } + cloneFunc = func() (string, *git.Repository, error) { + return CloneRepoUsingSSH(ctx, repoURI) } default: return errors.New("invalid connection type for git source") } + + err := func() error { + path, repo, err := cloneFunc() + defer os.RemoveAll(path) + if err != nil { + return err + } + return s.git.ScanRepo(ctx, repo, path, s.scanOptions, reporter) + }() + if err != nil { + return reporter.ChunkErr(ctx, err) + } return nil } @@ -256,29 +240,35 @@ func (s *Source) scanDirs(ctx context.Context, reporter sources.ChunkReporter) e if len(gitDir) == 0 { continue } - if !s.scanOptions.Bare && strings.HasSuffix(gitDir, "git") { - // TODO: Figure out why we skip directories ending in "git". - continue - } - // try paths instead of url - repo, err := RepoFromPath(gitDir, s.scanOptions.Bare) - if err != nil { + if err := s.scanDir(ctx, gitDir, reporter); err != nil { ctx.Logger().Info("error scanning repository", "repo", gitDir, "error", err) continue } + } + return nil +} - err = func(repoPath string) error { - if !s.preserveTempDirs && strings.HasPrefix(repoPath, filepath.Join(os.TempDir(), "trufflehog")) { - defer os.RemoveAll(repoPath) - } +// scanDir scans a single provided directory. +func (s *Source) scanDir(ctx context.Context, gitDir string, reporter sources.ChunkReporter) error { + if !s.scanOptions.Bare && strings.HasSuffix(gitDir, "git") { + // TODO: Figure out why we skip directories ending in "git". + return nil + } + // try paths instead of url + repo, err := RepoFromPath(gitDir, s.scanOptions.Bare) + if err != nil { + return reporter.ChunkErr(ctx, err) + } - return s.git.ScanRepo(ctx, repo, repoPath, s.scanOptions, reporter) - }(gitDir) - if err != nil { - ctx.Logger().Info("error scanning repository", "repo", gitDir, "error", err) - continue + err = func() error { + if !s.preserveTempDirs && strings.HasPrefix(gitDir, filepath.Join(os.TempDir(), "trufflehog")) { + defer os.RemoveAll(gitDir) } + return s.git.ScanRepo(ctx, repo, gitDir, s.scanOptions, reporter) + }() + if err != nil { + return reporter.ChunkErr(ctx, err) } return nil } @@ -1003,6 +993,44 @@ func handleBinary(ctx context.Context, repo *git.Repository, reporter sources.Ch return nil } +func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error { + for _, repo := range s.conn.GetDirectories() { + if repo == "" { + continue + } + unit := SourceUnit{ID: repo, Kind: UnitDir} + if err := reporter.UnitOk(ctx, unit); err != nil { + return err + } + } + for _, repo := range s.conn.GetRepositories() { + if repo == "" { + continue + } + unit := SourceUnit{ID: repo, Kind: UnitRepo} + if err := reporter.UnitOk(ctx, unit); err != nil { + return err + } + } + return nil +} + +func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error { + gitUnit, ok := unit.(SourceUnit) + if !ok { + return fmt.Errorf("unsupported unit type: %T", unit) + } + + switch gitUnit.Kind { + case UnitRepo: + return s.scanRepo(ctx, gitUnit.ID, reporter) + case UnitDir: + return s.scanDir(ctx, gitUnit.ID, reporter) + default: + return fmt.Errorf("unexpected git unit kind: %q", gitUnit.Kind) + } +} + func (s *Source) UnmarshalSourceUnit(data []byte) (sources.SourceUnit, error) { return UnmarshalUnit(data) } diff --git a/pkg/sources/git/git_test.go b/pkg/sources/git/git_test.go index e9ec417f3d54..cc9d31756d3d 100644 --- a/pkg/sources/git/git_test.go +++ b/pkg/sources/git/git_test.go @@ -16,6 +16,7 @@ import ( "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" + "github.com/trufflesecurity/trufflehog/v3/pkg/sourcestest" ) func TestSource_Scan(t *testing.T) { @@ -504,3 +505,72 @@ func TestGitURLParse(t *testing.T) { assert.Equal(t, tt.scheme, u.Scheme) } } + +func TestEnumerate(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Setup the connection to test enumeration. + units := []string{ + "foo", "bar", "baz", + "/path/to/dir/", "/path/to/another/dir/", + } + conn, err := anypb.New(&sourcespb.Git{ + Repositories: units[0:3], + Directories: units[3:], + }) + assert.NoError(t, err) + + // Initialize the source. + s := Source{} + err = s.Init(ctx, "test enumerate", 0, 0, true, conn, 1) + assert.NoError(t, err) + + reporter := sourcestest.TestReporter{} + err = s.Enumerate(ctx, &reporter) + assert.NoError(t, err) + + assert.Equal(t, len(units), len(reporter.Units)) + assert.Equal(t, 0, len(reporter.UnitErrs)) + for _, unit := range reporter.Units { + assert.Contains(t, units, unit.SourceUnitID()) + } + for _, unit := range units[:3] { + assert.Contains(t, reporter.Units, SourceUnit{ID: unit, Kind: UnitRepo}) + } + for _, unit := range units[3:] { + assert.Contains(t, reporter.Units, SourceUnit{ID: unit, Kind: UnitDir}) + } +} + +func TestChunkUnit(t *testing.T) { + t.Parallel() + ctx := context.Background() + // Initialize the source. + s := Source{} + conn, err := anypb.New(&sourcespb.Git{ + Credential: &sourcespb.Git_Unauthenticated{}, + }) + assert.NoError(t, err) + err = s.Init(ctx, "test chunk", 0, 0, true, conn, 1) + assert.NoError(t, err) + + reporter := sourcestest.TestReporter{} + + // Happy path single repository. + err = s.ChunkUnit(ctx, SourceUnit{ + ID: "https://github.com/dustin-decker/secretsandstuff.git", + Kind: UnitRepo, + }, &reporter) + assert.NoError(t, err) + + // Error path. + err = s.ChunkUnit(ctx, SourceUnit{ + ID: "/file/not/found", + Kind: UnitDir, + }, &reporter) + assert.NoError(t, err) + + assert.Equal(t, 11, len(reporter.Chunks)) + assert.Equal(t, 1, len(reporter.ChunkErrs)) +}