From d50b68674a2d721f71d312d31e2bae7b615ace3b Mon Sep 17 00:00:00 2001 From: Miccah Date: Mon, 16 Oct 2023 10:42:18 -0700 Subject: [PATCH] [chore] Add SourceUnitEnumChunker filesystem tests (#1873) * [chore] Add SourceUnitEnumChunker filesystem tests * Ensure reported units are exactly what is expected --- pkg/sources/filesystem/filesystem.go | 3 +- pkg/sources/filesystem/filesystem_test.go | 205 ++++++++++++++++++++-- pkg/sourcestest/sourcestest.go | 61 +++++++ 3 files changed, 257 insertions(+), 12 deletions(-) create mode 100644 pkg/sourcestest/sourcestest.go diff --git a/pkg/sources/filesystem/filesystem.go b/pkg/sources/filesystem/filesystem.go index 2e7b71785a14..64eca62ce0cb 100644 --- a/pkg/sources/filesystem/filesystem.go +++ b/pkg/sources/filesystem/filesystem.go @@ -39,8 +39,7 @@ type Source struct { // Ensure the Source satisfies the interfaces at compile time var _ sources.Source = (*Source)(nil) var _ sources.SourceUnitUnmarshaller = (*Source)(nil) -var _ sources.SourceUnitEnumerator = (*Source)(nil) -var _ sources.SourceUnitChunker = (*Source)(nil) +var _ sources.SourceUnitEnumChunker = (*Source)(nil) // Type returns the type of source. // It is used for matching source types in configuration and job input. diff --git a/pkg/sources/filesystem/filesystem_test.go b/pkg/sources/filesystem/filesystem_test.go index ef5dfbdd3059..e6925fafe45a 100644 --- a/pkg/sources/filesystem/filesystem_test.go +++ b/pkg/sources/filesystem/filesystem_test.go @@ -15,6 +15,7 @@ import ( "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" + "github.com/trufflesecurity/trufflehog/v3/pkg/sourcestest" ) func TestSource_Scan(t *testing.T) { @@ -84,23 +85,15 @@ func TestSource_Scan(t *testing.T) { } func TestScanFile(t *testing.T) { - tmpfile, err := os.CreateTemp("", "example.txt") - if err != nil { - t.Fatal(err) - } - defer os.Remove(tmpfile.Name()) - chunkSize := sources.ChunkSize secretPart1 := "SECRET" secretPart2 := "SPLIT" // Split the secret into two parts and pad the rest of the chunk with A's. data := strings.Repeat("A", chunkSize-len(secretPart1)) + secretPart1 + secretPart2 + strings.Repeat("A", chunkSize-len(secretPart2)) - _, err = tmpfile.Write([]byte(data)) - assert.Nil(t, err) - - err = tmpfile.Close() + tmpfile, cleanup, err := createTempFile("", data) assert.Nil(t, err) + defer cleanup() source := &Source{} chunksChan := make(chan *sources.Chunk, 2) @@ -120,3 +113,195 @@ func TestScanFile(t *testing.T) { assert.Contains(t, foundSecret, secretPart1+secretPart2) } + +func TestEnumerate(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Setup the connection to test enumeration. + units := []string{ + "/one", "/two", "/three", + "/path/to/dir/", "/path/to/another/dir/", + } + conn, err := anypb.New(&sourcespb.Filesystem{ + Paths: units[0:3], + Directories: units[3:], + }) + assert.NoError(t, err) + + // Initialize the source. + s := Source{} + err = s.Init(ctx, "test enumerate", 0, 0, true, conn, 1) + assert.NoError(t, err) + + reporter := sourcestest.TestReporter{} + err = s.Enumerate(ctx, &reporter) + assert.NoError(t, err) + + assert.Equal(t, len(units), len(reporter.Units)) + assert.Equal(t, 0, len(reporter.UnitErrs)) + for _, unit := range reporter.Units { + assert.Contains(t, units, unit.SourceUnitID()) + } + for _, unit := range units { + assert.Contains(t, reporter.Units, sources.CommonSourceUnit{ID: unit}) + } +} + +func TestChunkUnit(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Setup test file to chunk. + fileContents := "TestChunkUnit" + tmpfile, cleanup, err := createTempFile("", fileContents) + assert.NoError(t, err) + defer cleanup() + + tmpdir, cleanup, err := createTempDir("", "foo", "bar", "baz") + assert.NoError(t, err) + defer cleanup() + + conn, err := anypb.New(&sourcespb.Filesystem{}) + assert.NoError(t, err) + + // Initialize the source. + s := Source{} + err = s.Init(ctx, "test chunk unit", 0, 0, true, conn, 1) + assert.NoError(t, err) + + // Happy path single file. + reporter := sourcestest.TestReporter{} + err = s.ChunkUnit(ctx, sources.CommonSourceUnit{ + ID: tmpfile.Name(), + }, &reporter) + assert.NoError(t, err) + + // Happy path directory. + err = s.ChunkUnit(ctx, sources.CommonSourceUnit{ + ID: tmpdir, + }, &reporter) + assert.NoError(t, err) + + // Error path. + err = s.ChunkUnit(ctx, sources.CommonSourceUnit{ + ID: "/file/not/found", + }, &reporter) + assert.NoError(t, err) + + assert.Equal(t, 4, len(reporter.Chunks)) + assert.Equal(t, 1, len(reporter.ChunkErrs)) + dataFound := make(map[string]struct{}, 4) + for _, chunk := range reporter.Chunks { + dataFound[string(chunk.Data)] = struct{}{} + } + assert.Contains(t, dataFound, fileContents) + assert.Contains(t, dataFound, "foo") + assert.Contains(t, dataFound, "bar") + assert.Contains(t, dataFound, "baz") +} + +func TestEnumerateReporterErr(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Setup the connection to test enumeration. + units := []string{ + "/one", "/two", "/three", + "/path/to/dir/", "/path/to/another/dir/", + } + conn, err := anypb.New(&sourcespb.Filesystem{ + Paths: units[0:3], + Directories: units[3:], + }) + assert.NoError(t, err) + + // Initialize the source. + s := Source{} + err = s.Init(ctx, "test enumerate", 0, 0, true, conn, 1) + assert.NoError(t, err) + + // Enumerate should always return an error if the reporter returns an + // error. + reporter := sourcestest.ErrReporter{} + err = s.Enumerate(ctx, &reporter) + assert.Error(t, err) +} + +func TestChunkUnitReporterErr(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Setup test file to chunk. + tmpfile, err := os.CreateTemp("", "example.txt") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + fileContents := []byte("TestChunkUnit") + _, err = tmpfile.Write(fileContents) + assert.NoError(t, err) + assert.NoError(t, tmpfile.Close()) + + conn, err := anypb.New(&sourcespb.Filesystem{}) + assert.NoError(t, err) + + // Initialize the source. + s := Source{} + err = s.Init(ctx, "test chunk unit", 0, 0, true, conn, 1) + assert.NoError(t, err) + + // Happy path. ChunkUnit should always return an error if the reporter + // returns an error. + reporter := sourcestest.ErrReporter{} + err = s.ChunkUnit(ctx, sources.CommonSourceUnit{ + ID: tmpfile.Name(), + }, &reporter) + assert.Error(t, err) + + // Error path. ChunkUnit should always return an error if the reporter + // returns an error. + err = s.ChunkUnit(ctx, sources.CommonSourceUnit{ + ID: "/file/not/found", + }, &reporter) + assert.Error(t, err) +} + +// createTempFile is a helper function to create a temporary file in the given +// directory with the provided contents. If dir is "", the operating system's +// temp directory is used. +func createTempFile(dir string, contents string) (*os.File, func(), error) { + tmpfile, err := os.CreateTemp(dir, "trufflehogtest") + if err != nil { + return nil, nil, err + } + + if _, err := tmpfile.Write([]byte(contents)); err != nil { + _ = os.Remove(tmpfile.Name()) + return nil, nil, err + } + if err := tmpfile.Close(); err != nil { + _ = os.Remove(tmpfile.Name()) + return nil, nil, err + } + return tmpfile, func() { _ = os.Remove(tmpfile.Name()) }, nil +} + +// createTempDir is a helper function to create a temporary directory in the +// given directory with files containing the provided contents. If dir is "", +// the operating system's temp directory is used. +func createTempDir(dir string, contents ...string) (string, func(), error) { + tmpdir, err := os.MkdirTemp(dir, "trufflehogtest") + if err != nil { + return "", nil, err + } + + for _, content := range contents { + if _, _, err := createTempFile(tmpdir, content); err != nil { + _ = os.RemoveAll(tmpdir) + return "", nil, err + } + } + return tmpdir, func() { _ = os.RemoveAll(tmpdir) }, nil +} diff --git a/pkg/sourcestest/sourcestest.go b/pkg/sourcestest/sourcestest.go new file mode 100644 index 000000000000..6631ad5403f6 --- /dev/null +++ b/pkg/sourcestest/sourcestest.go @@ -0,0 +1,61 @@ +package sourcestest + +import ( + "fmt" + + "github.com/trufflesecurity/trufflehog/v3/pkg/context" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" +) + +type reporter interface { + sources.UnitReporter + sources.ChunkReporter +} + +var ( + _ reporter = (*TestReporter)(nil) + _ reporter = (*ErrReporter)(nil) +) + +// TestReporter is a helper struct that implements both UnitReporter and +// ChunkReporter by simply recording the values passed in the methods. +type TestReporter struct { + Units []sources.SourceUnit + UnitErrs []error + Chunks []sources.Chunk + ChunkErrs []error +} + +func (t *TestReporter) UnitOk(_ context.Context, unit sources.SourceUnit) error { + t.Units = append(t.Units, unit) + return nil +} +func (t *TestReporter) UnitErr(_ context.Context, err error) error { + t.UnitErrs = append(t.UnitErrs, err) + return nil +} +func (t *TestReporter) ChunkOk(_ context.Context, chunk sources.Chunk) error { + t.Chunks = append(t.Chunks, chunk) + return nil +} +func (t *TestReporter) ChunkErr(_ context.Context, err error) error { + t.ChunkErrs = append(t.ChunkErrs, err) + return nil +} + +// ErrReporter implements UnitReporter and ChunkReporter but always returns an +// error. +type ErrReporter struct{} + +func (ErrReporter) UnitOk(context.Context, sources.SourceUnit) error { + return fmt.Errorf("ErrReporter: UnitOk error") +} +func (ErrReporter) UnitErr(context.Context, error) error { + return fmt.Errorf("ErrReporter: UnitErr error") +} +func (ErrReporter) ChunkOk(context.Context, sources.Chunk) error { + return fmt.Errorf("ErrReporter: ChunkOk error") +} +func (ErrReporter) ChunkErr(context.Context, error) error { + return fmt.Errorf("ErrReporter: ChunkErr error") +}