Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor UnitHook to block the scan if finished metrics aren't handled #2309

Merged
merged 5 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pkg/sources/job_progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ func (jp *JobProgress) executeHooks(todo func(hook JobProgressHook)) {
hooksExecTime.WithLabelValues().Observe(float64(elapsed))
}(time.Now())
for _, hook := range jp.hooks {
// TODO: Non-blocking?
// Execute hooks synchronously so they can provide
// back-pressure to the source.
todo(hook)
}
}
Expand Down
152 changes: 87 additions & 65 deletions pkg/sources/job_progress_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,61 @@ package sources
import (
"errors"
"fmt"
"strings"
"runtime"
"sync"
"time"

lru "github.com/hashicorp/golang-lru/v2"
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
)

// UnitHook implements JobProgressHook for tracking the progress of each
// individual unit.
type UnitHook struct {
metrics *lru.Cache[string, *UnitMetrics]
mu sync.Mutex
metrics map[string]*UnitMetrics
mu sync.Mutex
finishedMetrics chan UnitMetrics
logBackPressure func()
NoopHook
}

type UnitHookOpt func(*UnitHook)

func WithUnitHookCache(cache *lru.Cache[string, *UnitMetrics]) UnitHookOpt {
return func(hook *UnitHook) { hook.metrics = cache }
// WithUnitHookFinishBufferSize sets the buffer size for handling finished
// metrics (default is 1024). If the buffer fills, then scanning will stop
// until there is room.
func WithUnitHookFinishBufferSize(buf int) UnitHookOpt {
return func(hook *UnitHook) {
hook.finishedMetrics = make(chan UnitMetrics, buf)
}
}

func NewUnitHook(ctx context.Context, opts ...UnitHookOpt) *UnitHook {
// lru.NewWithEvict can only fail if the size is < 0.
cache, _ := lru.NewWithEvict(1024, func(key string, value *UnitMetrics) {
if value.handled {
return
}
ctx.Logger().Error(fmt.Errorf("eviction"), "dropping unit metric",
"id", key,
"metric", value,
)
})
hook := UnitHook{metrics: cache}
func NewUnitHook(ctx context.Context, opts ...UnitHookOpt) (*UnitHook, <-chan UnitMetrics) {
var once sync.Once
hook := UnitHook{
metrics: make(map[string]*UnitMetrics, runtime.NumCPU()),
finishedMetrics: make(chan UnitMetrics, 1024),
logBackPressure: func() {
once.Do(func() {
ctx.Logger().Info("back pressure detected in unit hook")
})
},
}
for _, opt := range opts {
opt(&hook)
}
return &hook
go func() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
hooksChannelSize.WithLabelValues().Set(float64(len(hook.finishedMetrics)))
case <-ctx.Done():
return
}
}
}()
return &hook, hook.finishedMetrics
}

// id is a helper method to generate an ID for the given job and unit.
Expand All @@ -52,36 +69,59 @@ func (u *UnitHook) id(ref JobProgressRef, unit SourceUnit) string {
return fmt.Sprintf("%d/%d/%s", ref.SourceID, ref.JobID, unitID)
}

func (u *UnitHook) ejectFinishedMetrics(metrics UnitMetrics) {
// Intentionally block the hook from returning to supply back-pressure
// to the source.
select {
case u.finishedMetrics <- metrics:
return
default:
u.logBackPressure()
}
u.finishedMetrics <- metrics
}

func (u *UnitHook) StartUnitChunking(ref JobProgressRef, unit SourceUnit, start time.Time) {
id := u.id(ref, unit)
u.mu.Lock()
defer u.mu.Unlock()

u.metrics.Add(id, &UnitMetrics{
u.metrics[id] = &UnitMetrics{
Unit: unit,
Parent: ref,
StartTime: &start,
})
}
}

func (u *UnitHook) EndUnitChunking(ref JobProgressRef, unit SourceUnit, end time.Time) {
id := u.id(ref, unit)
u.mu.Lock()
defer u.mu.Unlock()

metrics, ok := u.metrics.Get(id)
metrics, ok := u.finishUnit(id)
if !ok {
return
}
metrics.EndTime = &end
u.ejectFinishedMetrics(*metrics)
}

func (u *UnitHook) finishUnit(id string) (*UnitMetrics, bool) {
u.mu.Lock()
defer u.mu.Unlock()

metrics, ok := u.metrics[id]
if !ok {
return nil, false
}
delete(u.metrics, id)
return metrics, true
}

