Skip to content

Commit

Permalink
Check client side maximum packet size eclipse#3
Browse files Browse the repository at this point in the history
Signed-off-by: Daichi Tomaru <[email protected]>
  • Loading branch information
tomatod committed Jan 17, 2023
1 parent d63b3b2 commit 6c2d598
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
19 changes: 18 additions & 1 deletion packets/packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"net"
"sync"
"errors"
)

// PacketType is a type alias to byte representing the different
Expand Down Expand Up @@ -33,6 +34,10 @@ const (
AUTH
)

var (
ErrPacketTooLarge = errors.New("Received packet whose size exceeds client's maximum packet size.")
)

type (
// Packet is the interface defining the unique parts of a controlpacket
Packet interface {
Expand Down Expand Up @@ -185,7 +190,8 @@ func NewControlPacket(t byte) *ControlPacket {

// ReadPacket reads a control packet from a io.Reader and returns a completed
// struct with the appropriate data
func ReadPacket(r io.Reader) (*ControlPacket, error) {
// If maximumPacketSize == 0, packet size will be not checked
func ReadPacketByMaximum(r io.Reader, maximumPacketSize uint32) (*ControlPacket, error) {
t := [1]byte{}
_, err := io.ReadFull(r, t[:])
if err != nil {
Expand Down Expand Up @@ -252,7 +258,11 @@ func ReadPacket(r io.Reader) (*ControlPacket, error) {
if err != nil {
return nil, err
}
vbiLen := vbi.Len()
cp.remainingLength, err = decodeVBI(vbi)
if maximumPacketSize != 0 && uint32(1 + vbiLen + cp.remainingLength) > maximumPacketSize {
return nil, ErrPacketTooLarge
}
if err != nil {
return nil, err
}
Expand All @@ -275,6 +285,13 @@ func ReadPacket(r io.Reader) (*ControlPacket, error) {
return cp, nil
}

// ReadPacketByMaximum reads a control packet from a io.Reader without considering client's maximum packet size limit
// and returns a completed struct with the appropriate data
// This implementation leaves for compatibility considerations.
func ReadPacket (r io.Reader) (*ControlPacket, error) {
return ReadPacketByMaximum(r, 0)
}

// WriteTo writes a packet to an io.Writer, handling packing all the parts of
// a control packet.
func (c *ControlPacket) WriteTo(w io.Writer) (int64, error) {
Expand Down
18 changes: 16 additions & 2 deletions paho/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,11 @@ func (c *Client) incoming() {
case <-c.stop:
return
default:
recv, err := packets.ReadPacket(c.Conn)
recv, err := packets.ReadPacketByMaximum(c.Conn, c.clientProps.MaximumPacketSize)
if err == packets.ErrPacketTooLarge {
go c.errorWithDisconnect(err, &Disconnect{ReasonCode: 0x95})
return
}
if err != nil {
go c.error(err)
return
Expand Down Expand Up @@ -558,6 +562,12 @@ func (c *Client) error(e error) {
go c.OnClientError(e)
}

func (c *Client) errorWithDisconnect(e error, d *Disconnect) {
c.debug.Println("error called:", e)
c.Disconnect(d)
go c.OnClientError(e)
}

func (c *Client) serverDisconnect(d *Disconnect) {
c.close()
c.workers.Wait()
Expand Down Expand Up @@ -863,7 +873,11 @@ func (c *Client) publishQoS12(ctx context.Context, pb *packets.Publish) (*Publis
}

func (c *Client) expectConnack(packet chan<- *packets.Connack, errs chan<- error) {
recv, err := packets.ReadPacket(c.Conn)
recv, err := packets.ReadPacketByMaximum(c.Conn, c.clientProps.MaximumPacketSize)
if err == packets.ErrPacketTooLarge {
go c.errorWithDisconnect(err, &Disconnect{ReasonCode: 0x95})
return
}
if err != nil {
errs <- err
return
Expand Down

0 comments on commit 6c2d598

Please sign in to comment.