diff --git a/pkg/sources/filesystem/filesystem.go b/pkg/sources/filesystem/filesystem.go index 1d3d0cc8b7d7..e7a90d9e9188 100644 --- a/pkg/sources/filesystem/filesystem.go +++ b/pkg/sources/filesystem/filesystem.go @@ -270,7 +270,11 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e // ChunkUnit implements SourceUnitChunker interface. func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error { - path := unit.SourceUnitID() + commonUnit, err := sources.IntoCommonUnit(unit) + if err != nil { + return err + } + path := commonUnit.ID logger := ctx.Logger().WithValues("path", path) cleanPath := filepath.Clean(path) diff --git a/pkg/sources/filesystem/filesystem_test.go b/pkg/sources/filesystem/filesystem_test.go index 4a846a6737cc..0eaed5808944 100644 --- a/pkg/sources/filesystem/filesystem_test.go +++ b/pkg/sources/filesystem/filesystem_test.go @@ -172,7 +172,9 @@ func TestEnumerate(t *testing.T) { assert.Equal(t, len(units), len(reporter.Units)) assert.Equal(t, 0, len(reporter.UnitErrs)) for _, unit := range reporter.Units { - assert.Contains(t, units, unit.SourceUnitID()) + commonUnit, err := sources.IntoCommonUnit(unit) + assert.NoError(t, err) + assert.Contains(t, units, commonUnit.ID) } for _, unit := range units { assert.Contains(t, reporter.Units, sources.CommonSourceUnit{ID: unit}) diff --git a/pkg/sources/git/git.go b/pkg/sources/git/git.go index d4577d28a302..bbcabda9b69d 100644 --- a/pkg/sources/git/git.go +++ b/pkg/sources/git/git.go @@ -1249,9 +1249,9 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e } func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error { - gitUnit, ok := unit.(SourceUnit) - if !ok { - return fmt.Errorf("unsupported unit type: %T", unit) + gitUnit, err := sources.IntoUnit[SourceUnit](unit) + if err != nil { + return err } switch gitUnit.Kind { diff --git a/pkg/sources/git/git_test.go b/pkg/sources/git/git_test.go index cc9d31756d3d..27a4a1e5d04b 100644 --- a/pkg/sources/git/git_test.go +++ b/pkg/sources/git/git_test.go @@ -533,7 +533,9 @@ func TestEnumerate(t *testing.T) { assert.Equal(t, len(units), len(reporter.Units)) assert.Equal(t, 0, len(reporter.UnitErrs)) for _, unit := range reporter.Units { - assert.Contains(t, units, unit.SourceUnitID()) + gitUnit, err := sources.IntoUnit[SourceUnit](unit) + assert.NoError(t, err) + assert.Contains(t, units, gitUnit.ID) } for _, unit := range units[:3] { assert.Contains(t, reporter.Units, SourceUnit{ID: unit, Kind: UnitRepo}) diff --git a/pkg/sources/git/unit.go b/pkg/sources/git/unit.go index 9457e23a6ef5..01c8b7033a9e 100644 --- a/pkg/sources/git/unit.go +++ b/pkg/sources/git/unit.go @@ -3,6 +3,9 @@ package git import ( "encoding/json" "fmt" + "net/url" + "path/filepath" + "strings" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" ) @@ -24,7 +27,33 @@ type SourceUnit struct { // Implement sources.SourceUnit interface. func (u SourceUnit) SourceUnitID() string { - return u.ID + return fmt.Sprintf("%s:%s", u.Kind, u.ID) +} + +// Provide a custom Display method. +func (u SourceUnit) Display() string { + switch u.Kind { + case UnitRepo: + repo := u.ID + if parsedURL, err := url.Parse(u.ID); err == nil { + // scheme://host/owner/repo + repo = strings.TrimPrefix(parsedURL.Path, "/") + } else if _, path, found := strings.Cut(u.ID, ":"); found { + // git@host:owner/repo + // TODO: Is this possible? We should maybe canonicalize + // the URL before getting here. + repo = path + } + return strings.TrimSuffix(repo, ".git") + case UnitDir: + return filepath.Base(u.ID) + default: + return "mysterious git unit" + } +} + +func (u SourceUnit) Type() string { + return u.Kind } // Helper function to unmarshal raw bytes into our SourceUnit struct. diff --git a/pkg/sources/git/unit_test.go b/pkg/sources/git/unit_test.go new file mode 100644 index 000000000000..e5ac45843c21 --- /dev/null +++ b/pkg/sources/git/unit_test.go @@ -0,0 +1,42 @@ +package git + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnmarshalUnit(t *testing.T) { + s := `{"kind":"repo","id":"https://github.com/trufflesecurity/test_keys.git"}` + expectedUnit := SourceUnit{ID: "https://github.com/trufflesecurity/test_keys.git", Kind: UnitRepo} + gotUnit, err := UnmarshalUnit([]byte(s)) + assert.NoError(t, err) + + assert.Equal(t, expectedUnit, gotUnit) + + _, err = UnmarshalUnit(nil) + assert.Error(t, err) + + _, err = UnmarshalUnit([]byte(`{"kind":"idk","id":"id"}`)) + assert.Error(t, err) +} + +func TestMarshalUnit(t *testing.T) { + unit := SourceUnit{ID: "https://github.com/trufflesecurity/test_keys.git", Kind: UnitRepo} + b, err := json.Marshal(unit) + assert.NoError(t, err) + + assert.Equal(t, `{"kind":"repo","id":"https://github.com/trufflesecurity/test_keys.git"}`, string(b)) +} + +func TestDisplayUnit(t *testing.T) { + unit := SourceUnit{ID: "https://github.com/trufflesecurity/test_keys.git", Kind: UnitRepo} + assert.Equal(t, "trufflesecurity/test_keys", unit.Display()) + + unit = SourceUnit{ID: "/path/to/repo", Kind: UnitDir} + assert.Equal(t, "repo", unit.Display()) + + unit = SourceUnit{ID: "ssh://github.com/trufflesecurity/test_keys", Kind: UnitRepo} + assert.Equal(t, "trufflesecurity/test_keys", unit.Display()) +} diff --git a/pkg/sources/gitlab/gitlab.go b/pkg/sources/gitlab/gitlab.go index 850d153b3091..ba75e029b505 100644 --- a/pkg/sources/gitlab/gitlab.go +++ b/pkg/sources/gitlab/gitlab.go @@ -186,7 +186,11 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ . }) reporter := sources.VisitorReporter{ VisitUnit: func(ctx context.Context, unit sources.SourceUnit) error { - repos = append(repos, unit.SourceUnitID()) + gitUnit, err := sources.IntoUnit[git.SourceUnit](unit) + if err != nil { + return err + } + repos = append(repos, gitUnit.ID) return ctx.Err() }, } @@ -255,7 +259,11 @@ func (s *Source) Validate(ctx context.Context) []error { var repos []string visitor := sources.VisitorReporter{ VisitUnit: func(ctx context.Context, unit sources.SourceUnit) error { - repos = append(repos, unit.SourceUnitID()) + gitUnit, err := sources.IntoUnit[git.SourceUnit](unit) + if err != nil { + return err + } + repos = append(repos, gitUnit.ID) return nil }, } @@ -362,7 +370,7 @@ func (s *Source) getAllProjectRepos( continue } // Report the unit. - unit := sources.CommonSourceUnit{ID: proj.HTTPURLToRepo} + unit := git.SourceUnit{Kind: git.UnitRepo, ID: proj.HTTPURLToRepo} if err := reporter.UnitOk(ctx, unit); err != nil { return err } @@ -616,7 +624,7 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e // Report all repos if specified. if len(repos) > 0 { for _, repo := range repos { - unit := sources.CommonSourceUnit{ID: repo} + unit := git.SourceUnit{Kind: git.UnitRepo, ID: repo} if err := reporter.UnitOk(ctx, unit); err != nil { return err } @@ -635,11 +643,14 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e // ChunkUnit downloads and reports chunks for the given GitLab repository unit. func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error { - repoURL := unit.SourceUnitID() + gitUnit, err := sources.IntoUnit[git.SourceUnit](unit) + if err != nil { + return err + } + repoURL := gitUnit.ID var path string var repo *gogit.Repository - var err error if s.authMethod == "UNAUTHENTICATED" { path, repo, err = git.CloneRepoUsingUnauthenticated(ctx, repoURL) } else { diff --git a/pkg/sources/job_progress_test.go b/pkg/sources/job_progress_test.go index 05d4c6e6ae86..20ba0678550a 100644 --- a/pkg/sources/job_progress_test.go +++ b/pkg/sources/job_progress_test.go @@ -81,7 +81,7 @@ func TestJobProgressHook(t *testing.T) { startChunk := time.Now().Add(40 * time.Second) endChunk := time.Now().Add(50 * time.Second) reportErr := fmt.Errorf("reporting error") - reportUnit := CommonSourceUnit{"reporting unit"} + reportUnit := CommonSourceUnit{ID: "reporting unit"} reportChunk := &Chunk{Data: []byte("reporting chunk")} hook.EXPECT().Start(gomock.Any(), startTime) diff --git a/pkg/sources/source_manager_test.go b/pkg/sources/source_manager_test.go index 815f0507e5bc..1e35d45fbe11 100644 --- a/pkg/sources/source_manager_test.go +++ b/pkg/sources/source_manager_test.go @@ -62,6 +62,8 @@ type countChunk byte func (c countChunk) SourceUnitID() string { return fmt.Sprintf("countChunk(%d)", c) } +func (c countChunk) Display() string { return c.SourceUnitID() } + func (c *counterChunker) Enumerate(ctx context.Context, reporter UnitReporter) error { for i := 0; i < c.count; i++ { if err := reporter.UnitOk(ctx, countChunk(byte(i))); err != nil { @@ -187,7 +189,7 @@ func (c *unitChunker) Chunks(ctx context.Context, ch chan *Chunk, _ ...ChunkingT } func (c *unitChunker) Enumerate(ctx context.Context, rep UnitReporter) error { for _, step := range c.steps { - if err := rep.UnitOk(ctx, CommonSourceUnit{step.unit}); err != nil { + if err := rep.UnitOk(ctx, CommonSourceUnit{ID: step.unit}); err != nil { return err } } @@ -195,7 +197,11 @@ func (c *unitChunker) Enumerate(ctx context.Context, rep UnitReporter) error { } func (c *unitChunker) ChunkUnit(ctx context.Context, unit SourceUnit, rep ChunkReporter) error { for _, step := range c.steps { - if unit.SourceUnitID() != step.unit { + commonUnit, err := IntoCommonUnit(unit) + if err != nil { + return err + } + if commonUnit.ID != step.unit { continue } if step.err != "" { @@ -345,7 +351,7 @@ func TestSourceManagerUnitHook(t *testing.T) { }) m0, m1, m2 := metrics[0], metrics[1], metrics[2] - assert.Equal(t, "1 one", m0.Unit.SourceUnitID()) + assert.Equal(t, "unit:1 one", m0.Unit.SourceUnitID()) assert.Equal(t, uint64(1), m0.TotalChunks) assert.Equal(t, uint64(3), m0.TotalBytes) assert.NotZero(t, m0.StartTime) @@ -353,7 +359,7 @@ func TestSourceManagerUnitHook(t *testing.T) { assert.NotZero(t, m0.ElapsedTime()) assert.Equal(t, 0, len(m0.Errors)) - assert.Equal(t, "2 two", m1.Unit.SourceUnitID()) + assert.Equal(t, "unit:2 two", m1.Unit.SourceUnitID()) assert.Equal(t, uint64(0), m1.TotalChunks) assert.Equal(t, uint64(0), m1.TotalBytes) assert.NotZero(t, m1.StartTime) @@ -361,7 +367,7 @@ func TestSourceManagerUnitHook(t *testing.T) { assert.NotZero(t, m1.ElapsedTime()) assert.Equal(t, 1, len(m1.Errors)) - assert.Equal(t, "3 three", m2.Unit.SourceUnitID()) + assert.Equal(t, "unit:3 three", m2.Unit.SourceUnitID()) assert.Equal(t, uint64(0), m2.TotalChunks) assert.Equal(t, uint64(0), m2.TotalBytes) assert.NotZero(t, m2.StartTime) diff --git a/pkg/sources/source_unit.go b/pkg/sources/source_unit.go index 8eac88430431..b6399c40c481 100644 --- a/pkg/sources/source_unit.go +++ b/pkg/sources/source_unit.go @@ -11,11 +11,20 @@ var _ SourceUnit = CommonSourceUnit{} // CommonSourceUnit is a common implementation of SourceUnit that Sources can // use instead of implementing their own types. type CommonSourceUnit struct { - ID string `json:"source_unit_id"` + Kind string `json:"kind,omitempty"` + ID string `json:"id"` } // SourceUnitID implements the SourceUnit interface. func (c CommonSourceUnit) SourceUnitID() string { + kind := "unit" + if c.Kind != "" { + kind = c.Kind + } + return fmt.Sprintf("%s:%s", kind, c.ID) +} + +func (c CommonSourceUnit) Display() string { return c.ID } @@ -35,3 +44,16 @@ func (c CommonSourceUnitUnmarshaller) UnmarshalSourceUnit(data []byte) (SourceUn } return unit, nil } + +func IntoUnit[T any](unit SourceUnit) (T, error) { + tUnit, ok := unit.(T) + if !ok { + var t T + return t, fmt.Errorf("unsupported unit type: %T", unit) + } + return tUnit, nil +} + +func IntoCommonUnit(unit SourceUnit) (CommonSourceUnit, error) { + return IntoUnit[CommonSourceUnit](unit) +} diff --git a/pkg/sources/sources.go b/pkg/sources/sources.go index 6d895887101a..72e6fd2e66d1 100644 --- a/pkg/sources/sources.go +++ b/pkg/sources/sources.go @@ -133,8 +133,13 @@ type ChunkReporter interface { // SourceUnit is an object that represents a Source's unit of work. This is // used as the output of enumeration, progress reporting, and job distribution. type SourceUnit interface { - // SourceUnitID uniquely identifies a source unit. + // SourceUnitID uniquely identifies a source unit. It does not need to + // be human readable or two-way, however, it should be canonical and + // stable across runs. SourceUnitID() string + + // Display is the human readable representation of the SourceUnit. + Display() string } // GCSConfig defines the optional configuration for a GCS source. diff --git a/pkg/sources/travisci/travisci.go b/pkg/sources/travisci/travisci.go index cfe457dbc1cc..7bafc9c574f2 100644 --- a/pkg/sources/travisci/travisci.go +++ b/pkg/sources/travisci/travisci.go @@ -111,7 +111,8 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e for _, repo := range repositories { err = reporter.UnitOk(ctx, sources.CommonSourceUnit{ - ID: strconv.Itoa(int(*repo.Id)), + ID: strconv.Itoa(int(*repo.Id)), + Kind: "repo", }) if err != nil { return fmt.Errorf("error reporting unit: %w", err) @@ -125,7 +126,11 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e // ChunkUnit implements SourceUnitChunker interface. func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error { - repo, _, err := s.client.Repositories.Find(ctx, unit.SourceUnitID(), nil) + commonUnit, err := sources.IntoCommonUnit(unit) + if err != nil { + return err + } + repo, _, err := s.client.Repositories.Find(ctx, commonUnit.ID, nil) if err != nil { return fmt.Errorf("error finding repository: %w", err) }