Skip to content

Commit

Permalink
Initial implementation of JobReport with SourceManager usage (#1557)
Browse files Browse the repository at this point in the history
* Initial implementation of JobReport with SourceManager usage

* Limit concurrent units

* Only save the last JobReport per handle
  • Loading branch information
mcastorina authored Jul 27, 2023
1 parent 3897454 commit e391e89
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 19 deletions.
37 changes: 37 additions & 0 deletions pkg/sources/job_report.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package sources

import (
"errors"
"sync"
"time"
)

// JobReport aggregates information about a run of a Source.
type JobReport struct {
SourceID int64
JobID int64
StartTime time.Time
EndTime time.Time
TotalChunks uint64
errors []error
errorsLock sync.Mutex
}

// AddError adds a non-nil error to the aggregate of errors encountered during
// scanning.
func (jr *JobReport) AddError(err error) {
if err == nil {
return
}
jr.errorsLock.Lock()
defer jr.errorsLock.Unlock()
jr.errors = append(jr.errors, err)
}

// Errors joins all aggregated errors into one. If there were no errors, nil is
// returned. errors.Is can be used to check for specific errors.
func (jr *JobReport) Errors() error {
jr.errorsLock.Lock()
defer jr.errorsLock.Unlock()
return errors.Join(jr.errors...)
}
239 changes: 220 additions & 19 deletions pkg/sources/source_manager.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package sources

import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
Expand All @@ -25,12 +27,20 @@ type SourceManager struct {
// Map of handle to source initializer.
handles map[handle]SourceInitFunc
handlesLock sync.Mutex
// Map of handle to job reports.
// TODO: Manage culling and flushing to the API.
report map[handle]*JobReport
reportLock sync.Mutex
// Pool limiting the amount of concurrent sources running.
pool errgroup.Group
pool errgroup.Group
concurrentUnits int
// Run the sources using source unit enumeration / chunking if available.
useSourceUnits bool
// Downstream chunks channel to be scanned.
outputChunks chan *Chunk
// Set to true when Wait() returns.
done bool
// Set when Wait() returns.
done bool
doneErr error
}

// apiClient is an interface for optionally communicating with an external API.
Expand All @@ -56,12 +66,25 @@ func WithBufferedOutput(size int) func(*SourceManager) {
return func(mgr *SourceManager) { mgr.outputChunks = make(chan *Chunk, size) }
}

// WithSourceUnits enables using source unit enumeration and chunking if the
// source supports it.
func WithSourceUnits() func(*SourceManager) {
return func(mgr *SourceManager) { mgr.useSourceUnits = true }
}

// WithConcurrentUnits limits the number of units to be scanned concurrently.
// The default is unlimited.
func WithConcurrentUnits(n int) func(*SourceManager) {
return func(mgr *SourceManager) { mgr.concurrentUnits = n }
}

// NewManager creates a new manager with the provided options.
func NewManager(opts ...func(*SourceManager)) *SourceManager {
mgr := SourceManager{
// Default to the headless API. Can be overwritten by the WithAPI option.
api: &headlessAPI{},
handles: make(map[handle]SourceInitFunc),
report: make(map[handle]*JobReport),
outputChunks: make(chan *Chunk),
}
for _, opt := range opts {
Expand Down Expand Up @@ -100,24 +123,39 @@ func (s *SourceManager) Run(ctx context.Context, handle handle) error {
ch := make(chan error)
s.pool.Go(func() error {
defer common.Recover(ctx)
// TODO: The manager should record these errors.
ch <- s.run(ctx, handle)
report, err := s.run(ctx, handle)
if report != nil {
s.reportLock.Lock()
s.report[handle] = report
s.reportLock.Unlock()
}
if err != nil {
ch <- err
return nil
}
ch <- report.Errors()
return nil
})
return <-ch
}

// ScheduleRun blocks until a resource is available to run the source, then
// asynchronously runs it. Error information is lost in this case.
// asynchronously runs it. Error information is stored and returned by Wait().
func (s *SourceManager) ScheduleRun(ctx context.Context, handle handle) error {
// Do preflight checks before waiting on the pool.
if err := s.preflightChecks(ctx, handle); err != nil {
return err
}
s.pool.Go(func() error {
defer common.Recover(ctx)
// TODO: The manager should record these errors.
_ = s.run(ctx, handle)
// The error is already saved in the report, so we can ignore
// it here.
report, _ := s.run(ctx, handle)
if report != nil {
s.reportLock.Lock()
s.report[handle] = report
s.reportLock.Unlock()
}
return nil
})
// TODO: Maybe wait for a signal here that initialization was successful?
Expand All @@ -136,12 +174,34 @@ func (s *SourceManager) Chunks() <-chan *Chunk {
func (s *SourceManager) Wait() error {
// Check if the manager has been Waited.
if s.done {
return nil
return s.doneErr
}
// TODO: Aggregate all errors from all sources.
defer close(s.outputChunks)
defer func() { s.done = true }()
return s.pool.Wait()

// We are only using the errgroup for limiting concurrency.
// TODO: Maybe switch to using a semaphore.Weighted.
_ = s.pool.Wait()

// Aggregate all errors from all job reports.
// TODO: This should probably only be the fatal errors. We'll also need
// to rewrite this for when the reports start getting culled.
s.reportLock.Lock()
defer s.reportLock.Unlock()
errs := make([]error, 0, len(s.report))
for _, report := range s.report {
errs = append(errs, report.Errors())
}
s.doneErr = errors.Join(errs...)
return s.doneErr
}

// Report retrieves a scan report for a given handle. If no report exists or
// the Source has not finished, nil will be returned.
func (s *SourceManager) Report(handle handle) *JobReport {
s.reportLock.Lock()
defer s.reportLock.Unlock()
return s.report[handle]
}

// preflightChecks is a helper method to check the Manager or the context isn't
Expand All @@ -159,23 +219,127 @@ func (s *SourceManager) preflightChecks(ctx context.Context, handle handle) erro
}

// run is a helper method to sychronously run the source. It does not check for
// acquired resources.
func (s *SourceManager) run(ctx context.Context, handle handle) error {
// acquired resources. Possible return values are:
//
// - *JobReport, nil
// Successfully ran the source, but the report could have errors.
//
// - *JobReport, error
// There was an error calling Init or Chunks. This sort of error indicates
// a fatal error and is also recorded in the report.
//
// - nil, error:
// There was an error from the API or the handle is invalid. The latter of
// which should never happen due to the preflightChecks.
func (s *SourceManager) run(ctx context.Context, handle handle) (*JobReport, error) {
jobID, err := s.api.GetJobID(ctx, int64(handle))
if err != nil {
return err
return nil, err
}
initFunc, ok := s.getInitFunc(handle)
if !ok {
return fmt.Errorf("unrecognized handle")
return nil, fmt.Errorf("unrecognized handle")
}
// Create a report for this run.
report := &JobReport{
SourceID: int64(handle),
JobID: jobID,
StartTime: time.Now(),
}
defer func() { report.EndTime = time.Now() }()

// Initialize the source.
source, err := initFunc(ctx, jobID, int64(handle))
if err != nil {
return err
report.AddError(err)
return report, err
}
// Check for the preferred method of tracking source units.
if enumChunker, ok := source.(SourceUnitEnumChunker); ok && s.useSourceUnits {
return s.runWithUnits(ctx, handle, enumChunker, report)
}
return s.runWithoutUnits(ctx, handle, source, report)
}

// runWithoutUnits is a helper method to run a Source. It has coarse-grained
// job reporting.
func (s *SourceManager) runWithoutUnits(ctx context.Context, handle handle, source Source, report *JobReport) (*JobReport, error) {
// Introspect on the chunks we get from the Chunks method.
ch := make(chan *Chunk)
var wg sync.WaitGroup
// Consume chunks and export chunks.
wg.Add(1)
go func() {
defer wg.Done()
for chunk := range ch {
atomic.AddUint64(&report.TotalChunks, 1)
_ = common.CancellableWrite(ctx, s.outputChunks, chunk)
}
}()
// Don't return from this function until the goroutine has finished
// outputting chunks to the downstream channel. Closing the channel
// will stop the goroutine, so that needs to happen first in the defer
// stack.
defer wg.Wait()
defer close(ch)
if err := source.Chunks(ctx, ch); err != nil {
report.AddError(err)
return report, err
}
// TODO: Support UnitChunker and SourceUnitEnumerator.
// TODO: This is where we can introspect on the chunks collected.
return source.Chunks(ctx, s.outputChunks)
return report, nil
}

// runWithUnits is a helper method to run a Source that is also a
// SourceUnitEnumChunker. This allows better introspection of what is getting
// scanned and any errors encountered.
func (s *SourceManager) runWithUnits(ctx context.Context, handle handle, source SourceUnitEnumChunker, report *JobReport) (*JobReport, error) {
reporter := &mgrUnitReporter{
unitCh: make(chan SourceUnit),
}
// Produce units.
go func() {
// TODO: Catch panics and add to report.
defer close(reporter.unitCh)
if err := source.Enumerate(ctx, reporter); err != nil {
report.AddError(err)
}
}()
var wg sync.WaitGroup
// TODO: Maybe switch to using a semaphore.Weighted.
var unitPool errgroup.Group
if s.concurrentUnits != 0 {
// Negative values indicated no limit.
unitPool.SetLimit(s.concurrentUnits)
}
for unit := range reporter.unitCh {
reporter := &mgrChunkReporter{
unitID: unit.SourceUnitID(),
chunkCh: make(chan *Chunk),
}
unit := unit
// Consume units and produce chunks.
unitPool.Go(func() error {
// TODO: Catch panics and add to report.
defer close(reporter.chunkCh)
if err := source.ChunkUnit(ctx, unit, reporter); err != nil {
report.AddError(err)
}
return nil
})
// Consume chunks and export chunks.
wg.Add(1)
go func() {
defer wg.Done()
for chunk := range reporter.chunkCh {
// TODO: Introspect on the chunks we got from this unit.
atomic.AddUint64(&report.TotalChunks, 1)
_ = common.CancellableWrite(ctx, s.outputChunks, chunk)
}
}()
}
wg.Wait()
// TODO: Return fatal errors.
return report, nil
}

// getInitFunc is a helper method for safe concurrent access to the
Expand All @@ -201,3 +365,40 @@ func (api *headlessAPI) RegisterSource(ctx context.Context, name string, kind so
func (api *headlessAPI) GetJobID(ctx context.Context, id int64) (int64, error) {
return atomic.AddInt64(&api.jobIDCounter, 1), nil
}

// mgrUnitReporter implements the UnitReporter interface.
type mgrUnitReporter struct {
unitCh chan SourceUnit
unitErrs []error
unitErrsLock sync.Mutex
}

func (s *mgrUnitReporter) UnitOk(ctx context.Context, unit SourceUnit) error {
return common.CancellableWrite(ctx, s.unitCh, unit)
}

func (s *mgrUnitReporter) UnitErr(ctx context.Context, err error) error {
s.unitErrsLock.Lock()
defer s.unitErrsLock.Unlock()
s.unitErrs = append(s.unitErrs, err)
return nil
}

// mgrChunkReporter implements the ChunkReporter interface.
type mgrChunkReporter struct {
unitID string
chunkCh chan *Chunk
chunkErrs []error
chunkErrsLock sync.Mutex
}

func (s *mgrChunkReporter) ChunkOk(ctx context.Context, chunk Chunk) error {
return common.CancellableWrite(ctx, s.chunkCh, &chunk)
}

func (s *mgrChunkReporter) ChunkErr(ctx context.Context, err error) error {
s.chunkErrsLock.Lock()
defer s.chunkErrsLock.Unlock()
s.chunkErrs = append(s.chunkErrs, err)
return nil
}
Loading

0 comments on commit e391e89

Please sign in to comment.