From 6c2d5980007c9c1a850bec441a68385371654bcc Mon Sep 17 00:00:00 2001 From: Daichi Tomaru Date: Sun, 15 Jan 2023 11:21:51 +0900 Subject: [PATCH] Check client side maximum packet size eclipse#3 Signed-off-by: Daichi Tomaru --- packets/packets.go | 19 ++++++++++++++++++- paho/client.go | 18 ++++++++++++++++-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/packets/packets.go b/packets/packets.go index 4965940..d6b783a 100644 --- a/packets/packets.go +++ b/packets/packets.go @@ -6,6 +6,7 @@ import ( "io" "net" "sync" + "errors" ) // PacketType is a type alias to byte representing the different @@ -33,6 +34,10 @@ const ( AUTH ) +var ( + ErrPacketTooLarge = errors.New("Received packet whose size exceeds client's maximum packet size.") +) + type ( // Packet is the interface defining the unique parts of a controlpacket Packet interface { @@ -185,7 +190,8 @@ func NewControlPacket(t byte) *ControlPacket { // ReadPacket reads a control packet from a io.Reader and returns a completed // struct with the appropriate data -func ReadPacket(r io.Reader) (*ControlPacket, error) { +// If maximumPacketSize == 0, packet size will be not checked +func ReadPacketByMaximum(r io.Reader, maximumPacketSize uint32) (*ControlPacket, error) { t := [1]byte{} _, err := io.ReadFull(r, t[:]) if err != nil { @@ -252,7 +258,11 @@ func ReadPacket(r io.Reader) (*ControlPacket, error) { if err != nil { return nil, err } + vbiLen := vbi.Len() cp.remainingLength, err = decodeVBI(vbi) + if maximumPacketSize != 0 && uint32(1 + vbiLen + cp.remainingLength) > maximumPacketSize { + return nil, ErrPacketTooLarge + } if err != nil { return nil, err } @@ -275,6 +285,13 @@ func ReadPacket(r io.Reader) (*ControlPacket, error) { return cp, nil } +// ReadPacketByMaximum reads a control packet from a io.Reader without considering client's maximum packet size limit +// and returns a completed struct with the appropriate data +// This implementation leaves for compatibility considerations. +func ReadPacket (r io.Reader) (*ControlPacket, error) { + return ReadPacketByMaximum(r, 0) +} + // WriteTo writes a packet to an io.Writer, handling packing all the parts of // a control packet. func (c *ControlPacket) WriteTo(w io.Writer) (int64, error) { diff --git a/paho/client.go b/paho/client.go index f41e3d0..484605b 100644 --- a/paho/client.go +++ b/paho/client.go @@ -409,7 +409,11 @@ func (c *Client) incoming() { case <-c.stop: return default: - recv, err := packets.ReadPacket(c.Conn) + recv, err := packets.ReadPacketByMaximum(c.Conn, c.clientProps.MaximumPacketSize) + if err == packets.ErrPacketTooLarge { + go c.errorWithDisconnect(err, &Disconnect{ReasonCode: 0x95}) + return + } if err != nil { go c.error(err) return @@ -558,6 +562,12 @@ func (c *Client) error(e error) { go c.OnClientError(e) } +func (c *Client) errorWithDisconnect(e error, d *Disconnect) { + c.debug.Println("error called:", e) + c.Disconnect(d) + go c.OnClientError(e) +} + func (c *Client) serverDisconnect(d *Disconnect) { c.close() c.workers.Wait() @@ -863,7 +873,11 @@ func (c *Client) publishQoS12(ctx context.Context, pb *packets.Publish) (*Publis } func (c *Client) expectConnack(packet chan<- *packets.Connack, errs chan<- error) { - recv, err := packets.ReadPacket(c.Conn) + recv, err := packets.ReadPacketByMaximum(c.Conn, c.clientProps.MaximumPacketSize) + if err == packets.ErrPacketTooLarge { + go c.errorWithDisconnect(err, &Disconnect{ReasonCode: 0x95}) + return + } if err != nil { errs <- err return