Skip to content

Commit

Permalink
Refactor git source to support scanning units
Browse files Browse the repository at this point in the history
  • Loading branch information
mcastorina committed Nov 1, 2023
1 parent be8e254 commit ef89775
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 73 deletions.
192 changes: 119 additions & 73 deletions pkg/sources/git/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,11 @@ func NewGit(sourceType sourcespb.SourceType, jobID sources.JobID, sourceID sourc
}

// Ensure the Source satisfies the interfaces at compile time.
var _ sources.Source = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
var _ interface {
sources.Source
sources.SourceUnitEnumChunker
sources.SourceUnitUnmarshaller
} = (*Source)(nil)

// Type returns the type of source.
// It is used for matching source types in configuration and job input.
Expand Down Expand Up @@ -179,71 +182,52 @@ func (s *Source) scanRepos(ctx context.Context, reporter sources.ChunkReporter)
return nil
}
totalRepos := len(s.conn.Repositories) + len(s.conn.Directories)
// TODO: refactor to remove duplicate code
for i, repoURI := range s.conn.Repositories {
s.SetProgressComplete(i, totalRepos, fmt.Sprintf("Repo: %s", repoURI), "")
if len(repoURI) == 0 {
continue
}
if err := s.scanRepo(ctx, repoURI, reporter); err != nil {
ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err)
continue
}
}
return nil
}

// scanRepo scans a single provided repository.
func (s *Source) scanRepo(ctx context.Context, repoURI string, reporter sources.ChunkReporter) error {
var cloneFunc func() (string, *git.Repository, error)
switch cred := s.conn.GetCredential().(type) {
case *sourcespb.Git_BasicAuth:
user := cred.BasicAuth.Username
token := cred.BasicAuth.Password

for i, repoURI := range s.conn.Repositories {
s.SetProgressComplete(i, totalRepos, fmt.Sprintf("Repo: %s", repoURI), "")
if len(repoURI) == 0 {
continue
}
err := func(repoURI string) error {
path, repo, err := CloneRepoUsingToken(ctx, token, repoURI, user)
defer os.RemoveAll(path)
if err != nil {
return err
}
return s.git.ScanRepo(ctx, repo, path, s.scanOptions, reporter)
}(repoURI)
if err != nil {
ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err)
continue
}
cloneFunc = func() (string, *git.Repository, error) {
user := cred.BasicAuth.Username
token := cred.BasicAuth.Password
return CloneRepoUsingToken(ctx, token, repoURI, user)
}
case *sourcespb.Git_Unauthenticated:
for i, repoURI := range s.conn.Repositories {
s.SetProgressComplete(i, totalRepos, fmt.Sprintf("Repo: %s", repoURI), "")
if len(repoURI) == 0 {
continue
}
err := func(repoURI string) error {
path, repo, err := CloneRepoUsingUnauthenticated(ctx, repoURI)
defer os.RemoveAll(path)
if err != nil {
return err
}
return s.git.ScanRepo(ctx, repo, path, s.scanOptions, reporter)
}(repoURI)
if err != nil {
ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err)
continue
}
cloneFunc = func() (string, *git.Repository, error) {
return CloneRepoUsingUnauthenticated(ctx, repoURI)
}
case *sourcespb.Git_SshAuth:
for i, repoURI := range s.conn.Repositories {
s.SetProgressComplete(i, totalRepos, fmt.Sprintf("Repo: %s", repoURI), "")
if len(repoURI) == 0 {
continue
}
err := func(repoURI string) error {
path, repo, err := CloneRepoUsingSSH(ctx, repoURI)
defer os.RemoveAll(path)
if err != nil {
return err
}
return s.git.ScanRepo(ctx, repo, path, s.scanOptions, reporter)
}(repoURI)
if err != nil {
ctx.Logger().Info("error scanning repository", "repo", repoURI, "error", err)
continue
}
cloneFunc = func() (string, *git.Repository, error) {
return CloneRepoUsingSSH(ctx, repoURI)
}
default:
return errors.New("invalid connection type for git source")
}

err := func() error {
path, repo, err := cloneFunc()
defer os.RemoveAll(path)
if err != nil {
return err
}
return s.git.ScanRepo(ctx, repo, path, s.scanOptions, reporter)
}()
if err != nil {
return reporter.ChunkErr(ctx, err)
}
return nil
}

Expand All @@ -256,29 +240,35 @@ func (s *Source) scanDirs(ctx context.Context, reporter sources.ChunkReporter) e
if len(gitDir) == 0 {
continue
}
if !s.scanOptions.Bare && strings.HasSuffix(gitDir, "git") {
// TODO: Figure out why we skip directories ending in "git".
continue
}
// try paths instead of url
repo, err := RepoFromPath(gitDir, s.scanOptions.Bare)
if err != nil {
if err := s.scanDir(ctx, gitDir, reporter); err != nil {
ctx.Logger().Info("error scanning repository", "repo", gitDir, "error", err)
continue
}
}
return nil
}

