From 89b49ecb84a21855245490278f2c1f85c895e2e2 Mon Sep 17 00:00:00 2001 From: alsm Date: Wed, 30 Jan 2019 11:09:25 +0000 Subject: [PATCH] Start to add tests and fix issues they're throwing up Adding tests are overdue and it's showing up the areas that I haven't got right, especially in the packets library eclipse/paho.golang#5 --- packets/connack.go | 14 ++- packets/packets.go | 1 + packets/packets_test.go | 252 ++++++++++++++++++++++++++++++++++++++++ packets/subscribe.go | 26 +++++ packets/unsubscribe.go | 18 ++- paho/client.go | 20 ++-- paho/client_test.go | 137 ++++++++++++++++++++++ paho/cp_unsubscribe.go | 11 +- paho/server_test.go | 89 ++++++++++++++ 9 files changed, 551 insertions(+), 17 deletions(-) create mode 100644 paho/client_test.go create mode 100644 paho/server_test.go diff --git a/packets/connack.go b/packets/connack.go index 4eb30c5..d115bb5 100644 --- a/packets/connack.go +++ b/packets/connack.go @@ -36,7 +36,19 @@ func (c *Connack) Unpack(r *bytes.Buffer) error { // Buffers is the implementation of the interface required function for a packet func (c *Connack) Buffers() net.Buffers { - return nil + var header bytes.Buffer + + if c.SessionPresent { + header.WriteByte(1) + } else { + header.WriteByte(0) + } + header.WriteByte(c.ReasonCode) + + idvp := c.Properties.Pack(CONNACK) + propLen := encodeVBI(len(idvp)) + + return net.Buffers{header.Bytes(), propLen, idvp} } // WriteTo is the implementation of the interface required function for a packet diff --git a/packets/packets.go b/packets/packets.go index bb4fda6..6f48dd2 100644 --- a/packets/packets.go +++ b/packets/packets.go @@ -176,6 +176,7 @@ func ReadPacket(r io.Reader) (*ControlPacket, error) { if err != nil { return nil, err } + if n != int64(cp.remainingLength) { return nil, fmt.Errorf("failed to read packet, expected %d bytes, read %d", cp.remainingLength, n) } diff --git a/packets/packets_test.go b/packets/packets_test.go index 710318f..d523916 100644 --- a/packets/packets_test.go +++ b/packets/packets_test.go @@ -3,6 +3,7 @@ package packets import ( "bufio" "bytes" + "reflect" "testing" "github.com/stretchr/testify/assert" @@ -142,3 +143,254 @@ func TestReadStringWriteString(t *testing.T) { require.Nil(t, err) assert.Equal(t, "Test string", s) } + +func TestNewControlPacket(t *testing.T) { + tests := []struct { + name string + args PacketType + want *ControlPacket + }{ + { + name: "connect", + args: CONNECT, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: CONNECT}, + Content: &Connect{ + ProtocolName: "MQTT", + ProtocolVersion: 5, + Properties: &Properties{User: make(map[string]string)}, + }, + }, + }, + { + name: "connack", + args: CONNACK, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: CONNACK}, + Content: &Connack{Properties: &Properties{User: make(map[string]string)}}, + }, + }, + { + name: "publish", + args: PUBLISH, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: PUBLISH}, + Content: &Publish{Properties: &Properties{User: make(map[string]string)}}, + }, + }, + { + name: "puback", + args: PUBACK, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: PUBACK}, + Content: &Puback{Properties: &Properties{User: make(map[string]string)}}, + }, + }, + { + name: "pubrec", + args: PUBREC, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: PUBREC}, + Content: &Pubrec{Properties: &Properties{User: make(map[string]string)}}, + }, + }, + { + name: "pubrel", + args: PUBREL, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: PUBREL, Flags: 2}, + Content: &Pubrel{Properties: &Properties{User: make(map[string]string)}}, + }, + }, + { + name: "pubcomp", + args: PUBCOMP, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: PUBCOMP}, + Content: &Pubcomp{Properties: &Properties{User: make(map[string]string)}}, + }, + }, + { + name: "subscribe", + args: SUBSCRIBE, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: SUBSCRIBE, Flags: 2}, + Content: &Subscribe{ + Properties: &Properties{User: make(map[string]string)}, + Subscriptions: make(map[string]SubOptions), + }, + }, + }, + { + name: "suback", + args: SUBACK, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: SUBACK}, + Content: &Suback{Properties: &Properties{User: make(map[string]string)}}, + }, + }, + { + name: "unsubscribe", + args: UNSUBSCRIBE, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: UNSUBSCRIBE, Flags: 2}, + Content: &Unsubscribe{Properties: &Properties{User: make(map[string]string)}}, + }, + }, + { + name: "unsuback", + args: UNSUBACK, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: UNSUBACK}, + Content: &Unsuback{Properties: &Properties{User: make(map[string]string)}}, + }, + }, + { + name: "pingreq", + args: PINGREQ, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: PINGREQ}, + Content: &Pingreq{}, + }, + }, + { + name: "pingresp", + args: PINGRESP, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: PINGRESP}, + Content: &Pingresp{}, + }, + }, + { + name: "disconnect", + args: DISCONNECT, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: DISCONNECT}, + Content: &Disconnect{Properties: &Properties{User: make(map[string]string)}}, + }, + }, + { + name: "auth", + args: AUTH, + want: &ControlPacket{ + FixedHeader: FixedHeader{Type: AUTH, Flags: 1}, + Content: &Auth{Properties: &Properties{User: make(map[string]string)}}, + }, + }, + { + name: "dummy", + args: 20, + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewControlPacket(tt.args); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewControlPacket() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestControlPacket_PacketID(t *testing.T) { + type fields struct { + Content Packet + FixedHeader FixedHeader + } + tests := []struct { + name string + fields fields + want uint16 + }{ + { + name: "publish", + fields: fields{ + FixedHeader: FixedHeader{Type: PUBLISH}, + Content: &Publish{PacketID: 123}, + }, + want: 123, + }, + { + name: "puback", + fields: fields{ + FixedHeader: FixedHeader{Type: PUBACK}, + Content: &Puback{PacketID: 123}, + }, + want: 123, + }, + { + name: "pubrel", + fields: fields{ + FixedHeader: FixedHeader{Type: PUBREL}, + Content: &Pubrel{PacketID: 123}, + }, + want: 123, + }, + { + name: "pubrec", + fields: fields{ + FixedHeader: FixedHeader{Type: PUBREC}, + Content: &Pubrec{PacketID: 123}, + }, + want: 123, + }, + { + name: "pubcomp", + fields: fields{ + FixedHeader: FixedHeader{Type: PUBCOMP}, + Content: &Pubcomp{PacketID: 123}, + }, + want: 123, + }, + { + name: "subscribe", + fields: fields{ + FixedHeader: FixedHeader{Type: SUBSCRIBE}, + Content: &Subscribe{PacketID: 123}, + }, + want: 123, + }, + { + name: "suback", + fields: fields{ + FixedHeader: FixedHeader{Type: SUBACK}, + Content: &Suback{PacketID: 123}, + }, + want: 123, + }, + { + name: "unsubscribe", + fields: fields{ + FixedHeader: FixedHeader{Type: UNSUBSCRIBE}, + Content: &Unsubscribe{PacketID: 123}, + }, + want: 123, + }, + { + name: "unsuback", + fields: fields{ + FixedHeader: FixedHeader{Type: UNSUBACK}, + Content: &Unsuback{PacketID: 123}, + }, + want: 123, + }, { + name: "connect", + fields: fields{ + FixedHeader: FixedHeader{Type: CONNECT}, + Content: &Connect{}, + }, + want: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ControlPacket{ + Content: tt.fields.Content, + FixedHeader: tt.fields.FixedHeader, + } + if got := c.PacketID(); got != tt.want { + t.Errorf("ControlPacket.PacketID() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/packets/subscribe.go b/packets/subscribe.go index 992b85e..6e611f4 100644 --- a/packets/subscribe.go +++ b/packets/subscribe.go @@ -36,6 +36,20 @@ func (s *SubOptions) Pack() byte { return ret } +func (s *SubOptions) Unpack(r *bytes.Buffer) error { + b, err := r.ReadByte() + if err != nil { + return err + } + + s.QoS = b & 0x03 + s.NoLocal = (b & 1 << 2) == 1 + s.RetainAsPublished = (b & 1 << 3) == 1 + s.RetainHandling = b & 0x30 + + return nil +} + // Unpack is the implementation of the interface required function for a packet func (s *Subscribe) Unpack(r *bytes.Buffer) error { var err error @@ -49,6 +63,18 @@ func (s *Subscribe) Unpack(r *bytes.Buffer) error { return err } + for r.Len() > 0 { + var so SubOptions + t, err := readString(r) + if err != nil { + return err + } + if err = so.Unpack(r); err != nil { + return err + } + s.Subscriptions[t] = so + } + return nil } diff --git a/packets/unsubscribe.go b/packets/unsubscribe.go index 471203e..24f1fed 100644 --- a/packets/unsubscribe.go +++ b/packets/unsubscribe.go @@ -15,6 +15,17 @@ type Unsubscribe struct { // Unpack is the implementation of the interface required function for a packet func (u *Unsubscribe) Unpack(r *bytes.Buffer) error { + var err error + u.PacketID, err = readUint16(r) + if err != nil { + return err + } + + err = u.Properties.Unpack(r, UNSUBSCRIBE) + if err != nil { + return err + } + for { t, err := readString(r) if err != nil && err != io.EOF { @@ -33,10 +44,13 @@ func (u *Unsubscribe) Unpack(r *bytes.Buffer) error { func (u *Unsubscribe) Buffers() net.Buffers { var b bytes.Buffer writeUint16(u.PacketID, &b) + var topics bytes.Buffer for _, t := range u.Topics { - writeString(t, &b) + writeString(t, &topics) } - return net.Buffers{b.Bytes()} + idvp := u.Properties.Pack(UNSUBSCRIBE) + propLen := encodeVBI(len(idvp)) + return net.Buffers{b.Bytes(), propLen, idvp, topics.Bytes()} } // WriteTo is the implementation of the interface required function for a packet diff --git a/paho/client.go b/paho/client.go index ad1d747..9af0f3f 100644 --- a/paho/client.go +++ b/paho/client.go @@ -67,7 +67,7 @@ type ( // client.Conn *MUST* be set to an already connected net.Conn before // Connect() is called. func NewClient() *Client { - debug.Println("Creating new client") + debug.Println("creating new client") c := &Client{ stop: make(chan struct{}), serverProps: CommsProperties{ @@ -113,7 +113,7 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) { return nil, fmt.Errorf("client connection is nil") } - debug.Println("Connecting") + debug.Println("connecting") c.Lock() defer c.Unlock() @@ -134,7 +134,7 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) { } } - debug.Println("Starting Incoming") + debug.Println("starting Incoming") c.workers.Add(1) go func() { defer c.workers.Done() @@ -208,7 +208,7 @@ func (c *Client) Connect(ctx context.Context, cp *Connect) (*Connack, error) { c.serverInflight = semaphore.NewWeighted(int64(c.serverProps.ReceiveMaximum)) c.clientInflight = semaphore.NewWeighted(int64(c.clientProps.ReceiveMaximum)) - debug.Println("Received CONNACK, starting PingHandler") + debug.Println("received CONNACK, starting PingHandler") c.workers.Add(1) go func() { defer c.workers.Done() @@ -227,7 +227,7 @@ func (c *Client) Incoming() { for { select { case <-c.stop: - debug.Println("Client stopping, Incoming stopping") + debug.Println("client stopping, Incoming stopping") return default: recv, err := packets.ReadPacket(c.Conn) @@ -235,7 +235,6 @@ func (c *Client) Incoming() { c.Error(err) return } - debug.Println("Received a control packet:", recv.Type) switch recv.Type { case packets.CONNACK: cap := recv.Content.(*packets.Connack) @@ -278,14 +277,15 @@ func (c *Client) Incoming() { pr.WriteTo(c.Conn) } case packets.PUBACK, packets.PUBCOMP, packets.SUBACK, packets.UNSUBACK: + debug.Println("received packet with id", recv.PacketID()) if cpCtx := c.MIDs.Get(recv.PacketID()); cpCtx != nil { cpCtx.Return <- *recv } else { - debug.Println("Received a response for a message ID we don't know:", recv.PacketID()) + debug.Println("received a response for a message ID we don't know:", recv.PacketID()) } case packets.PUBREC: if cpCtx := c.MIDs.Get(recv.PacketID()); cpCtx == nil { - debug.Println("Received a PUBREC for a message ID we don't know:", recv.PacketID()) + debug.Println("received a PUBREC for a message ID we don't know:", recv.PacketID()) pl := packets.Pubrel{ PacketID: recv.Content.(*packets.Pubrec).PacketID, ReasonCode: 0x92, @@ -333,7 +333,7 @@ func (c *Client) Incoming() { // which results in the other client goroutines terminating. // It also closes the client network connection. func (c *Client) Error(e error) { - debug.Println("Error called:", e) + debug.Println("error called:", e) c.Lock() select { case <-c.stop: @@ -440,7 +440,7 @@ func (c *Client) Subscribe(ctx context.Context, s *Subscribe) (*Suback, error) { if sap.Type != packets.SUBACK { return nil, fmt.Errorf("received %d instead of Suback", sap.Type) } - debug.Println("Received SUBACK") + debug.Println("received SUBACK") sa := SubackFromPacketSuback(sap.Content.(*packets.Suback)) switch { diff --git a/paho/client_test.go b/paho/client_test.go new file mode 100644 index 0000000..0015273 --- /dev/null +++ b/paho/client_test.go @@ -0,0 +1,137 @@ +package paho + +import ( + "context" + "log" + "os" + "testing" + "time" + + "github.com/eclipse/paho.golang/packets" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewClient(t *testing.T) { + c := NewClient() + + require.NotNil(t, c) + require.NotNil(t, c.stop) + require.NotNil(t, c.Persistence) + require.NotNil(t, c.MIDs) + require.NotNil(t, c.Router) + require.NotNil(t, c.PingHandler) + + assert.Equal(t, uint16(65535), c.serverProps.ReceiveMaximum) + assert.Equal(t, uint8(2), c.serverProps.MaximumQoS) + assert.Equal(t, uint32(0), c.serverProps.MaximumPacketSize) + assert.Equal(t, uint16(0), c.serverProps.TopicAliasMaximum) + assert.True(t, c.serverProps.RetainAvailable) + assert.True(t, c.serverProps.WildcardSubAvailable) + assert.True(t, c.serverProps.SubIDAvailable) + assert.True(t, c.serverProps.SharedSubAvailable) + + assert.Equal(t, uint16(65535), c.clientProps.ReceiveMaximum) + assert.Equal(t, uint8(2), c.clientProps.MaximumQoS) + assert.Equal(t, uint32(0), c.clientProps.MaximumPacketSize) + assert.Equal(t, uint16(0), c.clientProps.TopicAliasMaximum) + + assert.Equal(t, 10*time.Second, c.PacketTimeout) +} + +func TestClientConnect(t *testing.T) { + SetDebugLogger(log.New(os.Stderr, "CONNECT: ", log.LstdFlags)) + ts := newTestServer() + ts.SetResponse(packets.CONNACK, &packets.Connack{ + ReasonCode: 0, + SessionPresent: false, + Properties: &packets.Properties{ + MaximumPacketSize: Uint32(12345), + MaximumQOS: Byte(1), + ReceiveMaximum: Uint16(12345), + TopicAliasMaximum: Uint16(200), + }, + }) + go ts.Run() + defer ts.Stop() + + c := NewClient() + require.NotNil(t, c) + + c.Conn = ts.ClientConn() + + cp := &Connect{ + KeepAlive: 30, + ClientID: "testClient", + CleanStart: true, + } + + ca, err := c.Connect(context.Background(), cp) + require.Nil(t, err) + assert.Equal(t, uint8(0), ca.ReasonCode) + + time.Sleep(10 * time.Millisecond) +} + +func TestClientSubscribe(t *testing.T) { + SetDebugLogger(log.New(os.Stderr, "SUBSCRIBE: ", log.LstdFlags)) + ts := newTestServer() + ts.SetResponse(packets.SUBACK, &packets.Suback{ + Reasons: []byte{1, 2, 0}, + Properties: &packets.Properties{}, + }) + go ts.Run() + defer ts.Stop() + + c := NewClient() + require.NotNil(t, c) + + c.Conn = ts.ClientConn() + go c.Incoming() + go c.PingHandler.Start(c.Conn, 30*time.Second) + + s := &Subscribe{ + Subscriptions: map[string]SubscribeOptions{ + "test/1": SubscribeOptions{QoS: 1}, + "test/2": SubscribeOptions{QoS: 2}, + "test/3": SubscribeOptions{QoS: 0}, + }, + } + + sa, err := c.Subscribe(context.Background(), s) + require.Nil(t, err) + assert.Equal(t, []byte{1, 2, 0}, sa.Reasons) + + time.Sleep(10 * time.Millisecond) +} + +func TestClientUnsubscribe(t *testing.T) { + SetDebugLogger(log.New(os.Stderr, "UNSUBSCRIBE: ", log.LstdFlags)) + ts := newTestServer() + ts.SetResponse(packets.UNSUBACK, &packets.Unsuback{ + Reasons: []byte{0, 17}, + Properties: &packets.Properties{}, + }) + go ts.Run() + defer ts.Stop() + + c := NewClient() + require.NotNil(t, c) + + c.Conn = ts.ClientConn() + go c.Incoming() + go c.PingHandler.Start(c.Conn, 30*time.Second) + + u := &Unsubscribe{ + Topics: []string{ + "test/1", + "test/2", + }, + } + + ua, err := c.Unsubscribe(context.Background(), u) + require.Nil(t, err) + assert.Equal(t, []byte{0, 17}, ua.Reasons) + + time.Sleep(10 * time.Millisecond) +} diff --git a/paho/cp_unsubscribe.go b/paho/cp_unsubscribe.go index df1fc33..79d799a 100644 --- a/paho/cp_unsubscribe.go +++ b/paho/cp_unsubscribe.go @@ -19,10 +19,13 @@ type ( // Packet returns a packets library Unsubscribe from the paho Unsubscribe // on which it is called func (u *Unsubscribe) Packet() *packets.Unsubscribe { - return &packets.Unsubscribe{ - Topics: u.Topics, - Properties: &packets.Properties{ + v := &packets.Unsubscribe{Topics: u.Topics} + + if u.Properties != nil { + v.Properties = &packets.Properties{ User: u.Properties.User, - }, + } } + + return v } diff --git a/paho/server_test.go b/paho/server_test.go new file mode 100644 index 0000000..2258bdb --- /dev/null +++ b/paho/server_test.go @@ -0,0 +1,89 @@ +package paho + +import ( + "log" + "net" + + "github.com/eclipse/paho.golang/packets" +) + +type testServer struct { + conn net.Conn + clientConn net.Conn + stop chan struct{} + responses map[packets.PacketType]packets.Packet +} + +func newTestServer() *testServer { + t := &testServer{ + stop: make(chan struct{}), + responses: make(map[packets.PacketType]packets.Packet), + } + t.conn, t.clientConn = net.Pipe() + + return t +} + +func (t *testServer) ClientConn() net.Conn { + return t.clientConn +} + +func (t *testServer) SetResponse(pt packets.PacketType, p packets.Packet) { + t.responses[pt] = p +} + +func (t *testServer) Stop() { + t.conn.Close() + close(t.stop) +} + +func (t *testServer) Run() { + for { + select { + case <-t.stop: + return + default: + recv, err := packets.ReadPacket(t.conn) + if err != nil { + log.Println("error in test server reading packet", err) + return + } + log.Println("test server received a control packet:", recv.Type) + switch recv.Type { + case packets.CONNECT: + log.Println("received connect", recv.Content.(*packets.Connect)) + if p, ok := t.responses[packets.CONNACK]; ok { + if _, err := p.WriteTo(t.conn); err != nil { + log.Println(err) + } + } + case packets.SUBSCRIBE: + log.Println("received subscribe", recv.Content.(*packets.Subscribe)) + if p, ok := t.responses[packets.SUBACK]; ok { + p.(*packets.Suback).PacketID = recv.PacketID() + if _, err := p.WriteTo(t.conn); err != nil { + log.Println(err) + } + } + case packets.UNSUBSCRIBE: + log.Println("received unsubscribe", recv.Content.(*packets.Unsubscribe)) + if p, ok := t.responses[packets.UNSUBACK]; ok { + p.(*packets.Unsuback).PacketID = recv.PacketID() + if _, err := p.WriteTo(t.conn); err != nil { + log.Println(err) + } + } + case packets.AUTH: + case packets.PUBLISH: + case packets.PUBACK, packets.PUBCOMP, packets.SUBACK, packets.UNSUBACK: + case packets.PUBREC: + case packets.PUBREL: + case packets.DISCONNECT: + case packets.PINGREQ: + log.Println("test server sending pingresp") + pr := packets.NewControlPacket(packets.PINGRESP) + pr.WriteTo(t.conn) + } + } + } +}