func (u *UnitHook) ReportChunk(ref JobProgressRef, unit SourceUnit, chunk *Chunk) {
id := u.id(ref, unit)
u.mu.Lock()
defer u.mu.Unlock()

metrics, ok := u.metrics.Get(id)
metrics, ok := u.metrics[id]
if !ok && unit != nil {
// The unit has been evicted.
return
Expand All @@ -92,7 +132,7 @@ func (u *UnitHook) ReportChunk(ref JobProgressRef, unit SourceUnit, chunk *Chunk
Parent: ref,
StartTime: ref.Snapshot().StartTime,
}
u.metrics.Add(id, metrics)
u.metrics[id] = metrics
}
metrics.TotalChunks++
metrics.TotalBytes += uint64(len(chunk.Data))
Expand All @@ -103,7 +143,7 @@ func (u *UnitHook) ReportError(ref JobProgressRef, err error) {
defer u.mu.Unlock()

// Always add the error to the nil unit if it exists.
if metrics, ok := u.metrics.Get(u.id(ref, nil)); ok {
if metrics, ok := u.metrics[u.id(ref, nil)]; ok {
metrics.Errors = append(metrics.Errors, err)
}

Expand All @@ -114,59 +154,44 @@ func (u *UnitHook) ReportError(ref JobProgressRef, err error) {
}
id := u.id(ref, chunkErr.Unit)

metrics, ok := u.metrics.Get(id)
metrics, ok := u.metrics[id]
if !ok {
return
}
metrics.Errors = append(metrics.Errors, err)
}

func (u *UnitHook) Finish(ref JobProgressRef) {
u.mu.Lock()
defer u.mu.Unlock()
// Clear out any metrics on this job. This covers the case for the
// source running without unit support.
prefix := u.id(ref, nil)
for _, id := range u.metrics.Keys() {
if !strings.HasPrefix(id, prefix) {
continue
}
metric, ok := u.metrics.Get(id)
if !ok {
continue
}
// If the unit is nil, the source does not support units.
// Use the overall job metrics instead.
if metric.Unit == nil {
snap := ref.Snapshot()
metric.StartTime = snap.StartTime
metric.EndTime = snap.EndTime
metric.Errors = snap.Errors
}
id := u.id(ref, nil)
metrics, ok := u.finishUnit(id)
if !ok {
return
}
snap := ref.Snapshot()
metrics.StartTime = snap.StartTime
metrics.EndTime = snap.EndTime
metrics.Errors = snap.Errors
u.ejectFinishedMetrics(*metrics)
}

// UnitMetrics gets all the currently active or newly finished metrics for this
// job. If a unit returned from this method has finished, it will be removed
// from the cache and no longer returned in successive calls to UnitMetrics().
func (u *UnitHook) UnitMetrics() []UnitMetrics {
// InProgressSnapshot gets all the currently active metrics across all jobs.
func (u *UnitHook) InProgressSnapshot() []UnitMetrics {
u.mu.Lock()
defer u.mu.Unlock()
output := make([]UnitMetrics, 0, u.metrics.Len())
for _, id := range u.metrics.Keys() {
metric, ok := u.metrics.Get(id)
if !ok {
continue
}
output = append(output, *metric)
if metric.IsFinished() {
metric.handled = true
u.metrics.Remove(id)
Comment on lines -162 to -164
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this being done here before because there wasn't really a better place for it? But now, there is, so you moved it? Or am I misunderstanding?

Copy link
Collaborator Author

@mcastorina mcastorina Jan 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the previous mechanism to drain the cache, which is now in EndUnitChunking and Finish. The idea was that you would periodically call UnitMetrics() and get the list of in progress and finished metrics, and the finished ones would be considered "handled" and removed from the cache.

This PR moves the handling to be synchronous, so this function is now only getting a snapshot of the in-progress units (hence the name change).

}
output := make([]UnitMetrics, 0, len(u.metrics))
for _, metrics := range u.metrics {
output = append(output, *metrics)
}
return output
}

func (u *UnitHook) Close() error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used anywhere?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it gets called when the SourceManager is waited:

for _, hook := range s.hooks {
if hookCloser, ok := hook.(io.Closer); ok {
_ = hookCloser.Close()
}
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some synchronization is required to make the goroutine reading the channel finish before everything gets shut down, but that doesn't happen in this PR since UnitHook isn't used yet.

I figured it made sense for the source manager to close out the hooks since the hook initialized it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some synchronization is required to make the goroutine reading the channel finish before everything gets shut down, but that doesn't happen in this PR since UnitHook isn't used yet.

You read my mind :)

close(u.finishedMetrics)
return nil
}

type UnitMetrics struct {
Unit SourceUnit `json:"unit,omitempty"`
Parent JobProgressRef `json:"parent,omitempty"`
Expand All @@ -179,9 +204,6 @@ type UnitMetrics struct {
TotalBytes uint64 `json:"total_bytes"`
// All errors encountered by this unit.
Errors []error `json:"errors"`
// Flag to mark that these metrics were intentionally evicted from
// the cache.
handled bool
}

func (u UnitMetrics) IsFinished() bool {
Expand Down
7 changes: 7 additions & 0 deletions pkg/sources/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,11 @@ var (
Help: "Time spent executing hooks (ms)",
Buckets: []float64{5, 50, 500, 1000},
}, nil)

hooksChannelSize = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: common.MetricsNamespace,
Subsystem: common.MetricsSubsystem,
Name: "hooks_channel_size",
Help: "Total number of metrics waiting in the finished channel.",
}, nil)
)
6 changes: 6 additions & 0 deletions pkg/sources/source_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sources

import (
"fmt"
"io"
"runtime"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -182,6 +183,11 @@ func (s *SourceManager) Wait() error {
}
close(s.outputChunks)
close(s.firstErr)
for _, hook := range s.hooks {
if hookCloser, ok := hook.(io.Closer); ok {
_ = hookCloser.Close()
}
}
return s.waitErr
}

Expand Down
49 changes: 30 additions & 19 deletions pkg/sources/source_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"fmt"
"sort"
"testing"
"time"

lru "github.com/hashicorp/golang-lru/v2"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"

Expand Down Expand Up @@ -316,7 +316,7 @@ func TestSourceManagerAvailableCapacity(t *testing.T) {
}

func TestSourceManagerUnitHook(t *testing.T) {
hook := NewUnitHook(context.TODO())
hook, ch := NewUnitHook(context.TODO())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between context.TODO() and context.Background()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functionally they're equivalent, but the docs say this:

Code should use context.TODO when it's unclear which Context to use or it is not yet available (because the surrounding function has not yet been extended to accept a Context parameter).


input := []unitChunk{
{unit: "1 one", output: "bar"},
Expand All @@ -333,9 +333,13 @@ func TestSourceManagerUnitHook(t *testing.T) {
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.NoError(t, mgr.Wait())

metrics := hook.UnitMetrics()
assert.Equal(t, 3, len(metrics))
assert.Equal(t, 0, len(hook.InProgressSnapshot()))
var metrics []UnitMetrics
for metric := range ch {
metrics = append(metrics, metric)
}
sort.Slice(metrics, func(i, j int) bool {
return metrics[i].Unit.SourceUnitID() < metrics[j].Unit.SourceUnitID()
})
Expand Down Expand Up @@ -366,14 +370,10 @@ func TestSourceManagerUnitHook(t *testing.T) {
assert.Equal(t, 1, len(m2.Errors))
}

// TestSourceManagerUnitHookNoBlock tests that the UnitHook drops metrics if
// they aren't handled fast enough.
func TestSourceManagerUnitHookNoBlock(t *testing.T) {
var evictedKeys []string
cache, _ := lru.NewWithEvict(1, func(key string, _ *UnitMetrics) {
evictedKeys = append(evictedKeys, key)
})
hook := NewUnitHook(context.TODO(), WithUnitHookCache(cache))
// TestSourceManagerUnitHookBackPressure tests that the UnitHook blocks if the
// finished metrics aren't handled fast enough.
Comment on lines +373 to +374
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

func TestSourceManagerUnitHookBackPressure(t *testing.T) {
hook, ch := NewUnitHook(context.TODO(), WithUnitHookFinishBufferSize(0))

input := []unitChunk{
{unit: "one", output: "bar"},
Expand All @@ -389,18 +389,25 @@ func TestSourceManagerUnitHookNoBlock(t *testing.T) {
assert.NoError(t, err)
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()

assert.Equal(t, 2, len(evictedKeys))
metrics := hook.UnitMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "three", metrics[0].Unit.SourceUnitID())
var metrics []UnitMetrics
for i := 0; i < len(input); i++ {
select {
case <-ref.Done():
t.Fatal("job should not finish until metrics have been collected")
case <-time.After(1 * time.Millisecond):
}
metrics = append(metrics, <-ch)
}

assert.NoError(t, mgr.Wait())
assert.Equal(t, 3, len(metrics), metrics)
}

// TestSourceManagerUnitHookNoUnits tests whether the UnitHook works for
// sources that don't support units.
func TestSourceManagerUnitHookNoUnits(t *testing.T) {
hook := NewUnitHook(context.TODO())
hook, ch := NewUnitHook(context.TODO())

mgr := NewManager(
WithBufferedOutput(8),
Expand All @@ -412,8 +419,12 @@ func TestSourceManagerUnitHookNoUnits(t *testing.T) {
ref, err := mgr.Run(context.Background(), "dummy", source)
assert.NoError(t, err)
<-ref.Done()
assert.NoError(t, mgr.Wait())

metrics := hook.UnitMetrics()
var metrics []UnitMetrics
for metric := range ch {
metrics = append(metrics, metric)
}
assert.Equal(t, 1, len(metrics))

m := metrics[0]
Expand Down
Loading