diff --git a/arena.go b/arena.go index 4278f9d3..708c581f 100644 --- a/arena.go +++ b/arena.go @@ -14,27 +14,25 @@ type Arena interface { // This must not be larger than 1<<32. NumSegments() int64 - // Data loads the data for the segment with the given ID. IDs are in - // the range [0, NumSegments()). - // must be tightly packed in the range [0, NumSegments()). - Data(id SegmentID) ([]byte, error) + // Segment returns the segment identified with the specified id. This + // may return nil if the segment with the specified ID does not exist. + Segment(id SegmentID) *Segment // Allocate selects a segment to place a new object in, creating a // segment or growing the capacity of a previously loaded segment if - // necessary. If Allocate does not return an error, then the - // difference of the capacity and the length of the returned slice - // must be at least minsz. segs is a map of segments keyed by ID - // using arrays returned by the Data method (although the length of - // these slices may have changed by previous allocations). Allocate - // must not modify segs. + // necessary. If Allocate does not return an error, then the returned + // segment may store up to minsz bytes starting at the returned address + // offset. + // + // Some allocators may specifically choose to grow the passed seg (if + // non nil), but that is not a requirement. // // If Allocate creates a new segment, the ID must be one larger than // the last segment's ID or zero if it is the first segment. // - // If Allocate returns an previously loaded segment's ID, then the - // arena is responsible for preserving the existing data in the - // returned byte slice. - Allocate(minsz Size, segs map[SegmentID]*Segment) (SegmentID, []byte, error) + // If Allocate returns an previously loaded segment, then the arena is + // responsible for preserving the existing data. + Allocate(minsz Size, msg *Message, seg *Segment) (*Segment, address, error) // Release all resources associated with the Arena. Callers MUST NOT // use the Arena after it has been released. @@ -47,55 +45,90 @@ type Arena interface { Release() } +// singleSegmentPool is a pool of *SingleSegmentArena. +var singleSegmentPool = sync.Pool{ + New: func() any { + return &SingleSegmentArena{} + }, +} + // SingleSegmentArena is an Arena implementation that stores message data // in a continguous slice. Allocation is performed by first allocating a // new slice and copying existing data. SingleSegment arena does not fail // unless the caller attempts to access another segment. -type SingleSegmentArena []byte +type SingleSegmentArena struct { + seg Segment + + // bp is the bufferpool assotiated with this arena if it was initialized + // for writing. + bp *bufferpool.Pool + + // fromPool determines if this should return to the pool when released. + fromPool bool +} + +func zeroSlice(b []byte) { + for i := range b { + b[i] = 0 + } +} // SingleSegment constructs a SingleSegmentArena from b. b MAY be nil. // Callers MAY use b to populate the segment for reading, or to reserve // memory of a specific size. -func SingleSegment(b []byte) *SingleSegmentArena { - return (*SingleSegmentArena)(&b) +func SingleSegment(b []byte) Arena { + if b == nil { + ssa := singleSegmentPool.Get().(*SingleSegmentArena) + ssa.fromPool = true + ssa.bp = &bufferpool.Default + return ssa + } + + return &SingleSegmentArena{seg: Segment{data: b}} } -func (ssa SingleSegmentArena) NumSegments() int64 { +func (ssa *SingleSegmentArena) NumSegments() int64 { return 1 } -func (ssa SingleSegmentArena) Data(id SegmentID) ([]byte, error) { +func (ssa *SingleSegmentArena) Segment(id SegmentID) *Segment { if id != 0 { - return nil, errors.New("segment " + str.Utod(id) + " requested in single segment arena") + return nil } - return ssa, nil + return &ssa.seg } -func (ssa *SingleSegmentArena) Allocate(sz Size, segs map[SegmentID]*Segment) (SegmentID, []byte, error) { - data := []byte(*ssa) - if segs[0] != nil { - data = segs[0].data +func (ssa *SingleSegmentArena) Allocate(sz Size, msg *Message, seg *Segment) (*Segment, address, error) { + if seg != nil && seg != &ssa.seg { + return nil, 0, errors.New("segment is not associated with arena") } + data := ssa.seg.data if len(data)%int(wordSize) != 0 { - return 0, nil, errors.New("segment size is not a multiple of word size") + return nil, 0, errors.New("segment size is not a multiple of word size") } + ssa.seg.BindTo(msg) if hasCapacity(data, sz) { - return 0, data, nil + addr := address(len(ssa.seg.data)) + ssa.seg.data = ssa.seg.data[:len(ssa.seg.data)+int(sz)] + return &ssa.seg, addr, nil } inc, err := nextAlloc(int64(len(data)), int64(maxAllocSize()), sz) if err != nil { - return 0, nil, err + return nil, 0, err + } + if ssa.bp == nil { + return nil, 0, errors.New("cannot allocate on read-only SingleSegmentArena") } - buf := bufferpool.Default.Get(cap(data) + inc) - copied := copy(buf, data) - buf = buf[:copied] - bufferpool.Default.Put(data) - *ssa = buf - return 0, *ssa, nil + addr := address(len(ssa.seg.data)) + ssa.seg.data = ssa.bp.Get(cap(data) + inc)[:len(data)+int(sz)] + copy(ssa.seg.data, data) + zeroSlice(data) + ssa.bp.Put(data) + return &ssa.seg, addr, nil } -func (ssa SingleSegmentArena) String() string { - return "single-segment arena [len=" + str.Itod(len(ssa)) + " cap=" + str.Itod(cap(ssa)) + "]" +func (ssa *SingleSegmentArena) String() string { + return "single-segment arena [len=" + str.Itod(len(ssa.seg.data)) + " cap=" + str.Itod(cap(ssa.seg.data)) + "]" } // Return this arena to an internal sync.Pool of arenas that can be @@ -108,17 +141,31 @@ func (ssa SingleSegmentArena) String() string { // Calling Release is optional; if not done the garbage collector // will release the memory per usual. func (ssa *SingleSegmentArena) Release() { - bufferpool.Default.Put(*ssa) - *ssa = nil + if ssa.bp != nil { + zeroSlice(ssa.seg.data) + ssa.bp.Put(ssa.seg.data) + } + ssa.seg.BindTo(nil) + ssa.seg.data = nil + if ssa.fromPool { + ssa.fromPool = false // Prevent double return + singleSegmentPool.Put(ssa) + } } // MultiSegment is an arena that stores object data across multiple []byte // buffers, allocating new buffers of exponentially-increasing size when // full. This avoids the potentially-expensive slice copying of SingleSegment. type MultiSegmentArena struct { - ss [][]byte - delim int // index of first segment in ss that is NOT in buf - buf []byte // full-sized buffer that was demuxed into ss. + segs []Segment + + // bp is the bufferpool assotiated with this arena's segments if it was + // initialized for writing. + bp *bufferpool.Pool + + // fromPool is true if this msa instance was obtained from the + // multiSegmentPool and should be returned there upon release. + fromPool bool } // MultiSegment returns a new arena that allocates new segments when @@ -126,7 +173,9 @@ type MultiSegmentArena struct { // buffer for reading or to reserve memory of a specific size. func MultiSegment(b [][]byte) *MultiSegmentArena { if b == nil { - return multiSegmentPool.Get().(*MultiSegmentArena) + msa := multiSegmentPool.Get().(*MultiSegmentArena) + msa.fromPool = true + return msa } return multiSegment(b) } @@ -141,28 +190,47 @@ func MultiSegment(b [][]byte) *MultiSegmentArena { // Calling Release is optional; if not done the garbage collector // will release the memory per usual. func (msa *MultiSegmentArena) Release() { - for i, v := range msa.ss { - msa.ss[i] = nil - - // segment not in buf? - if i >= msa.delim { - bufferpool.Default.Put(v) + for i := range msa.segs { + if msa.bp != nil { + zeroSlice(msa.segs[i].data) + msa.bp.Put(msa.segs[i].data) } + msa.segs[i].data = nil + msa.segs[i].BindTo(nil) + } + + if msa.segs != nil { + msa.segs = msa.segs[:0] } - bufferpool.Default.Put(msa.buf) // nil is ok - *msa = MultiSegmentArena{ss: msa.ss[:0]} - multiSegmentPool.Put(msa) + if msa.fromPool { + // Prevent double inclusion if it is used after release. + msa.fromPool = false + + multiSegmentPool.Put(msa) + } } // Like MultiSegment, but doesn't use the pool func multiSegment(b [][]byte) *MultiSegmentArena { - return &MultiSegmentArena{ss: b} + var bp *bufferpool.Pool + var segs []Segment + if b == nil { + bp = &bufferpool.Default + segs = make([]Segment, 0, 5) // Typical size. + } else { + segs = make([]Segment, len(b)) + for i := range b { + segs[i].data = b[i] + segs[i].id = SegmentID(i) + } + } + return &MultiSegmentArena{segs: segs, bp: bp} } var multiSegmentPool = sync.Pool{ New: func() any { - return multiSegment(make([][]byte, 0, 16)) + return multiSegment(nil) }, } @@ -174,71 +242,126 @@ func (msa *MultiSegmentArena) demux(hdr streamHeader, data []byte) error { return errors.New("number of segments overflows int") } - msa.buf = data - msa.delim = int(maxSeg + 1) + // Grow list of existing segments as needed. + numSegs := int(maxSeg + 1) + if cap(msa.segs) >= numSegs { + msa.segs = msa.segs[:numSegs] + } else { + inc := numSegs - len(msa.segs) + msa.segs = append(msa.segs, make([]Segment, inc)...) + } - // We might be forced to allocate here, but hopefully it won't - // happen to often. We assume msa was freshly obtained from a - // pool, and that no segments have been allocated yet. - var segment []byte - for i := 0; i < msa.delim; i++ { + for i := SegmentID(0); i <= maxSeg; i++ { sz, err := hdr.segmentSize(SegmentID(i)) if err != nil { return err } - segment, data = data[:sz:sz], data[sz:] - msa.ss = append(msa.ss, segment) + msa.segs[i].data, data = data[:sz:sz], data[sz:] + msa.segs[i].id = i } return nil } func (msa *MultiSegmentArena) NumSegments() int64 { - return int64(len(msa.ss)) + return int64(len(msa.segs)) } -func (msa *MultiSegmentArena) Data(id SegmentID) ([]byte, error) { - if int64(id) >= int64(len(msa.ss)) { - return nil, errors.New("segment " + str.Utod(id) + " requested (arena only has " + - str.Itod(len(msa.ss)) + " segments)") +func (msa *MultiSegmentArena) Segment(id SegmentID) *Segment { + if int(id) >= len(msa.segs) { + return nil } - return msa.ss[id], nil + return &msa.segs[id] } -func (msa *MultiSegmentArena) Allocate(sz Size, segs map[SegmentID]*Segment) (SegmentID, []byte, error) { - var total int64 - for i, data := range msa.ss { - id := SegmentID(i) - if s := segs[id]; s != nil { - data = s.data +func (msa *MultiSegmentArena) Allocate(sz Size, msg *Message, seg *Segment) (*Segment, address, error) { + // Prefer allocating in seg if it has capacity. + if seg != nil && hasCapacity(seg.data, sz) { + // Double check this segment is part of this arena. + contains := false + for i := range msa.segs { + if &msa.segs[i] == seg { + contains = true + break + } + } + + if !contains { + // This is a usage error. + return nil, 0, errors.New("preferred segment is not part of the arena") + } + + // Double check this segment is for this message. + if seg.Message() != nil && seg.Message() != msg { + return nil, 0, errors.New("attempt to allocate in segment for different message") } + addr := address(len(seg.data)) + newLen := int(addr) + int(sz) + seg.data = seg.data[:newLen] + seg.BindTo(msg) + return seg, addr, nil + } + + var total int64 + for i := range msa.segs { + data := msa.segs[i].data if hasCapacity(data, sz) { - return id, data, nil + // Found segment with spare capacity. + addr := address(len(msa.segs[i].data)) + newLen := int(addr) + int(sz) + msa.segs[i].data = msa.segs[i].data[:newLen] + msa.segs[i].BindTo(msg) + return &msa.segs[i], addr, nil } if total += int64(cap(data)); total < 0 { // Overflow. - return 0, nil, errors.New("alloc " + str.Utod(sz) + " bytes: message too large") + return nil, 0, errors.New("alloc " + str.Utod(sz) + " bytes: message too large") } } - n, err := nextAlloc(total, 1<<63-1, sz) - if err != nil { - return 0, nil, err + // Check for read-only arena. + if msa.bp == nil { + return nil, 0, errors.New("cannot allocate segment in read-only multi-segment arena") } - buf := bufferpool.Default.Get(n) - buf = buf[:0] + // If this is the very first segment and the requested allocation + // size is zero, modify the requested size to at least one word. + // + // FIXME: this is to maintain compatibility to existing behavior and + // tests in NewMessage(), which assumes this. Remove once arenas + // enforce the contract of always having at least one segment. + compatFirstSegLenZeroAddSize := Size(0) + if len(msa.segs) == 0 && sz == 0 { + compatFirstSegLenZeroAddSize = wordSize + } + + // Determine actual allocation size (may be greater than sz). + n, err := nextAlloc(total, 1<<63-1, sz+compatFirstSegLenZeroAddSize) + if err != nil { + return nil, 0, err + } - id := SegmentID(len(msa.ss)) - msa.ss = append(msa.ss, buf) - return id, buf, nil + // We have determined this will be a new segment. Get the backing + // buffer for it. + buf := msa.bp.Get(n) + buf = buf[:sz] + + // Setup the segment. + id := SegmentID(len(msa.segs)) + msa.segs = append(msa.segs, Segment{ + data: buf, + id: id, + }) + res := &msa.segs[int(id)] + res.BindTo(msg) + return res, 0, nil } func (msa *MultiSegmentArena) String() string { - return "multi-segment arena [" + str.Itod(len(msa.ss)) + " segments]" + return "multi-segment arena [" + str.Itod(len(msa.segs)) + " segments]" } // nextAlloc computes how much more space to allocate given the number @@ -287,3 +410,66 @@ func nextAlloc(curr, max int64, req Size) (int, error) { func hasCapacity(b []byte, sz Size) bool { return sz <= Size(cap(b)-len(b)) } + +type ReadOnlySingleSegment struct { + seg Segment +} + +// NumSegments returns the number of segments in the arena. +// This must not be larger than 1<<32. +func (r *ReadOnlySingleSegment) NumSegments() int64 { + return 1 +} + +// Segment returns the segment identified with the specified id. This +// may return nil if the segment with the specified ID does not exist. +func (r *ReadOnlySingleSegment) Segment(id SegmentID) *Segment { + if id == 0 { + return &r.seg + } + + return nil +} + +// Allocate selects a segment to place a new object in, creating a +// segment or growing the capacity of a previously loaded segment if +// necessary. If Allocate does not return an error, then the +// difference of the capacity and the length of the returned slice +// must be at least minsz. Some allocators may specifically choose to +// grow the passed seg (if non nil). +// +// If Allocate creates a new segment, the ID must be one larger than +// the last segment's ID or zero if it is the first segment. +// +// If Allocate returns an previously loaded segment, then the +// arena is responsible for preserving the existing data. +func (r *ReadOnlySingleSegment) Allocate(minsz Size, msg *Message, seg *Segment) (*Segment, address, error) { + return nil, 0, errors.New("readOnly segment cannot allocate data") +} + +// Release all resources associated with the Arena. Callers MUST NOT +// use the Arena after it has been released. +// +// Calling Release() is OPTIONAL, but may reduce allocations. +// +// Implementations MAY use Release() as a signal to return resources +// to free lists, or otherwise reuse the Arena. However, they MUST +// NOT assume Release() will be called. +func (r *ReadOnlySingleSegment) Release() { + r.seg.data = nil +} + +// ReplaceData replaces the current data of the arena. This should ONLY be +// called on an empty or released arena, or else it panics. +func (r *ReadOnlySingleSegment) ReplaceData(b []byte) { + if r.seg.data != nil { + panic("replacing data on unreleased ReadOnlyArena") + } + + r.seg.data = b +} + +// NewReadOnlySingleSegment creates a new read only arena with the given data. +func NewReadOnlySingleSegment(b []byte) *ReadOnlySingleSegment { + return &ReadOnlySingleSegment{seg: Segment{data: b}} +} diff --git a/arena_test.go b/arena_test.go index 51267251..42bc00b7 100644 --- a/arena_test.go +++ b/arena_test.go @@ -1,10 +1,59 @@ package capnp import ( - "bytes" "testing" + + "capnproto.org/go/capnp/v3/exp/bufferpool" + "github.com/stretchr/testify/require" ) +type arenaAllocTest struct { + name string + + // Arrange + init func() (Arena, map[SegmentID]*Segment) + size Size + + // Assert + id SegmentID + data []byte +} + +func (test *arenaAllocTest) run(t *testing.T) { + arena, _ := test.init() + seg, addr, err := arena.Allocate(test.size, nil, nil) + + require.NoError(t, err, "Allocate error") + require.Equal(t, test.id, seg.id) + + // Allocate() contract is that segment data starting at addr should + // have anough room for test.size bytes. + require.Less(t, int(addr), len(seg.data)) + + data := seg.data[addr:] + require.LessOrEqual(t, test.size, Size(cap(seg.data))) + + data = data[:test.size] + require.Equal(t, test.data, data) +} + +func incrementingData(n int) []byte { + b := make([]byte, n) + for i := range b { + b[i] = byte(i % 256) + } + return b +} + +func segmentData(a Arena, id SegmentID) []byte { + seg := a.Segment(id) + if seg == nil { + return nil + } + + return seg.Data() +} + func TestSingleSegment(t *testing.T) { t.Parallel() t.Helper() @@ -13,40 +62,22 @@ func TestSingleSegment(t *testing.T) { t.Parallel() arena := SingleSegment(nil) - if n := arena.NumSegments(); n != 1 { - t.Errorf("SingleSegment(nil).NumSegments() = %d; want 1", n) - } - data, err := arena.Data(0) - if len(data) != 0 { - t.Errorf("SingleSegment(nil).Data(0) = %#v; want nil", data) - } - if err != nil { - t.Errorf("SingleSegment(nil).Data(0) error: %v", err) - } - _, err = arena.Data(1) - if err == nil { - t.Error("SingleSegment(nil).Data(1) succeeded; want error") - } + require.Equal(t, int64(1), arena.NumSegments()) + data0 := segmentData(arena, 0) + require.Empty(t, data0) + data1 := segmentData(arena, 1) + require.Empty(t, data1) }) t.Run("ExistingData", func(t *testing.T) { t.Parallel() arena := SingleSegment(incrementingData(8)) - if n := arena.NumSegments(); n != 1 { - t.Errorf("SingleSegment(incrementingData(8)).NumSegments() = %d; want 1", n) - } - data, err := arena.Data(0) - if want := incrementingData(8); !bytes.Equal(data, want) { - t.Errorf("SingleSegment(incrementingData(8)).Data(0) = %#v; want %#v", data, want) - } - if err != nil { - t.Errorf("SingleSegment(incrementingData(8)).Data(0) error: %v", err) - } - _, err = arena.Data(1) - if err == nil { - t.Error("SingleSegment(incrementingData(8)).Data(1) succeeded; want error") - } + require.Equal(t, int64(1), arena.NumSegments()) + data0 := segmentData(arena, 0) + require.Equal(t, incrementingData(8), data0) + data1 := segmentData(arena, 1) + require.Empty(t, data1) }) } @@ -61,7 +92,7 @@ func TestSingleSegmentAllocate(t *testing.T) { }, size: 8, id: 0, - data: []byte{}, + data: []byte{7: 0}, }, { name: "unloaded", @@ -71,7 +102,7 @@ func TestSingleSegmentAllocate(t *testing.T) { }, size: 8, id: 0, - data: incrementingData(16), + data: incrementingData(24)[16 : 16+8], }, { name: "loaded", @@ -85,7 +116,7 @@ func TestSingleSegmentAllocate(t *testing.T) { }, size: 8, id: 0, - data: incrementingData(16), + data: incrementingData(24)[16 : 16+8], }, { name: "loaded changes length", @@ -98,7 +129,7 @@ func TestSingleSegmentAllocate(t *testing.T) { }, size: 8, id: 0, - data: incrementingData(24), + data: incrementingData(32)[16 : 16+8], }, { name: "message-filled segment", @@ -111,11 +142,12 @@ func TestSingleSegmentAllocate(t *testing.T) { }, size: 8, id: 0, - data: incrementingData(24), + data: incrementingData(24)[16 : 16+8], }, } for i := range tests { - tests[i].run(t, i) + tc := tests[i] + t.Run(tc.name, tc.run) } } @@ -127,40 +159,22 @@ func TestMultiSegment(t *testing.T) { t.Parallel() arena := MultiSegment(nil) - if n := arena.NumSegments(); n != 0 { - t.Errorf("MultiSegment(nil).NumSegments() = %d; want 1", n) - } - _, err := arena.Data(0) - if err == nil { - t.Error("MultiSegment(nil).Data(0) succeeded; want error") - } + require.Equal(t, int64(0), arena.NumSegments()) + data0 := segmentData(arena, 0) + require.Empty(t, data0) }) t.Run("ExistingData", func(t *testing.T) { t.Parallel() arena := MultiSegment([][]byte{incrementingData(8), incrementingData(24)}) - if n := arena.NumSegments(); n != 2 { - t.Errorf("MultiSegment(...).NumSegments() = %d; want 2", n) - } - data, err := arena.Data(0) - if want := incrementingData(8); !bytes.Equal(data, want) { - t.Errorf("MultiSegment(...).Data(0) = %#v; want %#v", data, want) - } - if err != nil { - t.Errorf("MultiSegment(...).Data(0) error: %v", err) - } - data, err = arena.Data(1) - if want := incrementingData(24); !bytes.Equal(data, want) { - t.Errorf("MultiSegment(...).Data(1) = %#v; want %#v", data, want) - } - if err != nil { - t.Errorf("MultiSegment(...).Data(1) error: %v", err) - } - _, err = arena.Data(2) - if err == nil { - t.Error("MultiSegment(...).Data(2) succeeded; want error") - } + require.Equal(t, int64(2), arena.NumSegments()) + data0 := segmentData(arena, 0) + require.Equal(t, incrementingData(8), data0) + data1 := segmentData(arena, 1) + require.Equal(t, incrementingData(24), data1) + data2 := segmentData(arena, 2) + require.Empty(t, data2) }) } @@ -175,7 +189,7 @@ func TestMultiSegmentAllocate(t *testing.T) { }, size: 8, id: 0, - data: []byte{}, + data: []byte{7: 0}, }, { name: "space in unloaded segment", @@ -185,7 +199,7 @@ func TestMultiSegmentAllocate(t *testing.T) { }, size: 8, id: 0, - data: incrementingData(16), + data: incrementingData(24)[16 : 16+8], }, { name: "space in loaded segment", @@ -199,7 +213,7 @@ func TestMultiSegmentAllocate(t *testing.T) { }, size: 8, id: 0, - data: incrementingData(16), + data: incrementingData(24)[16 : 16+8], }, { name: "space in loaded segment changes length", @@ -212,24 +226,27 @@ func TestMultiSegmentAllocate(t *testing.T) { }, size: 8, id: 0, - data: incrementingData(24), + data: incrementingData(24)[16 : 16+8], }, { - name: "message-filled segment", + name: "first segment is filled", init: func() (Arena, map[SegmentID]*Segment) { buf := incrementingData(24) segs := map[SegmentID]*Segment{ 0: {id: 0, data: buf}, } - return MultiSegment([][]byte{buf[:16]}), segs + msa := MultiSegment([][]byte{buf[:16:16]}) + msa.bp = &bufferpool.Default + return msa, segs }, size: 8, id: 1, - data: []byte{}, + data: []byte{7: 0}, }, } for i := range tests { - tests[i].run(t, i) + tc := tests[i] + t.Run(tc.name, tc.run) } } diff --git a/capability.go b/capability.go index 7eee16e2..00fef962 100644 --- a/capability.go +++ b/capability.go @@ -59,7 +59,7 @@ func (i Interface) Message() *Message { if i.seg == nil { return nil } - return i.seg.msg + return i.seg.Message() } // IsValid returns whether the interface is valid. diff --git a/codec_test.go b/codec_test.go index 9a66c9d8..60ced825 100644 --- a/codec_test.go +++ b/codec_test.go @@ -81,12 +81,12 @@ type tooManySegsArena struct { func (t *tooManySegsArena) NumSegments() int64 { return 1<<32 + 1 } -func (t *tooManySegsArena) Data(id SegmentID) ([]byte, error) { - return nil, errors.New("no data") +func (t *tooManySegsArena) Segment(id SegmentID) *Segment { + return nil } -func (t *tooManySegsArena) Allocate(minsz Size, segs map[SegmentID]*Segment) (SegmentID, []byte, error) { - return 0, nil, errors.New("cannot allocate") +func (t *tooManySegsArena) Allocate(minsz Size, msg *Message, seg *Segment) (*Segment, address, error) { + return nil, 0, errors.New("cannot allocate") } func (t *tooManySegsArena) Release() {} diff --git a/integration_test.go b/integration_test.go index 38f375c6..5353930a 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1851,13 +1851,13 @@ func BenchmarkUnmarshal_Reuse(b *testing.B) { data[i].data, data[i].a = buf, *a } msg := new(capnp.Message) - ta := new(testArena) + ta := capnp.NewReadOnlySingleSegment(nil) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { testIdx := r.Intn(len(data)) - *ta = testArena(data[testIdx].data[8:]) // [8:] to skip header msg.Release() + ta.ReplaceData(data[testIdx].data[8:]) // [8:] to skip header msg.Arena = ta a, err := air.ReadRootBenchmarkA(msg) if err != nil { @@ -1913,25 +1913,6 @@ func BenchmarkDecode(b *testing.B) { } } -type testArena []byte - -func (ta testArena) NumSegments() int64 { - return 1 -} - -func (ta testArena) Data(id capnp.SegmentID) ([]byte, error) { - if id != 0 { - return nil, errors.New("test arena: requested non-zero segment") - } - return []byte(ta), nil -} - -func (ta testArena) Allocate(capnp.Size, map[capnp.SegmentID]*capnp.Segment) (capnp.SegmentID, []byte, error) { - return 0, nil, errors.New("test arena: can't allocate") -} - -func (ta testArena) Release() {} - func TestPointerTraverseDefense(t *testing.T) { t.Parallel() const limit = 128 diff --git a/list.go b/list.go index 3099f5f8..211c9c09 100644 --- a/list.go +++ b/list.go @@ -100,7 +100,7 @@ func (p List) Message() *Message { if p.seg == nil { return nil } - return p.seg.msg + return p.seg.Message() } // IsValid returns whether the list is valid. diff --git a/message.go b/message.go index db1acc1a..41be6458 100644 --- a/message.go +++ b/message.go @@ -54,11 +54,6 @@ type Message struct { // DepthLimit limits how deeply-nested a message structure can be. // If not set, this defaults to 64. DepthLimit uint - - // mu protects the following fields: - mu sync.Mutex - segs map[SegmentID]*Segment - firstSeg Segment // Preallocated first segment. msg is non-nil once initialized. } // NewMessage creates a message with a new root and returns the first @@ -109,14 +104,6 @@ func (m *Message) Release() { // not read. func (m *Message) Reset(arena Arena) (first *Segment, err error) { m.capTable.Reset() - for k := range m.segs { - // Optimization: keep the first segment so that the re-used - // Message does not have to allocate a new one. - if k == 0 && m.segs[k] == &m.firstSeg { - continue - } - delete(m.segs, k) - } if m.Arena != nil { m.Arena.Release() @@ -127,20 +114,18 @@ func (m *Message) Reset(arena Arena) (first *Segment, err error) { TraverseLimit: m.TraverseLimit, DepthLimit: m.DepthLimit, capTable: m.capTable, - segs: m.segs, - firstSeg: Segment{msg: m}, } if arena != nil { switch arena.NumSegments() { case 0: - if first, err = m.allocSegment(wordSize); err != nil { + if first, _, err = arena.Allocate(0, m, nil); err != nil { return nil, exc.WrapError("new message", err) } case 1: if first, err = m.Segment(0); err != nil { - return nil, exc.WrapError("new message", err) + return nil, exc.WrapError("Reset.Segment(0)", err) } if len(first.data) > 0 { return nil, errors.New("new message: arena not empty") @@ -153,7 +138,6 @@ func (m *Message) Reset(arena Arena) (first *Segment, err error) { if first.ID() != 0 { return nil, errors.New("new message: arena allocated first segment with non-zero ID") } - seg, _, err := alloc(first, wordSize) // allocate root if err != nil { return nil, exc.WrapError("new message", err) @@ -268,91 +252,16 @@ func (m *Message) NumSegments() int64 { // Segment returns the segment with the given ID. func (m *Message) Segment(id SegmentID) (*Segment, error) { - if int64(id) >= m.Arena.NumSegments() { - return nil, errors.New("segment " + str.Utod(id) + ": out of bounds") - } - m.mu.Lock() - seg, err := m.segment(id) - m.mu.Unlock() - return seg, err -} - -// segment returns the segment with the given ID, with no bounds -// checking. The caller must be holding m.mu. -func (m *Message) segment(id SegmentID) (*Segment, error) { - if m.segs == nil && id == 0 && m.firstSeg.msg != nil && m.firstSeg.data != nil { - return &m.firstSeg, nil - } - if s := m.segs[id]; s != nil && s.data != nil { - return s, nil - } - if len(m.segs) == maxInt { - return nil, errors.New("segment " + str.Utod(id) + ": number of loaded segments exceeds int") - } - data, err := m.Arena.Data(id) - if err != nil { - return nil, exc.WrapError("load segment "+str.Utod(id), err) - } - s := m.setSegment(id, data) - return s, nil -} - -// setSegment creates or updates the Segment with the given ID. -// The caller must be holding m.mu. -func (m *Message) setSegment(id SegmentID, data []byte) *Segment { - if m.segs == nil { - if id == 0 { - m.firstSeg = Segment{ - id: id, - msg: m, - data: data, - } - return &m.firstSeg - } - m.segs = make(map[SegmentID]*Segment) - if m.firstSeg.msg != nil { - m.segs[0] = &m.firstSeg - } - } else if seg := m.segs[id]; seg != nil { - seg.data = data - return seg - } - seg := &Segment{ - id: id, - msg: m, - data: data, - } - m.segs[id] = seg - return seg -} - -// allocSegment creates or resizes an existing segment such that -// cap(seg.Data) - len(seg.Data) >= sz. The caller must not be holding -// onto m.mu. -func (m *Message) allocSegment(sz Size) (*Segment, error) { - if sz > maxAllocSize() { - return nil, errors.New("allocation: too large") - } - - m.mu.Lock() - defer m.mu.Unlock() - - if len(m.segs) == maxInt { - return nil, errors.New("allocation: number of loaded segments exceeds int") - } - - // Transition from sole segment to segment map? - if m.segs == nil && m.firstSeg.msg != nil { - m.segs = make(map[SegmentID]*Segment) - m.segs[0] = &m.firstSeg + seg := m.Arena.Segment(id) + if seg == nil { + return nil, errors.New("segment " + str.Utod(id) + " out of bounds in arena") } - - id, data, err := m.Arena.Allocate(sz, m.segs) - if err != nil { - return nil, exc.WrapError("allocation", err) + segMsg := seg.Message() + if segMsg == nil { + seg.BindTo(m) + } else if segMsg != m { + return nil, errors.New("segment " + str.Utod(id) + ": not of the same message") } - - seg := m.setSegment(id, data) return seg, nil } @@ -375,31 +284,25 @@ func (m *Message) Marshal() ([]byte, error) { return nil, errors.New("marshal: header size overflows int") } var dataSize uint64 - m.mu.Lock() for i := int64(0); i < nsegs; i++ { - s, err := m.segment(SegmentID(i)) + s, err := m.Segment(SegmentID(i)) if err != nil { - m.mu.Unlock() return nil, exc.WrapError("marshal", err) } n := uint64(len(s.data)) if n%uint64(wordSize) != 0 { - m.mu.Unlock() return nil, errors.New("marshal: segment " + str.Itod(i) + " not word-aligned") } if n > uint64(maxSegmentSize) { - m.mu.Unlock() return nil, errors.New("marshal: segment " + str.Itod(i) + " too large") } dataSize += n if dataSize > uint64(maxInt) { - m.mu.Unlock() return nil, errors.New("marshal: message size overflows int") } } total := hdrSize + dataSize if total > uint64(maxInt) { - m.mu.Unlock() return nil, errors.New("marshal: message size overflows int") } @@ -407,19 +310,16 @@ func (m *Message) Marshal() ([]byte, error) { buf := make([]byte, int(hdrSize), int(total)) binary.LittleEndian.PutUint32(buf, uint32(nsegs-1)) for i := int64(0); i < nsegs; i++ { - s, err := m.segment(SegmentID(i)) + s, err := m.Segment(SegmentID(i)) if err != nil { - m.mu.Unlock() return nil, exc.WrapError("marshal", err) } if len(s.data)%int(wordSize) != 0 { - m.mu.Unlock() return nil, errors.New("marshal: segment " + str.Itod(i) + " not word-aligned") } binary.LittleEndian.PutUint32(buf[int(i+1)*4:], uint32(len(s.data)/int(wordSize))) buf = append(buf, s.data...) } - m.mu.Unlock() return buf, nil } @@ -454,23 +354,27 @@ func alloc(s *Segment, sz Size) (*Segment, address, error) { } sz = sz.padToWord() - if !hasCapacity(s.data, sz) { - var err error - s, err = s.msg.allocSegment(sz) - if err != nil { - return nil, 0, errors.New("allocSegment failed: " + err.Error()) - } + msg := s.Message() + if msg == nil { + return nil, 0, errors.New("segment does not have a message assotiated with it") + } + if msg.Arena == nil { + return nil, 0, errors.New("message does not have an arena") + } + + // TODO: From this point on, this could be changed to be a requirement + // for Arena implementations instead of relying on alloc() to do it. + + s, addr, err := msg.Arena.Allocate(sz, msg, s) + if err != nil { + return s, addr, err } - addr := address(len(s.data)) end, ok := addr.addSize(sz) if !ok { return nil, 0, errors.New("allocation: address overflow") } - space := s.data[len(s.data):end] - s.data = s.data[:end] - for i := range space { - space[i] = 0 - } + + zeroSlice(s.data[addr:end]) return s, addr, nil } diff --git a/message_test.go b/message_test.go index 4d3289b3..5c51972c 100644 --- a/message_test.go +++ b/message_test.go @@ -9,6 +9,7 @@ import ( "testing" "testing/quick" + "capnproto.org/go/capnp/v3/exp/bufferpool" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -23,8 +24,7 @@ func TestNewMessage(t *testing.T) { {arena: SingleSegment(nil)}, {arena: MultiSegment(nil)}, {arena: readOnlyArena{SingleSegment(make([]byte, 0, 7))}, fails: true}, - {arena: readOnlyArena{SingleSegment(make([]byte, 0, 8))}}, - {arena: MultiSegment(nil)}, + {arena: readOnlyArena{SingleSegment(make([]byte, 0, 8))}, fails: true}, {arena: MultiSegment([][]byte{make([]byte, 8)}), fails: true}, {arena: MultiSegment([][]byte{incrementingData(8)}), fails: true}, // This is somewhat arbitrary, but more than one segment = data. @@ -141,6 +141,9 @@ func TestAlloc(t *testing.T) { if err != nil { t.Fatal(err) } + + // Make arena not read-only again. + msg.Arena.(*MultiSegmentArena).bp = &bufferpool.Default tests = append(tests, allocTest{ name: "given segment full and no others available", seg: seg, @@ -591,45 +594,6 @@ func TestTotalSize(t *testing.T) { assert.Nil(t, err, "quick.Check returned an error") } -type arenaAllocTest struct { - name string - - // Arrange - init func() (Arena, map[SegmentID]*Segment) - size Size - - // Assert - id SegmentID - data []byte -} - -func (test *arenaAllocTest) run(t *testing.T, i int) { - arena, segs := test.init() - id, data, err := arena.Allocate(test.size, segs) - - if err != nil { - t.Errorf("tests[%d] - %s: Allocate error: %v", i, test.name, err) - return - } - if id != test.id { - t.Errorf("tests[%d] - %s: Allocate id = %d; want %d", i, test.name, id, test.id) - } - if !bytes.Equal(data, test.data) { - t.Errorf("tests[%d] - %s: Allocate data = % 02x; want % 02x", i, test.name, data, test.data) - } - if Size(cap(data)-len(data)) < test.size { - t.Errorf("tests[%d] - %s: Allocate len(data) = %d, cap(data) = %d; cap(data) should be at least %d", i, test.name, len(data), cap(data), Size(len(data))+test.size) - } -} - -func incrementingData(n int) []byte { - b := make([]byte, n) - for i := range b { - b[i] = byte(i % 256) - } - return b -} - type readOnlyArena struct { Arena } @@ -638,8 +602,12 @@ func (ro readOnlyArena) String() string { return fmt.Sprintf("readOnlyArena{%v}", ro.Arena) } -func (readOnlyArena) Allocate(sz Size, segs map[SegmentID]*Segment) (SegmentID, []byte, error) { - return 0, nil, errReadOnlyArena +func (readOnlyArena) Allocate(sz Size, msg *Message, seg *Segment) (*Segment, address, error) { + return nil, 0, errReadOnlyArena +} + +func (ro readOnlyArena) Segment(id SegmentID) *Segment { + return ro.Arena.Segment(id) } var errReadOnlyArena = errors.New("Allocate called on read-only arena") diff --git a/pointer.go b/pointer.go index d5b50483..330acf99 100644 --- a/pointer.go +++ b/pointer.go @@ -187,7 +187,7 @@ func (p Ptr) Message() *Message { if p.seg == nil { return nil } - return p.seg.msg + return p.seg.Message() } // Default returns p if it is valid, otherwise it unmarshals def. diff --git a/segment.go b/segment.go index c573f93a..8bca06a0 100644 --- a/segment.go +++ b/segment.go @@ -17,7 +17,8 @@ type SegmentID uint32 type Segment struct { // msg associated with this segment. A Message instance m maintains the // invariant that that all m.segs[].msg == m. - msg *Message + msg *Message + id SegmentID data []byte } @@ -27,6 +28,12 @@ func (s *Segment) Message() *Message { return s.msg } +// BindTo binds the segment to a given message. This is usually only called by +// Arena implementations and does not perform any kind of safety check. +func (s *Segment) BindTo(m *Message) { + s.msg = m +} + // ID returns the segment's ID. func (s *Segment) ID() SegmentID { return s.id @@ -107,7 +114,7 @@ func (s *Segment) root() PointerList { seg: s, length: 1, size: sz, - depthLimit: s.msg.depthLimit(), + depthLimit: s.Message().depthLimit(), } } @@ -115,7 +122,7 @@ func (s *Segment) lookupSegment(id SegmentID) (*Segment, error) { if s.id == id { return s, nil } - return s.msg.Segment(id) + return s.Message().Segment(id) } func (s *Segment) readPtr(paddr address, depthLimit uint) (ptr Ptr, err error) { @@ -135,7 +142,7 @@ func (s *Segment) readPtr(paddr address, depthLimit uint) (ptr Ptr, err error) { if err != nil { return Ptr{}, exc.WrapError("read pointer", err) } - if !s.msg.canRead(sp.readSize()) { + if !s.Message().canRead(sp.readSize()) { return Ptr{}, errors.New("read pointer: read traversal limit reached") } sp.depthLimit = depthLimit - 1 @@ -145,7 +152,7 @@ func (s *Segment) readPtr(paddr address, depthLimit uint) (ptr Ptr, err error) { if err != nil { return Ptr{}, exc.WrapError("read pointer", err) } - if !s.msg.canRead(lp.readSize()) { + if !s.Message().canRead(lp.readSize()) { return Ptr{}, errors.New("read pointer: read traversal limit reached") } lp.depthLimit = depthLimit - 1 @@ -377,8 +384,8 @@ func (s *Segment) writePtr(off address, src Ptr, forceCopy bool) error { srcRaw = l.raw() case interfacePtrType: i := src.Interface() - if src.seg.msg != s.msg { - c := s.msg.CapTable().Add(i.Client().AddRef()) + if src.seg.Message() != s.Message() { + c := s.Message().CapTable().Add(i.Client().AddRef()) i = NewInterface(s, c) } s.writeRawPointer(off, i.value(off)) diff --git a/segment_test.go b/segment_test.go index 8753909f..95bf74d7 100644 --- a/segment_test.go +++ b/segment_test.go @@ -5,6 +5,8 @@ import ( "encoding/binary" "fmt" "testing" + + "capnproto.org/go/capnp/v3/exp/bufferpool" ) func TestSegmentInBounds(t *testing.T) { @@ -666,6 +668,9 @@ func TestWriteDoubleFarPointer(t *testing.T) { make([]byte, 0, 16), }), } + + // Make arena writable again. + msg.Arena.(*MultiSegmentArena).bp = &bufferpool.Default seg1, err := msg.Segment(1) if err != nil { t.Fatal("msg.Segment(1):", err) diff --git a/struct.go b/struct.go index 201c006a..fed3054b 100644 --- a/struct.go +++ b/struct.go @@ -46,7 +46,7 @@ func NewRootStruct(s *Segment, sz ObjectSize) (Struct, error) { if err != nil { return st, err } - if err := s.msg.SetRoot(st.ToPtr()); err != nil { + if err := s.Message().SetRoot(st.ToPtr()); err != nil { return st, err } return st, nil @@ -75,7 +75,7 @@ func (p Struct) Message() *Message { if p.seg == nil { return nil } - return p.seg.msg + return p.seg.Message() } // IsValid returns whether the struct is valid.