Skip to content

Commit

Permalink
[chore] Add SourceUnitEnumChunker filesystem tests (trufflesecurity#1873
Browse files Browse the repository at this point in the history
)

* [chore] Add SourceUnitEnumChunker filesystem tests

* Ensure reported units are exactly what is expected
  • Loading branch information
mcastorina authored and Phoenix591 committed Oct 27, 2023
1 parent 543b828 commit d50b686
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 12 deletions.
3 changes: 1 addition & 2 deletions pkg/sources/filesystem/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ type Source struct {
// Ensure the Source satisfies the interfaces at compile time
var _ sources.Source = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
var _ sources.SourceUnitEnumerator = (*Source)(nil)
var _ sources.SourceUnitChunker = (*Source)(nil)
var _ sources.SourceUnitEnumChunker = (*Source)(nil)

// Type returns the type of source.
// It is used for matching source types in configuration and job input.
Expand Down
205 changes: 195 additions & 10 deletions pkg/sources/filesystem/filesystem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
"github.com/trufflesecurity/trufflehog/v3/pkg/sourcestest"
)

func TestSource_Scan(t *testing.T) {
Expand Down Expand Up @@ -84,23 +85,15 @@ func TestSource_Scan(t *testing.T) {
}

func TestScanFile(t *testing.T) {
tmpfile, err := os.CreateTemp("", "example.txt")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())

chunkSize := sources.ChunkSize
secretPart1 := "SECRET"
secretPart2 := "SPLIT"
// Split the secret into two parts and pad the rest of the chunk with A's.
data := strings.Repeat("A", chunkSize-len(secretPart1)) + secretPart1 + secretPart2 + strings.Repeat("A", chunkSize-len(secretPart2))

_, err = tmpfile.Write([]byte(data))
assert.Nil(t, err)

err = tmpfile.Close()
tmpfile, cleanup, err := createTempFile("", data)
assert.Nil(t, err)
defer cleanup()

source := &Source{}
chunksChan := make(chan *sources.Chunk, 2)
Expand All @@ -120,3 +113,195 @@ func TestScanFile(t *testing.T) {

assert.Contains(t, foundSecret, secretPart1+secretPart2)
}

func TestEnumerate(t *testing.T) {
t.Parallel()
ctx := context.Background()

// Setup the connection to test enumeration.
units := []string{
"/one", "/two", "/three",
"/path/to/dir/", "/path/to/another/dir/",
}
conn, err := anypb.New(&sourcespb.Filesystem{
Paths: units[0:3],
Directories: units[3:],
})
assert.NoError(t, err)

// Initialize the source.
s := Source{}
err = s.Init(ctx, "test enumerate", 0, 0, true, conn, 1)
assert.NoError(t, err)

reporter := sourcestest.TestReporter{}
err = s.Enumerate(ctx, &reporter)
assert.NoError(t, err)

assert.Equal(t, len(units), len(reporter.Units))
assert.Equal(t, 0, len(reporter.UnitErrs))
for _, unit := range reporter.Units {
assert.Contains(t, units, unit.SourceUnitID())
}
for _, unit := range units {
assert.Contains(t, reporter.Units, sources.CommonSourceUnit{ID: unit})
}
}

func TestChunkUnit(t *testing.T) {
t.Parallel()
ctx := context.Background()

// Setup test file to chunk.
fileContents := "TestChunkUnit"
tmpfile, cleanup, err := createTempFile("", fileContents)
assert.NoError(t, err)
defer cleanup()

tmpdir, cleanup, err := createTempDir("", "foo", "bar", "baz")
assert.NoError(t, err)
defer cleanup()

conn, err := anypb.New(&sourcespb.Filesystem{})
assert.NoError(t, err)

// Initialize the source.
s := Source{}
err = s.Init(ctx, "test chunk unit", 0, 0, true, conn, 1)
assert.NoError(t, err)

// Happy path single file.
reporter := sourcestest.TestReporter{}
err = s.ChunkUnit(ctx, sources.CommonSourceUnit{
ID: tmpfile.Name(),
}, &reporter)
assert.NoError(t, err)

// Happy path directory.
err = s.ChunkUnit(ctx, sources.CommonSourceUnit{
ID: tmpdir,
}, &reporter)
assert.NoError(t, err)

// Error path.
err = s.ChunkUnit(ctx, sources.CommonSourceUnit{
ID: "/file/not/found",
}, &reporter)
assert.NoError(t, err)

assert.Equal(t, 4, len(reporter.Chunks))
assert.Equal(t, 1, len(reporter.ChunkErrs))
dataFound := make(map[string]struct{}, 4)
for _, chunk := range reporter.Chunks {
dataFound[string(chunk.Data)] = struct{}{}
}
assert.Contains(t, dataFound, fileContents)
assert.Contains(t, dataFound, "foo")
assert.Contains(t, dataFound, "bar")
assert.Contains(t, dataFound, "baz")
}

func TestEnumerateReporterErr(t *testing.T) {
t.Parallel()
ctx := context.Background()

// Setup the connection to test enumeration.
units := []string{
"/one", "/two", "/three",
"/path/to/dir/", "/path/to/another/dir/",
}
conn, err := anypb.New(&sourcespb.Filesystem{
Paths: units[0:3],
Directories: units[3:],
})
assert.NoError(t, err)

// Initialize the source.
s := Source{}
err = s.Init(ctx, "test enumerate", 0, 0, true, conn, 1)
assert.NoError(t, err)

// Enumerate should always return an error if the reporter returns an
// error.
reporter := sourcestest.ErrReporter{}
err = s.Enumerate(ctx, &reporter)
assert.Error(t, err)
}

func TestChunkUnitReporterErr(t *testing.T) {
t.Parallel()
ctx := context.Background()

// Setup test file to chunk.
tmpfile, err := os.CreateTemp("", "example.txt")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())

fileContents := []byte("TestChunkUnit")
_, err = tmpfile.Write(fileContents)
assert.NoError(t, err)
assert.NoError(t, tmpfile.Close())

conn, err := anypb.New(&sourcespb.Filesystem{})
assert.NoError(t, err)

// Initialize the source.
s := Source{}
err = s.Init(ctx, "test chunk unit", 0, 0, true, conn, 1)
assert.NoError(t, err)

// Happy path. ChunkUnit should always return an error if the reporter
// returns an error.
reporter := sourcestest.ErrReporter{}
err = s.ChunkUnit(ctx, sources.CommonSourceUnit{
ID: tmpfile.Name(),
}, &reporter)
assert.Error(t, err)

// Error path. ChunkUnit should always return an error if the reporter
// returns an error.
err = s.ChunkUnit(ctx, sources.CommonSourceUnit{
ID: "/file/not/found",
}, &reporter)
assert.Error(t, err)
}