err = func(repoPath string) error {
if !s.preserveTempDirs && strings.HasPrefix(repoPath, filepath.Join(os.TempDir(), "trufflehog")) {
defer os.RemoveAll(repoPath)
}
// scanDir scans a single provided directory.
func (s *Source) scanDir(ctx context.Context, gitDir string, reporter sources.ChunkReporter) error {
if !s.scanOptions.Bare && strings.HasSuffix(gitDir, "git") {
// TODO: Figure out why we skip directories ending in "git".
return nil
}
// try paths instead of url
repo, err := RepoFromPath(gitDir, s.scanOptions.Bare)
if err != nil {
return reporter.ChunkErr(ctx, err)
}

return s.git.ScanRepo(ctx, repo, repoPath, s.scanOptions, reporter)
}(gitDir)
if err != nil {
ctx.Logger().Info("error scanning repository", "repo", gitDir, "error", err)
continue
err = func() error {
if !s.preserveTempDirs && strings.HasPrefix(gitDir, filepath.Join(os.TempDir(), "trufflehog")) {
defer os.RemoveAll(gitDir)
}

return s.git.ScanRepo(ctx, repo, gitDir, s.scanOptions, reporter)
}()
if err != nil {
return reporter.ChunkErr(ctx, err)
}
return nil
}
Expand Down Expand Up @@ -528,6 +518,9 @@ func (s *Git) ScanCommits(ctx context.Context, repo *git.Repository, path string
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
}
return nil
Expand Down Expand Up @@ -559,6 +552,10 @@ func (s *Git) gitChunk(ctx context.Context, diff gitparse.Diff, fileName, email,
// TODO: Return error.
return
}
if err := reporter.ChunkOk(ctx, chunk); err != nil {
// TODO: Return error.
return
}

newChunkBuffer.Reset()
lastOffset = offset
Expand All @@ -579,6 +576,10 @@ func (s *Git) gitChunk(ctx context.Context, diff gitparse.Diff, fileName, email,
// TODO: Return error.
return
}
if err := reporter.ChunkOk(ctx, chunk); err != nil {
// TODO: Return error.
return
}
continue
}
}
Expand All @@ -603,6 +604,10 @@ func (s *Git) gitChunk(ctx context.Context, diff gitparse.Diff, fileName, email,
// TODO: Return error.
return
}
if err := reporter.ChunkOk(ctx, chunk); err != nil {
// TODO: Return error.
return
}
}
}

Expand Down Expand Up @@ -687,6 +692,9 @@ func (s *Git) ScanStaged(ctx context.Context, repo *git.Repository, path string,
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
if err := reporter.ChunkOk(ctx, chunk); err != nil {
return err
}
}
}
return nil
Expand Down Expand Up @@ -1003,6 +1011,44 @@ func handleBinary(ctx context.Context, repo *git.Repository, reporter sources.Ch
return nil
}

func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error {
for _, repo := range s.conn.GetDirectories() {
if repo == "" {
continue
}
unit := SourceUnit{ID: repo, Kind: UnitDir}
if err := reporter.UnitOk(ctx, unit); err != nil {
return err
}
}
for _, repo := range s.conn.GetRepositories() {
if repo == "" {
continue
}
unit := SourceUnit{ID: repo, Kind: UnitRepo}
if err := reporter.UnitOk(ctx, unit); err != nil {
return err
}
}
return nil
}

func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error {
gitUnit, ok := unit.(SourceUnit)
if !ok {
return fmt.Errorf("unsupported unit type: %T", unit)
}

switch gitUnit.Kind {
case UnitRepo:
return s.scanRepo(ctx, gitUnit.ID, reporter)
case UnitDir:
return s.scanDir(ctx, gitUnit.ID, reporter)
default:
return fmt.Errorf("unexpected git unit kind: %q", gitUnit.Kind)
}
}

func (s *Source) UnmarshalSourceUnit(data []byte) (sources.SourceUnit, error) {
return UnmarshalUnit(data)
}
70 changes: 70 additions & 0 deletions pkg/sources/git/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
"github.com/trufflesecurity/trufflehog/v3/pkg/sourcestest"
)

func TestSource_Scan(t *testing.T) {
Expand Down Expand Up @@ -504,3 +505,72 @@ func TestGitURLParse(t *testing.T) {
assert.Equal(t, tt.scheme, u.Scheme)
}
}

func TestEnumerate(t *testing.T) {
t.Parallel()
ctx := context.Background()

// Setup the connection to test enumeration.
units := []string{
"foo", "bar", "baz",
"/path/to/dir/", "/path/to/another/dir/",
}
conn, err := anypb.New(&sourcespb.Git{
Repositories: units[0:3],
Directories: units[3:],
})
assert.NoError(t, err)

// Initialize the source.
s := Source{}
err = s.Init(ctx, "test enumerate", 0, 0, true, conn, 1)
assert.NoError(t, err)

reporter := sourcestest.TestReporter{}
err = s.Enumerate(ctx, &reporter)
assert.NoError(t, err)

assert.Equal(t, len(units), len(reporter.Units))
assert.Equal(t, 0, len(reporter.UnitErrs))
for _, unit := range reporter.Units {
assert.Contains(t, units, unit.SourceUnitID())
}
for _, unit := range units[:3] {
assert.Contains(t, reporter.Units, SourceUnit{ID: unit, Kind: UnitRepo})
}
for _, unit := range units[3:] {
assert.Contains(t, reporter.Units, SourceUnit{ID: unit, Kind: UnitDir})
}
}

func TestChunkUnit(t *testing.T) {
t.Parallel()
ctx := context.Background()
// Initialize the source.
s := Source{}
conn, err := anypb.New(&sourcespb.Git{
Credential: &sourcespb.Git_Unauthenticated{},
})
assert.NoError(t, err)
err = s.Init(ctx, "test chunk", 0, 0, true, conn, 1)
assert.NoError(t, err)

reporter := sourcestest.TestReporter{}

// Happy path single repository.
err = s.ChunkUnit(ctx, SourceUnit{
ID: "https://github.com/dustin-decker/secretsandstuff.git",
Kind: UnitRepo,
}, &reporter)
assert.NoError(t, err)

// Error path.
err = s.ChunkUnit(ctx, SourceUnit{
ID: "/file/not/found",
Kind: UnitDir,
}, &reporter)
assert.NoError(t, err)

assert.Equal(t, 11, len(reporter.Chunks))
assert.Equal(t, 1, len(reporter.ChunkErrs))
}

0 comments on commit ef89775

Please sign in to comment.