Skip to content

Commit

Permalink
separate method to keep backwards compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia committed Oct 30, 2024
1 parent f2e17f3 commit ab8a125
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 15 deletions.
27 changes: 16 additions & 11 deletions decode_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ var (
streamProtobufCommandDecoderPool sync.Pool
)

func GetStreamCommandDecoder(protoType Type, reader io.Reader, messageSizeLimit int) StreamCommandDecoder {
func GetStreamCommandDecoder(protoType Type, reader io.Reader) StreamCommandDecoder {
return GetStreamCommandDecoderLimited(protoType, reader, 0)
}

func GetStreamCommandDecoderLimited(protoType Type, reader io.Reader, messageSizeLimit int64) StreamCommandDecoder {
if protoType == TypeJSON {
e := streamJsonCommandDecoderPool.Get()
if e == nil {
Expand Down Expand Up @@ -45,7 +49,7 @@ func PutStreamCommandDecoder(protoType Type, e StreamCommandDecoder) {

type StreamCommandDecoder interface {
Decode() (*Command, int, error)
Reset(reader io.Reader, messageSizeLimit int)
Reset(reader io.Reader, messageSizeLimit int64)
}

// ErrMessageTooLarge for when the message exceeds the limit.
Expand All @@ -54,14 +58,14 @@ var ErrMessageTooLarge = errors.New("message size exceeds the limit")
type JSONStreamCommandDecoder struct {
reader *bufio.Reader
limitedReader *io.LimitedReader
messageSizeLimit int
messageSizeLimit int64
}

func NewJSONStreamCommandDecoder(reader io.Reader, messageSizeLimit int) *JSONStreamCommandDecoder {
func NewJSONStreamCommandDecoder(reader io.Reader, messageSizeLimit int64) *JSONStreamCommandDecoder {
var limitedReader *io.LimitedReader
var bufioReader *bufio.Reader
if messageSizeLimit > 0 {
limitedReader = &io.LimitedReader{R: reader, N: int64(messageSizeLimit) + 1}
limitedReader = &io.LimitedReader{R: reader, N: messageSizeLimit + 1}
bufioReader = bufio.NewReader(limitedReader)
} else {
bufioReader = bufio.NewReader(reader)
Expand All @@ -79,7 +83,7 @@ func (d *JSONStreamCommandDecoder) Decode() (*Command, int, error) {
}
cmdBytes, err := d.reader.ReadBytes('\n')
if err != nil {
if d.messageSizeLimit > 0 && len(cmdBytes) > d.messageSizeLimit {
if d.messageSizeLimit > 0 && int64(len(cmdBytes)) > d.messageSizeLimit {
return nil, 0, ErrMessageTooLarge
}
if err == io.EOF && len(cmdBytes) > 0 {
Expand All @@ -101,24 +105,25 @@ func (d *JSONStreamCommandDecoder) Decode() (*Command, int, error) {
return &c, len(cmdBytes), nil
}

func (d *JSONStreamCommandDecoder) Reset(reader io.Reader, messageSizeLimit int) {
func (d *JSONStreamCommandDecoder) Reset(reader io.Reader, messageSizeLimit int64) {
d.messageSizeLimit = messageSizeLimit
if messageSizeLimit > 0 {
limitedReader := &io.LimitedReader{R: reader, N: int64(messageSizeLimit) + 1}
limitedReader := &io.LimitedReader{R: reader, N: messageSizeLimit + 1}
bufioReader := bufio.NewReader(limitedReader)
d.limitedReader = limitedReader
d.reader.Reset(bufioReader)
} else {
d.limitedReader = nil
d.reader.Reset(reader)
}
}

type ProtobufStreamCommandDecoder struct {
reader *bufio.Reader
messageSizeLimit int
messageSizeLimit int64
}

func NewProtobufStreamCommandDecoder(reader io.Reader, messageSizeLimit int) *ProtobufStreamCommandDecoder {
func NewProtobufStreamCommandDecoder(reader io.Reader, messageSizeLimit int64) *ProtobufStreamCommandDecoder {
return &ProtobufStreamCommandDecoder{reader: bufio.NewReader(reader), messageSizeLimit: messageSizeLimit}
}

Expand Down Expand Up @@ -150,7 +155,7 @@ func (d *ProtobufStreamCommandDecoder) Decode() (*Command, int, error) {
return &c, int(msgLength) + 8, nil
}

func (d *ProtobufStreamCommandDecoder) Reset(reader io.Reader, messageSizeLimit int) {
func (d *ProtobufStreamCommandDecoder) Reset(reader io.Reader, messageSizeLimit int64) {
d.messageSizeLimit = messageSizeLimit
d.reader.Reset(reader)
}
27 changes: 23 additions & 4 deletions decode_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ func TestStreamingDecode_JSON(t *testing.T) {

func TestStreamingDecode_JSON_MessageLimit(t *testing.T) {
frame := getTestFrame(t, TypeJSON, 10000)
dec := GetStreamCommandDecoder(TypeJSON, bytes.NewReader(frame), 100)
dec := GetStreamCommandDecoderLimited(TypeJSON, bytes.NewReader(frame), 100)
_, _, err := dec.Decode()
require.ErrorIs(t, err, ErrMessageTooLarge)
}

func TestStreamingDecode_Protobuf_MessageLimit(t *testing.T) {
frame := getTestFrame(t, TypeProtobuf, 10000)
dec := GetStreamCommandDecoder(TypeProtobuf, bytes.NewReader(frame), 100)
dec := GetStreamCommandDecoderLimited(TypeProtobuf, bytes.NewReader(frame), 100)
_, _, err := dec.Decode()
require.ErrorIs(t, err, ErrMessageTooLarge)
}
Expand Down Expand Up @@ -95,7 +95,7 @@ func BenchmarkStreamingDecode_JSON(b *testing.B) {
}

func testDecodingFrame(tb testing.TB, frame []byte, protoType Type) {
dec := GetStreamCommandDecoder(protoType, bytes.NewReader(frame), 200000)
dec := GetStreamCommandDecoder(protoType, bytes.NewReader(frame))
_, size, err := dec.Decode()
require.NoError(tb, err)
if protoType == TypeProtobuf {
Expand Down Expand Up @@ -129,7 +129,7 @@ func TestJSONStreamCommandDecoder(t *testing.T) {

testCases := []struct {
name string
messageSizeLimit int
messageSizeLimit int64
}{
{
name: "no limit",
Expand Down Expand Up @@ -169,3 +169,22 @@ func TestJSONStreamCommandDecoder(t *testing.T) {
})
}
}

func TestJSONStreamCommandDecoder_ReuseDifferentLimit(t *testing.T) {
// Sample data emulating a network stream of JSON commands with newlines
data := `{"publish":{"channel":"1","data":{}}}
{"publish":{"channel":"1","data":{}}}`
decoder := GetStreamCommandDecoderLimited(TypeJSON, bytes.NewBufferString(data), 10)
_, _, err := decoder.Decode()
require.ErrorIs(t, err, ErrMessageTooLarge)
PutStreamCommandDecoder(TypeJSON, decoder)
decoder = GetStreamCommandDecoderLimited(TypeJSON, bytes.NewBufferString(data), 0)
cmd, _, err := decoder.Decode()
require.NoError(t, err)
require.NotNil(t, cmd)
require.NotNil(t, cmd.Publish)
cmd, _, err = decoder.Decode()
require.ErrorIs(t, err, io.EOF)
require.NotNil(t, cmd)
require.NotNil(t, cmd.Publish)
}

0 comments on commit ab8a125

Please sign in to comment.