diff --git a/pkg/util/loser/tree.go b/pkg/util/loser/tree.go index e0a80e5307..2bd65c4c93 100644 --- a/pkg/util/loser/tree.go +++ b/pkg/util/loser/tree.go @@ -38,6 +38,11 @@ func New[E any, S Sequence](sequences []S, maxVal E, at func(S) E, less func(E, t.nodes[i+nSequences].items = s if !t.moveNext(i + nSequences) { // Must call Next on each item so that At() has a value. if t.err != nil { + // error during initialize, requires us to close sequences not touched yet and mark nodes as uninitialized + for j := i + 1; j < nSequences; j++ { + t.close(sequences[j]) + t.nodes[j+nSequences].index = -1 + } break } } diff --git a/pkg/util/loser/tree_test.go b/pkg/util/loser/tree_test.go index d02dcb7671..cdf6767be6 100644 --- a/pkg/util/loser/tree_test.go +++ b/pkg/util/loser/tree_test.go @@ -5,6 +5,9 @@ import ( "math" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/grafana/pyroscope/pkg/util/loser" ) @@ -13,6 +16,8 @@ type List struct { cur uint64 err error + + closed int } func NewList(list ...uint64) *List { @@ -26,6 +31,9 @@ func (it *List) At() uint64 { func (it *List) Err() error { return it.err } func (it *List) Next() bool { + if it.err != nil { + return false + } if len(it.list) > 0 { it.cur = it.list[0] it.list = it.list[1:] @@ -35,7 +43,12 @@ func (it *List) Next() bool { return false } +func (it *List) Close() { it.closed += 1 } + func (it *List) Seek(val uint64) bool { + if it.err != nil { + return false + } for it.cur < val && len(it.list) > 0 { it.cur = it.list[0] it.list = it.list[1:] @@ -150,34 +163,74 @@ func TestPush(t *testing.T) { } func TestInitWithErr(t *testing.T) { - l := NewList() - l.err = errors.New("test") - l2 := NewList(5, 6, 7, 8) - tree := loser.New([]*List{l, l2}, math.MaxUint64, func(s *List) uint64 { return s.At() }, func(a, b uint64) bool { return a < b }, func(s *List) {}) - + lists := []*List{ + NewList(), + NewList(5, 6, 7, 8), + } + lists[0].err = testErr + tree := loser.New(lists, math.MaxUint64, func(s *List) uint64 { return s.At() }, func(a, b uint64) bool { return a < b }, func(s *List) { s.Close() }) if tree.Next() { t.Errorf("Next() should have returned false") } - if tree.Err() != l.err { - t.Errorf("Err() should have returned %v, got %v", l.err, tree.Err()) + if tree.Err() != testErr { + t.Errorf("Err() should have returned %v, got %v", testErr, tree.Err()) + } + + tree.Close() + for _, l := range lists { + assert.Equal(t, l.closed, 1, "list %+#v not closed exactly once", l) } + } +var testErr = errors.New("test") + func TestErrDuringNext(t *testing.T) { - l := NewList(5) - l.err = errors.New("test") - tree := loser.New([]*List{l}, math.MaxUint64, func(s *List) uint64 { return s.At() }, func(a, b uint64) bool { return a < b }, func(s *List) {}) + lists := []*List{ + NewList(5, 6), + NewList(11, 12), + } + tree := loser.New(lists, math.MaxUint64, func(s *List) uint64 { return s.At() }, func(a, b uint64) bool { return a < b }, func(s *List) { s.Close() }) + // no error for first element if !tree.Next() { t.Errorf("Next() should have returned true") } + // now error for second + lists[0].err = testErr if tree.Next() { t.Errorf("Next() should have returned false") } - if tree.Err() != l.err { - t.Errorf("Err() should have returned %v, got %v", l.err, tree.Err()) + if tree.Err() != testErr { + t.Errorf("Err() should have returned %v, got %v", testErr, tree.Err()) } if tree.Next() { t.Errorf("Next() should have returned false") } + + tree.Close() + for _, l := range lists { + assert.Equal(t, l.closed, 1, "list %+#v not closed exactly once", l) + } +} + +func TestErrInOneIterator(t *testing.T) { + l := NewList() + l.err = errors.New("test") + + lists := []*List{ + NewList(5, 1), + l, + NewList(2, 4), + } + tree := loser.New(lists, math.MaxUint64, func(s *List) uint64 { return s.At() }, func(a, b uint64) bool { return a < b }, func(s *List) { s.Close() }) + + // error for first element + require.False(t, tree.Next()) + assert.Equal(t, l.err, tree.Err()) + + tree.Close() + for _, l := range lists { + assert.Equal(t, l.closed, 1, "list %+#v not closed exactly once", l) + } }