Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SourceManager tests for Run and Wait methods #1530

Merged
merged 3 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 60 additions & 21 deletions pkg/sources/source_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type handle int64

// SourceInitFunc is a function that takes a source and job ID and returns an
// initialized Source.
type SourceInitFunc func(sourceID int64, jobID int64) (Source, error)
type SourceInitFunc func(ctx context.Context, sourceID int64, jobID int64) (Source, error)

type SourceManager struct {
api apiClient
Expand All @@ -29,6 +29,8 @@ type SourceManager struct {
pool errgroup.Group
// Downstream chunks channel to be scanned.
outputChunks chan *Chunk
// Set to true when Wait() returns.
done bool
}

// apiClient is an interface for optionally communicating with an external API.
Expand All @@ -41,30 +43,38 @@ type apiClient interface {

// WithAPI adds an API client to the manager for tracking jobs and progress.
func WithAPI(api apiClient) func(*SourceManager) {
return func(man *SourceManager) { man.api = api }
return func(mgr *SourceManager) { mgr.api = api }
}

// WithConcurrency limits the concurrent number of sources a manager can run.
func WithConcurrency(concurrency int) func(*SourceManager) {
return func(man *SourceManager) { man.pool.SetLimit(concurrency) }
return func(mgr *SourceManager) { mgr.pool.SetLimit(concurrency) }
}

// WithBufferedOutput sets the size of the buffer used for the Chunks() channel.
func WithBufferedOutput(size int) func(*SourceManager) {
return func(mgr *SourceManager) { mgr.outputChunks = make(chan *Chunk, size) }
}

// NewManager creates a new manager with the provided options.
func NewManager(outputChunks chan *Chunk, opts ...func(*SourceManager)) *SourceManager {
man := SourceManager{
func NewManager(opts ...func(*SourceManager)) *SourceManager {
mgr := SourceManager{
// Default to the headless API. Can be overwritten by the WithAPI option.
api: &headlessAPI{},
handles: make(map[handle]SourceInitFunc),
outputChunks: outputChunks,
outputChunks: make(chan *Chunk),
}
for _, opt := range opts {
opt(&man)
opt(&mgr)
}
return &man
return &mgr
}

// Enroll informs the SourceManager to track and manage a Source.
func (s *SourceManager) Enroll(ctx context.Context, name string, kind sourcespb.SourceType, f SourceInitFunc) (handle, error) {
if s.done {
return 0, fmt.Errorf("manager is done")
}
id, err := s.api.RegisterSource(ctx, name, kind)
if err != nil {
return 0, err
Expand All @@ -77,16 +87,14 @@ func (s *SourceManager) Enroll(ctx context.Context, name string, kind sourcespb.
return 0, fmt.Errorf("handle ID '%d' already in use", handleID)
}
s.handles[handleID] = f
return 0, nil
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are great :)

return handleID, nil
}

// Run blocks until a resource is available to run the source, then synchronously runs it.
// Run blocks until a resource is available to run the source, then
// synchronously runs it.
func (s *SourceManager) Run(ctx context.Context, handle handle) error {
// Check the handle is valid before waiting on the pool.
if _, ok := s.getInitFunc(handle); !ok {
return fmt.Errorf("unrecognized handle")
}
if err := ctx.Err(); err != nil {
// Do preflight checks before waiting on the pool.
if err := s.preflightChecks(ctx, handle); err != nil {
return err
}
ch := make(chan error)
Expand All @@ -102,11 +110,8 @@ func (s *SourceManager) Run(ctx context.Context, handle handle) error {
// ScheduleRun blocks until a resource is available to run the source, then
// asynchronously runs it. Error information is lost in this case.
func (s *SourceManager) ScheduleRun(ctx context.Context, handle handle) error {
// Check the handle is valid before waiting on the pool.
if _, ok := s.getInitFunc(handle); !ok {
return fmt.Errorf("unrecognized handle")
}
if err := ctx.Err(); err != nil {
// Do preflight checks before waiting on the pool.
if err := s.preflightChecks(ctx, handle); err != nil {
return err
}
s.pool.Go(func() error {
Expand All @@ -119,6 +124,40 @@ func (s *SourceManager) ScheduleRun(ctx context.Context, handle handle) error {
return nil
}

// Chunks returns the read only channel of all the chunks produced by all of
// the sources managed by this manager.
func (s *SourceManager) Chunks() <-chan *Chunk {
return s.outputChunks
}

// Wait blocks until all running sources are completed and returns an error if
// any of the sources had fatal errors. It also closes the channel returned by
// Chunks(). The manager should not be reused after calling this method.
func (s *SourceManager) Wait() error {
// Check if the manager has been Waited.
if s.done {
return nil
}
// TODO: Aggregate all errors from all sources.
defer close(s.outputChunks)
defer func() { s.done = true }()
return s.pool.Wait()
}

// preflightChecks is a helper method to check the Manager or the context isn't
// done and that the handle is valid.
func (s *SourceManager) preflightChecks(ctx context.Context, handle handle) error {
// Check if the manager has been Waited.
if s.done {
return fmt.Errorf("manager is done")
}
// Check the handle is valid.
if _, ok := s.getInitFunc(handle); !ok {
return fmt.Errorf("unrecognized handle")
}
return ctx.Err()
}

// run is a helper method to sychronously run the source. It does not check for
// acquired resources.
func (s *SourceManager) run(ctx context.Context, handle handle) error {
Expand All @@ -130,7 +169,7 @@ func (s *SourceManager) run(ctx context.Context, handle handle) error {
if !ok {
return fmt.Errorf("unrecognized handle")
}
source, err := initFunc(jobID, int64(handle))
source, err := initFunc(ctx, jobID, int64(handle))
if err != nil {
return err
}
Expand Down
119 changes: 119 additions & 0 deletions pkg/sources/source_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package sources

import (
"fmt"
"testing"

"github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"google.golang.org/protobuf/types/known/anypb"
)

// DummySource implements Source and is used for testing a SourceManager.
type DummySource struct {
sourceID int64
jobID int64
chunker
}

func (d *DummySource) Type() sourcespb.SourceType { return 1337 }
func (d *DummySource) SourceID() int64 { return d.sourceID }
func (d *DummySource) JobID() int64 { return d.jobID }
func (d *DummySource) Init(_ context.Context, _ string, jobID, sourceID int64, _ bool, _ *anypb.Any, _ int) error {
d.sourceID = sourceID
d.jobID = jobID
return nil
}
func (d *DummySource) GetProgress() *Progress { return nil }

// Interface to easily test different chunking methods.
type chunker interface {
Chunks(context.Context, chan *Chunk) error
}

// Chunk method that writes count bytes to the channel before returning.
type counterChunker struct {
chunkCounter byte
count int
}

func (c *counterChunker) Chunks(_ context.Context, ch chan *Chunk) error {
for i := 0; i < c.count; i++ {
ch <- &Chunk{Data: []byte{c.chunkCounter}}
c.chunkCounter++
}
return nil
}

// enrollDummy is a helper function to enroll a DummySource with a SourceManager.
func enrollDummy(mgr *SourceManager, chunkMethod chunker) (handle, error) {
return mgr.Enroll(context.Background(), "dummy", 1337,
func(ctx context.Context, jobID, sourceID int64) (Source, error) {
source := &DummySource{chunker: chunkMethod}
if err := source.Init(ctx, "dummy", jobID, sourceID, true, nil, 42); err != nil {
return nil, err
}
return source, nil
})
}

// tryRead is a helper function that will try to read from a channel and return
// an error if it cannot.
func tryRead(ch <-chan *Chunk) (*Chunk, error) {
select {
case chunk := <-ch:
return chunk, nil
default:
return nil, fmt.Errorf("no chunk available")
}
}

func TestSourceManagerRun(t *testing.T) {
mgr := NewManager(WithBufferedOutput(8))
handle, err := enrollDummy(mgr, &counterChunker{count: 1})
if err != nil {
t.Fatalf("unexpected error enrolling source: %v", err)
}
for i := 0; i < 3; i++ {
if err := mgr.Run(context.Background(), handle); err != nil {
t.Fatalf("unexpected error running source: %v", err)
}
chunk, err := tryRead(mgr.Chunks())
if err != nil {
t.Fatalf("reading chunk failed: %v", err)
}
if chunk.Data[0] != byte(i) {
t.Fatalf("unexpected chunk value, wanted %v, got: %v", chunk.Data[0], i)
}

// The Chunks channel should be empty now.
if chunk, err := tryRead(mgr.Chunks()); err == nil {
t.Fatalf("unexpected chunk found: %+v", chunk)
}
}
}

func TestSourceManagerWait(t *testing.T) {
mgr := NewManager()
handle, err := enrollDummy(mgr, &counterChunker{count: 1})
if err != nil {
t.Fatalf("unexpected error enrolling source: %v", err)
}
// Asynchronously run the source.
if err := mgr.ScheduleRun(context.Background(), handle); err != nil {
t.Fatalf("unexpected error scheduling run: %v", err)
}
// Read the 1 chunk we're expecting so Waiting completes.
<-mgr.Chunks()
// Wait for all resources to complete.
if err := mgr.Wait(); err != nil {
t.Fatalf("unexpected error waiting: %v", err)
}
// Enroll and run should return an error now.
if _, err := enrollDummy(mgr, &counterChunker{count: 1}); err == nil {
t.Fatalf("expected enroll to fail")
}
if err := mgr.ScheduleRun(context.Background(), handle); err == nil {
t.Fatalf("expected scheduling run to fail")
}
}
Loading