// createTempFile is a helper function to create a temporary file in the given
// directory with the provided contents. If dir is "", the operating system's
// temp directory is used.
func createTempFile(dir string, contents string) (*os.File, func(), error) {
tmpfile, err := os.CreateTemp(dir, "trufflehogtest")
if err != nil {
return nil, nil, err
}

if _, err := tmpfile.Write([]byte(contents)); err != nil {
_ = os.Remove(tmpfile.Name())
return nil, nil, err
}
if err := tmpfile.Close(); err != nil {
_ = os.Remove(tmpfile.Name())
return nil, nil, err
}
return tmpfile, func() { _ = os.Remove(tmpfile.Name()) }, nil
}

// createTempDir is a helper function to create a temporary directory in the
// given directory with files containing the provided contents. If dir is "",
// the operating system's temp directory is used.
func createTempDir(dir string, contents ...string) (string, func(), error) {
tmpdir, err := os.MkdirTemp(dir, "trufflehogtest")
if err != nil {
return "", nil, err
}

for _, content := range contents {
if _, _, err := createTempFile(tmpdir, content); err != nil {
_ = os.RemoveAll(tmpdir)
return "", nil, err
}
}
return tmpdir, func() { _ = os.RemoveAll(tmpdir) }, nil
}
61 changes: 61 additions & 0 deletions pkg/sourcestest/sourcestest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package sourcestest

import (
"fmt"

"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)

type reporter interface {
sources.UnitReporter
sources.ChunkReporter
}

var (
_ reporter = (*TestReporter)(nil)
_ reporter = (*ErrReporter)(nil)
)

// TestReporter is a helper struct that implements both UnitReporter and
// ChunkReporter by simply recording the values passed in the methods.
type TestReporter struct {
Units []sources.SourceUnit
UnitErrs []error
Chunks []sources.Chunk
ChunkErrs []error
}

func (t *TestReporter) UnitOk(_ context.Context, unit sources.SourceUnit) error {
t.Units = append(t.Units, unit)
return nil
}
func (t *TestReporter) UnitErr(_ context.Context, err error) error {
t.UnitErrs = append(t.UnitErrs, err)
return nil
}
func (t *TestReporter) ChunkOk(_ context.Context, chunk sources.Chunk) error {
t.Chunks = append(t.Chunks, chunk)
return nil
}
func (t *TestReporter) ChunkErr(_ context.Context, err error) error {
t.ChunkErrs = append(t.ChunkErrs, err)
return nil
}

// ErrReporter implements UnitReporter and ChunkReporter but always returns an
// error.
type ErrReporter struct{}

func (ErrReporter) UnitOk(context.Context, sources.SourceUnit) error {
return fmt.Errorf("ErrReporter: UnitOk error")
}
func (ErrReporter) UnitErr(context.Context, error) error {
return fmt.Errorf("ErrReporter: UnitErr error")
}
func (ErrReporter) ChunkOk(context.Context, sources.Chunk) error {
return fmt.Errorf("ErrReporter: ChunkOk error")
}
func (ErrReporter) ChunkErr(context.Context, error) error {
return fmt.Errorf("ErrReporter: ChunkErr error")
}

0 comments on commit d50b686

Please sign in to comment.