From 75608e1c34941c98e6af5a3c218cd4298a329f8b Mon Sep 17 00:00:00 2001 From: Daichi Tomaru Date: Sun, 15 Jan 2023 10:04:52 +0900 Subject: [PATCH] Check client side maximum packet size #3 Signed-off-by: Daichi Tomaru --- packets/packets.go | 4 ++++ packets/packets_test.go | 8 ++++++++ paho/client.go | 13 +++++++++++++ 3 files changed, 25 insertions(+) diff --git a/packets/packets.go b/packets/packets.go index 4965940..c0b0696 100644 --- a/packets/packets.go +++ b/packets/packets.go @@ -109,6 +109,10 @@ func (c *ControlPacket) PacketID() uint16 { } } +func (c *ControlPacket) GetPacketSize() int { + return 1 + len(encodeVBI(c.FixedHeader.remainingLength)) + c.remainingLength +} + func (c *ControlPacket) PacketType() string { return [...]string{ "", diff --git a/packets/packets_test.go b/packets/packets_test.go index b0f1bd4..2c3d542 100644 --- a/packets/packets_test.go +++ b/packets/packets_test.go @@ -138,6 +138,14 @@ func TestReadPacketConnect(t *testing.T) { assert.Equal(t, uint32(30), *c.Content.(*Connect).Properties.SessionExpiryInterval) } +func TestGetPacketSize(t *testing.T) { + // PUBLISH packet (topic: test, message: test) + p := []byte{0x30, 0x0b, 0x00, 0x04, 0x74, 0x65, 0x73, 0x74, 0x00, 0x74, 0x65, 0x73, 0x74} + c, err := ReadPacket(bufio.NewReader(bytes.NewReader(p))) + require.Nil(t, err) + assert.Equal(t, len(p), c.GetPacketSize()) +} + func TestReadStringWriteString(t *testing.T) { var b bytes.Buffer writeString("Test string", &b) diff --git a/paho/client.go b/paho/client.go index f41e3d0..52f80e8 100644 --- a/paho/client.go +++ b/paho/client.go @@ -414,6 +414,13 @@ func (c *Client) incoming() { go c.error(err) return } + if c.clientProps.MaximumPacketSize != 0 && recv.GetPacketSize() > int(c.clientProps.MaximumPacketSize) { + go c.errorWithDisconnect( + errors.New("received a packet whose size exceeds the client's maximum packet size limit."), + &Disconnect{ReasonCode: 0x95}, + ) + return + } switch recv.Type { case packets.CONNACK: c.debug.Println("received CONNACK") @@ -558,6 +565,12 @@ func (c *Client) error(e error) { go c.OnClientError(e) } +func (c *Client) errorWithDisconnect(e error, d *Disconnect) { + c.debug.Println("error called:", e) + c.Disconnect(d) + go c.OnClientError(e) +} + func (c *Client) serverDisconnect(d *Disconnect) { c.close() c.workers.Wait()