diff --git a/reader.go b/reader.go index dd1c936d..a10ab527 100644 --- a/reader.go +++ b/reader.go @@ -36,11 +36,7 @@ func (c *Conn) readControl() error { if err := internal.ReadN(c.rbuf, payload, int(n)); err != nil { return err } - maskEnabled := c.fh.GetMask() - if err := c.checkMask(maskEnabled); err != nil { - return err - } - if maskEnabled { + if maskEnabled := c.fh.GetMask(); maskEnabled { internal.MaskXOR(payload, c.fh.GetMaskKey()) } } diff --git a/reader_test.go b/reader_test.go index 34dc02fb..beb63032 100644 --- a/reader_test.go +++ b/reader_test.go @@ -2,6 +2,7 @@ package gws import ( "bytes" + "compress/flate" _ "embed" "encoding/hex" "encoding/json" @@ -229,6 +230,60 @@ func TestSegments(t *testing.T) { }() wg.Wait() }) + + t.Run("illegal compression", func(t *testing.T) { + var wg = &sync.WaitGroup{} + wg.Add(1) + + var serverHandler = new(webSocketMocker) + var clientHandler = new(webSocketMocker) + var serverOption = &ServerOption{} + var clientOption = &ClientOption{} + + var s1 = internal.AlphabetNumeric.Generate(1024) + serverHandler.onClose = func(socket *Conn, err error) { + as.Error(err) + wg.Done() + } + + server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) + go server.ReadLoop() + go client.ReadLoop() + + go func() { + client.compressEnabled = true + client.config.compressors = new(compressors).initialize(16, flate.BestSpeed) + testWrite(client, true, OpcodeText, testCloneBytes(s1)) + }() + wg.Wait() + }) + + t.Run("decompress error", func(t *testing.T) { + var wg = &sync.WaitGroup{} + wg.Add(1) + + var serverHandler = new(webSocketMocker) + var clientHandler = new(webSocketMocker) + var serverOption = &ServerOption{CompressEnabled: true} + var clientOption = &ClientOption{CompressEnabled: true} + + serverHandler.onClose = func(socket *Conn, err error) { + as.Error(err) + wg.Done() + } + + server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) + go server.ReadLoop() + go client.ReadLoop() + + go func() { + frame, _, _ := client.genFrame(OpcodeText, testdata) + data := frame.Bytes() + data[20] = 'x' + client.conn.Write(data) + }() + wg.Wait() + }) } func TestMessage(t *testing.T) {