Skip to content

Commit

Permalink
Add Display method to SourceUnit and Kind member to the CommonSourceUnit
Browse files Browse the repository at this point in the history
  • Loading branch information
mcastorina committed Feb 13, 2024
1 parent af7f811 commit 6a5da27
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 23 deletions.
6 changes: 5 additions & 1 deletion pkg/sources/filesystem/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,11 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e

// ChunkUnit implements SourceUnitChunker interface.
func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error {
path := unit.SourceUnitID()
commonUnit, err := sources.IntoCommonUnit(unit)
if err != nil {
return err
}
path := commonUnit.ID
logger := ctx.Logger().WithValues("path", path)

cleanPath := filepath.Clean(path)
Expand Down
4 changes: 3 additions & 1 deletion pkg/sources/filesystem/filesystem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ func TestEnumerate(t *testing.T) {
assert.Equal(t, len(units), len(reporter.Units))
assert.Equal(t, 0, len(reporter.UnitErrs))
for _, unit := range reporter.Units {
assert.Contains(t, units, unit.SourceUnitID())
commonUnit, err := sources.IntoCommonUnit(unit)
assert.NoError(t, err)
assert.Contains(t, units, commonUnit.ID)
}
for _, unit := range units {
assert.Contains(t, reporter.Units, sources.CommonSourceUnit{ID: unit})
Expand Down
6 changes: 3 additions & 3 deletions pkg/sources/git/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -1249,9 +1249,9 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e
}

func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error {
gitUnit, ok := unit.(SourceUnit)
if !ok {
return fmt.Errorf("unsupported unit type: %T", unit)
gitUnit, err := sources.IntoUnit[SourceUnit](unit)
if err != nil {
return err
}

switch gitUnit.Kind {
Expand Down
4 changes: 3 additions & 1 deletion pkg/sources/git/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,9 @@ func TestEnumerate(t *testing.T) {
assert.Equal(t, len(units), len(reporter.Units))
assert.Equal(t, 0, len(reporter.UnitErrs))
for _, unit := range reporter.Units {
assert.Contains(t, units, unit.SourceUnitID())
gitUnit, err := sources.IntoUnit[SourceUnit](unit)
assert.NoError(t, err)
assert.Contains(t, units, gitUnit.ID)
}
for _, unit := range units[:3] {
assert.Contains(t, reporter.Units, SourceUnit{ID: unit, Kind: UnitRepo})
Expand Down
31 changes: 30 additions & 1 deletion pkg/sources/git/unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package git
import (
"encoding/json"
"fmt"
"net/url"
"path/filepath"
"strings"

"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)
Expand All @@ -24,7 +27,33 @@ type SourceUnit struct {

// Implement sources.SourceUnit interface.
func (u SourceUnit) SourceUnitID() string {
return u.ID
return fmt.Sprintf("%s:%s", u.Kind, u.ID)
}

// Provide a custom Display method.
func (u SourceUnit) Display() string {
switch u.Kind {
case UnitRepo:
repo := u.ID
if parsedURL, err := url.Parse(u.ID); err == nil {
// scheme://host/owner/repo
repo = strings.TrimPrefix(parsedURL.Path, "/")
} else if _, path, found := strings.Cut(u.ID, ":"); found {
// git@host:owner/repo
// TODO: Is this possible? We should maybe canonicalize
// the URL before getting here.
repo = path
}
return strings.TrimSuffix(repo, ".git")
case UnitDir:
return filepath.Base(u.ID)
default:
return "mysterious git unit"
}
}

func (u SourceUnit) Type() string {
return u.Kind
}

// Helper function to unmarshal raw bytes into our SourceUnit struct.
Expand Down
42 changes: 42 additions & 0 deletions pkg/sources/git/unit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package git

import (
"encoding/json"
"testing"

"github.com/stretchr/testify/assert"
)

func TestUnmarshalUnit(t *testing.T) {
s := `{"kind":"repo","id":"https://github.com/trufflesecurity/test_keys.git"}`
expectedUnit := SourceUnit{ID: "https://github.com/trufflesecurity/test_keys.git", Kind: UnitRepo}
gotUnit, err := UnmarshalUnit([]byte(s))
assert.NoError(t, err)

assert.Equal(t, expectedUnit, gotUnit)

_, err = UnmarshalUnit(nil)
assert.Error(t, err)

_, err = UnmarshalUnit([]byte(`{"kind":"idk","id":"id"}`))
assert.Error(t, err)
}

func TestMarshalUnit(t *testing.T) {
unit := SourceUnit{ID: "https://github.com/trufflesecurity/test_keys.git", Kind: UnitRepo}
b, err := json.Marshal(unit)
assert.NoError(t, err)

assert.Equal(t, `{"kind":"repo","id":"https://github.com/trufflesecurity/test_keys.git"}`, string(b))
}

func TestDisplayUnit(t *testing.T) {
unit := SourceUnit{ID: "https://github.com/trufflesecurity/test_keys.git", Kind: UnitRepo}
assert.Equal(t, "trufflesecurity/test_keys", unit.Display())

unit = SourceUnit{ID: "/path/to/repo", Kind: UnitDir}
assert.Equal(t, "repo", unit.Display())

unit = SourceUnit{ID: "ssh://github.com/trufflesecurity/test_keys", Kind: UnitRepo}
assert.Equal(t, "trufflesecurity/test_keys", unit.Display())
}
23 changes: 17 additions & 6 deletions pkg/sources/gitlab/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
})
reporter := sources.VisitorReporter{
VisitUnit: func(ctx context.Context, unit sources.SourceUnit) error {
repos = append(repos, unit.SourceUnitID())
gitUnit, err := sources.IntoUnit[git.SourceUnit](unit)
if err != nil {
return err
}
repos = append(repos, gitUnit.ID)
return ctx.Err()
},
}
Expand Down Expand Up @@ -255,7 +259,11 @@ func (s *Source) Validate(ctx context.Context) []error {
var repos []string
visitor := sources.VisitorReporter{
VisitUnit: func(ctx context.Context, unit sources.SourceUnit) error {
repos = append(repos, unit.SourceUnitID())
gitUnit, err := sources.IntoUnit[git.SourceUnit](unit)
if err != nil {
return err
}
repos = append(repos, gitUnit.ID)
return nil
},
}
Expand Down Expand Up @@ -362,7 +370,7 @@ func (s *Source) getAllProjectRepos(
continue
}
// Report the unit.
unit := sources.CommonSourceUnit{ID: proj.HTTPURLToRepo}
unit := git.SourceUnit{Kind: git.UnitRepo, ID: proj.HTTPURLToRepo}
if err := reporter.UnitOk(ctx, unit); err != nil {
return err
}
Expand Down Expand Up @@ -616,7 +624,7 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e
// Report all repos if specified.
if len(repos) > 0 {
for _, repo := range repos {
unit := sources.CommonSourceUnit{ID: repo}
unit := git.SourceUnit{Kind: git.UnitRepo, ID: repo}
if err := reporter.UnitOk(ctx, unit); err != nil {
return err
}
Expand All @@ -635,11 +643,14 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e

// ChunkUnit downloads and reports chunks for the given GitLab repository unit.
func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error {
repoURL := unit.SourceUnitID()
gitUnit, err := sources.IntoUnit[git.SourceUnit](unit)
if err != nil {
return err
}
repoURL := gitUnit.ID

var path string
var repo *gogit.Repository
var err error
if s.authMethod == "UNAUTHENTICATED" {
path, repo, err = git.CloneRepoUsingUnauthenticated(ctx, repoURL)
} else {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sources/job_progress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestJobProgressHook(t *testing.T) {
startChunk := time.Now().Add(40 * time.Second)
endChunk := time.Now().Add(50 * time.Second)
reportErr := fmt.Errorf("reporting error")
reportUnit := CommonSourceUnit{"reporting unit"}
reportUnit := CommonSourceUnit{ID: "reporting unit"}
reportChunk := &Chunk{Data: []byte("reporting chunk")}

hook.EXPECT().Start(gomock.Any(), startTime)
Expand Down
16 changes: 11 additions & 5 deletions pkg/sources/source_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ type countChunk byte

func (c countChunk) SourceUnitID() string { return fmt.Sprintf("countChunk(%d)", c) }

func (c countChunk) Display() string { return c.SourceUnitID() }

func (c *counterChunker) Enumerate(ctx context.Context, reporter UnitReporter) error {
for i := 0; i < c.count; i++ {
if err := reporter.UnitOk(ctx, countChunk(byte(i))); err != nil {
Expand Down Expand Up @@ -187,15 +189,19 @@ func (c *unitChunker) Chunks(ctx context.Context, ch chan *Chunk, _ ...ChunkingT
}
func (c *unitChunker) Enumerate(ctx context.Context, rep UnitReporter) error {
for _, step := range c.steps {
if err := rep.UnitOk(ctx, CommonSourceUnit{step.unit}); err != nil {
if err := rep.UnitOk(ctx, CommonSourceUnit{ID: step.unit}); err != nil {
return err
}
}
return nil
}
func (c *unitChunker) ChunkUnit(ctx context.Context, unit SourceUnit, rep ChunkReporter) error {
for _, step := range c.steps {
if unit.SourceUnitID() != step.unit {
commonUnit, err := IntoCommonUnit(unit)
if err != nil {
return err
}
if commonUnit.ID != step.unit {
continue
}
if step.err != "" {
Expand Down Expand Up @@ -345,23 +351,23 @@ func TestSourceManagerUnitHook(t *testing.T) {
})
m0, m1, m2 := metrics[0], metrics[1], metrics[2]

assert.Equal(t, "1 one", m0.Unit.SourceUnitID())
assert.Equal(t, "unit:1 one", m0.Unit.SourceUnitID())
assert.Equal(t, uint64(1), m0.TotalChunks)
assert.Equal(t, uint64(3), m0.TotalBytes)
assert.NotZero(t, m0.StartTime)
assert.NotZero(t, m0.EndTime)
assert.NotZero(t, m0.ElapsedTime())
assert.Equal(t, 0, len(m0.Errors))

assert.Equal(t, "2 two", m1.Unit.SourceUnitID())
assert.Equal(t, "unit:2 two", m1.Unit.SourceUnitID())
assert.Equal(t, uint64(0), m1.TotalChunks)
assert.Equal(t, uint64(0), m1.TotalBytes)
assert.NotZero(t, m1.StartTime)
assert.NotZero(t, m1.EndTime)
assert.NotZero(t, m1.ElapsedTime())
assert.Equal(t, 1, len(m1.Errors))

assert.Equal(t, "3 three", m2.Unit.SourceUnitID())
assert.Equal(t, "unit:3 three", m2.Unit.SourceUnitID())
assert.Equal(t, uint64(0), m2.TotalChunks)
assert.Equal(t, uint64(0), m2.TotalBytes)
assert.NotZero(t, m2.StartTime)
Expand Down
24 changes: 23 additions & 1 deletion pkg/sources/source_unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,20 @@ var _ SourceUnit = CommonSourceUnit{}
// CommonSourceUnit is a common implementation of SourceUnit that Sources can
// use instead of implementing their own types.
type CommonSourceUnit struct {
ID string `json:"source_unit_id"`
Kind string `json:"kind,omitempty"`
ID string `json:"id"`
}

// SourceUnitID implements the SourceUnit interface.
func (c CommonSourceUnit) SourceUnitID() string {
kind := "unit"
if c.Kind != "" {
kind = c.Kind
}
return fmt.Sprintf("%s:%s", kind, c.ID)
}

func (c CommonSourceUnit) Display() string {
return c.ID
}

Expand All @@ -35,3 +44,16 @@ func (c CommonSourceUnitUnmarshaller) UnmarshalSourceUnit(data []byte) (SourceUn
}
return unit, nil
}

func IntoUnit[T any](unit SourceUnit) (T, error) {
tUnit, ok := unit.(T)
if !ok {
var t T
return t, fmt.Errorf("unsupported unit type: %T", unit)
}
return tUnit, nil
}

func IntoCommonUnit(unit SourceUnit) (CommonSourceUnit, error) {
return IntoUnit[CommonSourceUnit](unit)
}
7 changes: 6 additions & 1 deletion pkg/sources/sources.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,13 @@ type ChunkReporter interface {
// SourceUnit is an object that represents a Source's unit of work. This is
// used as the output of enumeration, progress reporting, and job distribution.
type SourceUnit interface {
// SourceUnitID uniquely identifies a source unit.
// SourceUnitID uniquely identifies a source unit. It does not need to
// be human readable or two-way, however, it should be canonical and
// stable across runs.
SourceUnitID() string

// Display is the human readable representation of the SourceUnit.
Display() string
}

// GCSConfig defines the optional configuration for a GCS source.
Expand Down
9 changes: 7 additions & 2 deletions pkg/sources/travisci/travisci.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e

for _, repo := range repositories {
err = reporter.UnitOk(ctx, sources.CommonSourceUnit{
ID: strconv.Itoa(int(*repo.Id)),
ID: strconv.Itoa(int(*repo.Id)),
Kind: "repo",
})
if err != nil {
return fmt.Errorf("error reporting unit: %w", err)
Expand All @@ -125,7 +126,11 @@ func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) e

// ChunkUnit implements SourceUnitChunker interface.
func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporter sources.ChunkReporter) error {
repo, _, err := s.client.Repositories.Find(ctx, unit.SourceUnitID(), nil)
commonUnit, err := sources.IntoCommonUnit(unit)
if err != nil {
return err
}
repo, _, err := s.client.Repositories.Find(ctx, commonUnit.ID, nil)
if err != nil {
return fmt.Errorf("error finding repository: %w", err)
}
Expand Down

0 comments on commit 6a5da27

Please sign in to comment.