diff --git a/codec.go b/codec.go index 83b42a0e..c68ec375 100644 --- a/codec.go +++ b/codec.go @@ -195,6 +195,10 @@ func MustUnmarshalRoot(data []byte) Ptr { return p } +var ( + errTooManySegments = errors.New("message has too many segments") +) + // An Encoder represents a framer for serializing a particular Cap'n // Proto stream. type Encoder struct { @@ -220,6 +224,9 @@ func (e *Encoder) Encode(m *Message) error { if nsegs == 0 { return errors.New("encode: message has no segments") } + if nsegs > 1<<32 { + return exc.WrapError("encode", errTooManySegments) + } e.bufs = append(e.bufs[:0], nil) // first element is placeholder for header maxSeg := SegmentID(nsegs - 1) hdrSize := streamHeaderSize(maxSeg) diff --git a/codec_test.go b/codec_test.go index 1e194123..9a66c9d8 100644 --- a/codec_test.go +++ b/codec_test.go @@ -2,8 +2,11 @@ package capnp import ( "bytes" + "errors" "io" "testing" + + "github.com/stretchr/testify/require" ) func TestEncoder(t *testing.T) { @@ -72,6 +75,40 @@ func TestDecoder(t *testing.T) { } } +type tooManySegsArena struct { + data []byte +} + +func (t *tooManySegsArena) NumSegments() int64 { return 1<<32 + 1 } + +func (t *tooManySegsArena) Data(id SegmentID) ([]byte, error) { + return nil, errors.New("no data") +} + +func (t *tooManySegsArena) Allocate(minsz Size, segs map[SegmentID]*Segment) (SegmentID, []byte, error) { + return 0, nil, errors.New("cannot allocate") +} + +func (t *tooManySegsArena) Release() {} + +// TestEncoderTooManySegments verifies attempting to encode an arena that has +// more segments than possible. +func TestEncoderTooManySegments(t *testing.T) { + t.Parallel() + zeroWord := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + arena := &tooManySegsArena{data: zeroWord} + + // Setup via field because NewMessage checks arena has > 1 segments. + var msg Message + msg.Arena = arena + var buf bytes.Buffer + enc := NewEncoder(&buf) + err := enc.Encode(&msg) + + // Encoding should error with a specific error. + require.ErrorIs(t, err, errTooManySegments) +} + func TestDecoder_MaxMessageSize(t *testing.T) { t.Parallel()