diff --git a/README.md b/README.md index 45f36267..f34e5d0c 100755 --- a/README.md +++ b/README.md @@ -77,8 +77,8 @@ ok github.com/lxzan/gws 17.231s - [Introduction](#introduction) - [Why GWS](#why-gws) - [Benchmark](#benchmark) - - [IOPS (Echo Server)](#iops-echo-server) - - [GoBench](#gobench) + - [IOPS (Echo Server)](#iops-echo-server) + - [GoBench](#gobench) - [Index](#index) - [Feature](#feature) - [Attention](#attention) @@ -87,13 +87,14 @@ ok github.com/lxzan/gws 17.231s - [Quick Start](#quick-start) - [Best Practice](#best-practice) - [More Examples](#more-examples) - - [KCP](#kcp) - - [Proxy](#proxy) - - [Broadcast](#broadcast) - - [WriteWithTimeout](#writewithtimeout) - - [Pub / Sub](#pub--sub) + - [KCP](#kcp) + - [Proxy](#proxy) + - [Broadcast](#broadcast) + - [WriteWithTimeout](#writewithtimeout) + - [Pub / Sub](#pub--sub) - [Autobahn Test](#autobahn-test) - [Communication](#communication) +- [Buy me a coffee](#buy-me-a-coffee) - [Acknowledgments](#acknowledgments) ### Feature @@ -388,6 +389,10 @@ docker run -it --rm \ QQ +### Buy me a coffee + +WeChat + ### Acknowledgments The following project had particular influence on gws's design. diff --git a/README_CN.md b/README_CN.md index e05395fd..cee253b8 100755 --- a/README_CN.md +++ b/README_CN.md @@ -67,8 +67,8 @@ ok github.com/lxzan/gws 17.231s - [介绍](#介绍) - [为什么选择 GWS](#为什么选择-gws) - [基准测试](#基准测试) - - [IOPS (Echo Server)](#iops-echo-server) - - [GoBench](#gobench) + - [IOPS (Echo Server)](#iops-echo-server) + - [GoBench](#gobench) - [Index](#index) - [特性](#特性) - [注意](#注意) @@ -77,13 +77,14 @@ ok github.com/lxzan/gws 17.231s - [快速上手](#快速上手) - [最佳实践](#最佳实践) - [更多用例](#更多用例) - - [KCP](#kcp) - - [代理](#代理) - - [广播](#广播) - - [写入超时](#写入超时) - - [发布/订阅](#发布订阅) + - [KCP](#kcp) + - [代理](#代理) + - [广播](#广播) + - [写入超时](#写入超时) + - [发布/订阅](#发布订阅) - [Autobahn 测试](#autobahn-测试) - [交流](#交流) +- [赞赏](#赞赏) - [致谢](#致谢) ### 特性 @@ -375,6 +376,10 @@ docker run -it --rm \ QQ +### 赞赏 + +WeChat + ### 致谢 - [crossbario/autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) diff --git a/assets/alipay.jpg b/assets/alipay.jpg new file mode 100644 index 00000000..be8245aa Binary files /dev/null and b/assets/alipay.jpg differ diff --git a/compress.go b/compress.go index 9d4b8750..431c2673 100644 --- a/compress.go +++ b/compress.go @@ -111,6 +111,14 @@ func (c *deflater) Compress(src internal.Payload, dst *bytes.Buffer, dict []byte return nil } +func compressTo(cpsWriter *flate.Writer, r io.WriterTo, w io.Writer, dict []byte) error { + cpsWriter.ResetDict(w, dict) + if _, err := r.WriteTo(cpsWriter); err != nil { + return err + } + return cpsWriter.Flush() +} + // 滑动窗口 // Sliding window type slideWindow struct { diff --git a/conn.go b/conn.go index 60842db4..92dd0dee 100644 --- a/conn.go +++ b/conn.go @@ -100,7 +100,7 @@ func (c *Conn) ReadLoop() { // Infinite loop to read messages, if an error occurs, trigger the error event and exit the loop for { if err := c.readMessage(); err != nil { - c.emitError(err) + c.emitError(true, err) break } } @@ -125,106 +125,36 @@ func (c *Conn) ReadLoop() { } } -// 获取压缩字典 -// Get compressed dictionary -func (c *Conn) getCpsDict(isBroadcast bool) []byte { - // 广播模式必须保证每一帧都是相同的内容, 所以不使用上下文接管优化压缩率 - // In broadcast mode, each frame must be the same content, so context takeover is not used to optimize compression ratio - if isBroadcast { - return nil - } - - // 如果是服务器并且服务器上下文接管启用,返回压缩字典 - // If it is a server and server context takeover is enabled, return the compression dictionary - if c.isServer && c.pd.ServerContextTakeover { - return c.cpsWindow.dict - } - - // 如果是客户端并且客户端上下文接管启用,返回压缩字典 - // If client-side and client context takeover is enabled, return the compression dictionary - if !c.isServer && c.pd.ClientContextTakeover { - return c.cpsWindow.dict - } - - return nil -} - -// 获取解压字典 -// Get decompression dictionary -func (c *Conn) getDpsDict() []byte { - // 如果是服务器并且客户端上下文接管启用,返回解压字典 - // If it is a server and client context takeover is enabled, return the decompression dictionary - if c.isServer && c.pd.ClientContextTakeover { - return c.dpsWindow.dict - } - - // 如果是客户端并且服务器上下文接管启用,返回解压字典 - // If it is a client and server context takeover is enabled, return the decompressed dictionary - if !c.isServer && c.pd.ServerContextTakeover { - return c.dpsWindow.dict - } - - return nil -} - -// UTF8编码检查 -// UTF8 encoding check -func (c *Conn) isTextValid(opcode Opcode, payload []byte) bool { - if c.config.CheckUtf8Enabled { - return internal.CheckEncoding(uint8(opcode), payload) - } - return true -} - // 检查连接是否已关闭 // Checks if the connection is closed func (c *Conn) isClosed() bool { return atomic.LoadUint32(&c.closed) == 1 } -// 关闭连接并存储错误信息 -// Closes the connection and stores the error information -func (c *Conn) close(reason []byte, err error) { - c.ev.Store(err) - _ = c.doWrite(OpcodeCloseConnection, internal.Bytes(reason)) - _ = c.conn.Close() -} - // 处理错误事件 // Handle the error event -func (c *Conn) emitError(err error) { +func (c *Conn) emitError(reading bool, err error) { if err == nil { return } - // 使用原子操作检查并设置连接的关闭状态 - // Use atomic operation to check and set the closed state of the connection if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - var responseCode = internal.CloseNormalClosure - var responseErr error = internal.CloseNormalClosure - - // 根据错误类型设置响应代码和响应错误 - // Set response code and response error based on the error type - switch v := err.(type) { - case internal.StatusCode: - responseCode = v - case *internal.Error: - responseCode = v.Code - responseErr = v.Err - default: - responseErr = err - } - - var content = responseCode.Bytes() - content = append(content, err.Error()...) - - // 如果内容长度超过阈值,截断内容 - // If the content length exceeds the threshold, truncate the content - if len(content) > internal.ThresholdV1 { - content = content[:internal.ThresholdV1] + // 待发送的错误码和错误原因 + // Error code to be sent and cause of error + var sendCode, sendErr = internal.CloseGoingAway, error(internal.CloseGoingAway) + if reading { + switch v := err.(type) { + case internal.StatusCode: + sendCode, sendErr = v, v + case *internal.Error: + sendCode, sendErr, err = v.Code, v.Err, v.Err + default: + sendCode, sendErr = internal.CloseNormalClosure, err + } } - c.close(content, responseErr) + var reason = append(sendCode.Bytes(), sendErr.Error()...) + _ = c.writeClose(err, reason) } } @@ -257,12 +187,12 @@ func (c *Conn) emitClose(buf *bytes.Buffer) error { responseCode = internal.StatusCode(realCode) } } - if !c.isTextValid(OpcodeCloseConnection, buf.Bytes()) { + if !internal.CheckEncoding(c.config.CheckUtf8Enabled, uint8(OpcodeCloseConnection), buf.Bytes()) { responseCode = internal.CloseUnsupportedData } } if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - c.close(responseCode.Bytes(), &CloseError{Code: realCode, Reason: buf.Bytes()}) + _ = c.writeClose(&CloseError{Code: realCode, Reason: buf.Bytes()}, responseCode.Bytes()) } return internal.CloseNormalClosure } @@ -271,7 +201,7 @@ func (c *Conn) emitClose(buf *bytes.Buffer) error { // Sets the deadline for the connection func (c *Conn) SetDeadline(t time.Time) error { err := c.conn.SetDeadline(t) - c.emitError(err) + c.emitError(false, err) return err } @@ -279,7 +209,7 @@ func (c *Conn) SetDeadline(t time.Time) error { // Sets the deadline for read operations func (c *Conn) SetReadDeadline(t time.Time) error { err := c.conn.SetReadDeadline(t) - c.emitError(err) + c.emitError(false, err) return err } @@ -287,7 +217,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error { // Sets the deadline for write operations func (c *Conn) SetWriteDeadline(t time.Time) error { err := c.conn.SetWriteDeadline(t) - c.emitError(err) + c.emitError(false, err) return err } diff --git a/conn_test.go b/conn_test.go index 680181b2..d37882cd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -151,6 +151,6 @@ func TestConn_EmitError(t *testing.T) { server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) go client.ReadLoop() err := errors.New(string(internal.AlphabetNumeric.Generate(500))) - server.emitError(err) + server.emitError(false, err) wg.Wait() } diff --git a/internal/error.go b/internal/error.go index 518fd023..35bbc9f8 100644 --- a/internal/error.go +++ b/internal/error.go @@ -3,21 +3,21 @@ package internal // closeErrorMap 将状态码映射到错误信息 // map status codes to error messages var closeErrorMap = map[StatusCode]string{ - 0: "empty code", - CloseNormalClosure: "close normal", - CloseGoingAway: "client going away", - CloseProtocolError: "protocol error", - CloseUnsupported: "unsupported data", - CloseNoStatusReceived: "no status", - CloseAbnormalClosure: "abnormal closure", - CloseUnsupportedData: "invalid payload data", - ClosePolicyViolation: "policy violation", - CloseMessageTooLarge: "message too large", - CloseMissingExtension: "mandatory extension missing", - CloseInternalServerErr: "internal server error", - CloseServiceRestart: "server restarting", - CloseTryAgainLater: "try again later", - CloseTLSHandshake: "TLS handshake error", + 0: "empty code", + CloseNormalClosure: "close normal", + CloseGoingAway: "client going away", + CloseProtocolError: "protocol error", + CloseUnsupported: "unsupported data", + CloseNoStatusReceived: "no status", + CloseAbnormalClosure: "abnormal closure", + CloseUnsupportedData: "invalid payload data", + ClosePolicyViolation: "policy violation", + CloseMessageTooLarge: "message too large", + CloseMissingExtension: "mandatory extension missing", + CloseInternalErr: "internal error", + CloseServiceRestart: "server restarting", + CloseTryAgainLater: "try again later", + CloseTLSHandshake: "TLS handshake error", } // StatusCode WebSocket错误码 @@ -55,8 +55,8 @@ const ( // CloseMissingExtension 客户端期望服务器商定一个或多个拓展, 但服务器没有处理, 因此客户端断开连接. CloseMissingExtension StatusCode = 1010 - // CloseInternalServerErr 客户端由于遇到没有预料的情况阻止其完成请求, 因此服务端断开连接. - CloseInternalServerErr StatusCode = 1011 + // CloseInternalErr 客户端由于遇到没有预料的情况阻止其完成请求, 因此服务端断开连接. + CloseInternalErr StatusCode = 1011 // CloseServiceRestart 服务器由于重启而断开连接. [Ref] CloseServiceRestart StatusCode = 1012 diff --git a/internal/io.go b/internal/io.go index 32aefa7d..45674176 100644 --- a/internal/io.go +++ b/internal/io.go @@ -21,13 +21,11 @@ func WriteN(writer io.Writer, content []byte) error { // CheckEncoding 检查 payload 的编码是否有效 // checks if the encoding of the payload is valid -func CheckEncoding(opcode uint8, payload []byte) bool { - switch opcode { - case 1, 8: +func CheckEncoding(enabled bool, opcode uint8, payload []byte) bool { + if enabled && (opcode == 1 || opcode == 8) { return utf8.Valid(payload) - default: - return true } + return true } type Payload interface { @@ -39,11 +37,9 @@ type Payload interface { type Buffers [][]byte func (b Buffers) CheckEncoding(enabled bool, opcode uint8) bool { - if enabled { - for i, _ := range b { - if !CheckEncoding(opcode, b[i]) { - return false - } + for i, _ := range b { + if !CheckEncoding(enabled, opcode, b[i]) { + return false } } return true @@ -73,10 +69,7 @@ func (b Buffers) WriteTo(w io.Writer) (int64, error) { type Bytes []byte func (b Bytes) CheckEncoding(enabled bool, opcode uint8) bool { - if enabled { - return CheckEncoding(opcode, b) - } - return true + return CheckEncoding(enabled, opcode, b) } func (b Bytes) Len() int { diff --git a/reader.go b/reader.go index 7235ef19..ea2037fe 100644 --- a/reader.go +++ b/reader.go @@ -161,13 +161,13 @@ func (c *Conn) dispatch(msg *Message) error { // Emit onmessage event func (c *Conn) emitMessage(msg *Message) (err error) { if msg.compressed { - msg.Data, err = c.deflater.Decompress(msg.Data, c.getDpsDict()) + msg.Data, err = c.deflater.Decompress(msg.Data, c.dpsWindow.dict) if err != nil { - return internal.NewError(internal.CloseInternalServerErr, err) + return internal.NewError(internal.CloseInternalErr, err) } _, _ = c.dpsWindow.Write(msg.Bytes()) } - if !c.isTextValid(msg.Opcode, msg.Bytes()) { + if !internal.CheckEncoding(c.config.CheckUtf8Enabled, uint8(msg.Opcode), msg.Bytes()) { return internal.NewError(internal.CloseUnsupportedData, ErrTextEncoding) } if c.config.ParallelEnabled { diff --git a/types.go b/types.go index 9acc0d90..6088504d 100644 --- a/types.go +++ b/types.go @@ -73,6 +73,10 @@ var ( // Text message encoding error (must be utf8) ErrTextEncoding = errors.New("invalid text encoding") + // ErrMessageTooLarge 消息体积过大 + // message is too large + ErrMessageTooLarge = errors.New("message too large") + // ErrConnClosed 连接已关闭 // Connection closed ErrConnClosed = net.ErrClosed diff --git a/upgrader.go b/upgrader.go index def7b67a..6b465094 100644 --- a/upgrader.go +++ b/upgrader.go @@ -116,7 +116,7 @@ func NewUpgrader(eventHandler Event, option *ServerOption) *Upgrader { func (c *Upgrader) hijack(w http.ResponseWriter) (net.Conn, *bufio.Reader, error) { hj, ok := w.(http.Hijacker) if !ok { - return nil, nil, internal.CloseInternalServerErr + return nil, nil, internal.CloseInternalErr } netConn, _, err := hj.Hijack() if err != nil { @@ -249,11 +249,11 @@ func (c *Upgrader) doUpgradeFromConn(netConn net.Conn, br *bufio.Reader, r *http // Compressing and decompressing dictionaries has a large memory overhead, so use lazy loading. if pd.Enabled { socket.deflater = c.deflaterPool.Select() - if c.option.PermessageDeflate.ServerContextTakeover { - socket.cpsWindow.initialize(config.cswPool, c.option.PermessageDeflate.ServerMaxWindowBits) + if pd.ServerContextTakeover { + socket.cpsWindow.initialize(config.cswPool, pd.ServerMaxWindowBits) } - if c.option.PermessageDeflate.ClientContextTakeover { - socket.dpsWindow.initialize(config.dswPool, c.option.PermessageDeflate.ClientMaxWindowBits) + if pd.ClientContextTakeover { + socket.dpsWindow.initialize(config.dswPool, pd.ClientMaxWindowBits) } } return socket, nil diff --git a/bigfile.go b/writefile.go similarity index 90% rename from bigfile.go rename to writefile.go index 468bcf40..2c7a2911 100644 --- a/bigfile.go +++ b/writefile.go @@ -57,7 +57,7 @@ func (c *Conn) splitReader(r io.Reader, f func(index int, eof bool, p []byte) er // Segmented write technology to reduce memory usage during write process func (c *Conn) WriteFile(opcode Opcode, payload io.Reader) error { err := c.doWriteFile(opcode, payload) - c.emitError(err) + c.emitError(false, err) return err } @@ -92,7 +92,8 @@ func (c *Conn) doWriteFile(opcode Opcode, payload io.Reader) error { if c.pd.Enabled { var deflater = c.getBigDeflater() var fw = &flateWriter{cb: cb} - err := deflater.Compress(payload, fw, c.getCpsDict(false), &c.cpsWindow) + var reader = &readerWrapper{r: payload, sw: &c.cpsWindow} + err := deflater.Compress(reader, fw, c.cpsWindow.dict) c.putBigDeflater(deflater) return err } else { @@ -119,11 +120,11 @@ func newBigDeflater(isServer bool, options PermessageDeflate) *bigDeflater { func (c *bigDeflater) FlateWriter() *flate.Writer { return (*flate.Writer)(c) } // Compress 压缩 -func (c *bigDeflater) Compress(src io.Reader, dst *flateWriter, dict []byte, sw *slideWindow) error { - if err := compressTo(c.FlateWriter(), &readerWrapper{r: src, sw: sw}, dst, dict); err != nil { +func (c *bigDeflater) Compress(r io.WriterTo, w *flateWriter, dict []byte) error { + if err := compressTo(c.FlateWriter(), r, w, dict); err != nil { return err } - return dst.Flush() + return w.Flush() } // 写入代理 @@ -223,11 +224,3 @@ func (c *readerWrapper) WriteTo(w io.Writer) (int64, error) { } return int64(sum), err } - -func compressTo(cpsWriter *flate.Writer, r io.WriterTo, w io.Writer, dict []byte) error { - cpsWriter.ResetDict(w, dict) - if _, err := r.WriteTo(cpsWriter); err != nil { - return err - } - return cpsWriter.Flush() -} diff --git a/writer.go b/writer.go index 7ec569b9..a534bbd9 100644 --- a/writer.go +++ b/writer.go @@ -2,7 +2,6 @@ package gws import ( "bytes" - "errors" "math" "sync" "sync/atomic" @@ -15,12 +14,29 @@ import ( // Send shutdown frame, active disconnection // If you don't have any special needs, we recommend code=1000, reason=nil // https://developer.mozilla.org/zh-CN/docs/Web/API/CloseEvent#status_codes -func (c *Conn) WriteClose(code uint16, reason []byte) { - var err = internal.NewError(internal.StatusCode(code), errEmpty) - if len(reason) > 0 { - err.Err = errors.New(string(reason)) +func (c *Conn) WriteClose(code uint16, reason []byte) error { + if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + var buf = binaryPool.Get(128) + code = internal.SelectValue(code < 1000, 1000, code) + buf.Write(internal.StatusCode(code).Bytes()) + buf.Write(reason) + err := c.writeClose(internal.StatusCode(code), buf.Bytes()) + binaryPool.Put(buf) + return err } - c.emitError(err) + return ErrConnClosed +} + +// 关闭连接并存储错误信息 +// Closes the connection and stores the error information +func (c *Conn) writeClose(ev error, reason []byte) error { + if len(reason) > internal.ThresholdV1 { + reason = reason[:internal.ThresholdV1] + } + c.ev.Store(ev) + err := c.doWrite(OpcodeCloseConnection, internal.Bytes(reason)) + _ = c.conn.Close() + return err } // WritePing @@ -49,7 +65,7 @@ func (c *Conn) WriteString(s string) error { // Writes text/binary messages, text messages should be encoded in UTF8. func (c *Conn) WriteMessage(opcode Opcode, payload []byte) error { err := c.doWrite(opcode, internal.Bytes(payload)) - c.emitError(err) + c.emitError(false, err) return err } @@ -59,7 +75,7 @@ func (c *Conn) WriteMessage(opcode Opcode, payload []byte) error { // Write messages to the task queue asynchronously and non-blockingly, // allowing payload memory to be recycled only after receiving the callback func (c *Conn) WriteAsync(opcode Opcode, payload []byte, callback func(error)) { - c.writeQueue.Push(func() { + c.Async(func() { if err := c.WriteMessage(opcode, payload); callback != nil { callback(err) } @@ -71,14 +87,14 @@ func (c *Conn) WriteAsync(opcode Opcode, payload []byte, callback func(error)) { // Writev is similar to WriteMessage, except that you can write multiple slices at once. func (c *Conn) Writev(opcode Opcode, payloads ...[]byte) error { var err = c.doWrite(opcode, internal.Buffers(payloads)) - c.emitError(err) + c.emitError(false, err) return err } // WritevAsync 类似 WriteAsync, 区别是可以一次写入多个切片 // It's similar to WriteAsync, except that you can write multiple slices at once. func (c *Conn) WritevAsync(opcode Opcode, payloads [][]byte, callback func(error)) { - c.writeQueue.Push(func() { + c.Async(func() { if err := c.Writev(opcode, payloads...); callback != nil { callback(err) } @@ -146,20 +162,19 @@ type frameConfig struct { // 生成帧数据 // Generates the frame data func (c *Conn) genFrame(opcode Opcode, payload internal.Payload, cfg frameConfig) (*bytes.Buffer, error) { + var n = payload.Len() if opcode == OpcodeText && !payload.CheckEncoding(cfg.checkEncoding, uint8(opcode)) { - return nil, internal.NewError(internal.CloseUnsupportedData, ErrTextEncoding) + return nil, ErrTextEncoding } - - var n = payload.Len() if n > c.config.WriteMaxPayloadSize { - return nil, internal.CloseMessageTooLarge + return nil, ErrMessageTooLarge } var buf = binaryPool.Get(n + frameHeaderSize) buf.Write(framePadding[0:]) if cfg.compress && opcode.isDataFrame() && n >= c.pd.Threshold { - return c.compressData(buf, opcode, cfg.fin, payload, cfg.broadcast) + return c.compressData(opcode, payload, buf, cfg) } var header = frameHeader{} @@ -177,15 +192,18 @@ func (c *Conn) genFrame(opcode Opcode, payload internal.Payload, cfg frameConfig // 压缩数据并生成帧 // Compresses the data and generates the frame -func (c *Conn) compressData(buf *bytes.Buffer, opcode Opcode, fin bool, payload internal.Payload, isBroadcast bool) (*bytes.Buffer, error) { - err := c.deflater.Compress(payload, buf, c.getCpsDict(isBroadcast)) - if err != nil { +func (c *Conn) compressData(opcode Opcode, payload internal.Payload, buf *bytes.Buffer, cfg frameConfig) (*bytes.Buffer, error) { + // 广播模式必须保证每一帧都是相同的内容, 所以不能使用字典优化压缩率 + // Broadcast mode must ensure that every frame is the same, so you can't use a dictionary to optimize the compression rate. + var dict = internal.SelectValue(cfg.broadcast, nil, c.cpsWindow.dict) + if err := c.deflater.Compress(payload, buf, dict); err != nil { return nil, err } + var contents = buf.Bytes() var payloadSize = buf.Len() - frameHeaderSize var header = frameHeader{} - headerLength, maskBytes := header.GenerateHeader(c.isServer, fin, true, opcode, payloadSize) + headerLength, maskBytes := header.GenerateHeader(c.isServer, cfg.fin, true, opcode, payloadSize) if !c.isServer { internal.MaskXOR(contents[frameHeaderSize:], maskBytes) } @@ -232,7 +250,7 @@ func (c *Broadcaster) writeFrame(socket *Conn, frame *bytes.Buffer) error { } socket.mu.Lock() var err = internal.WriteN(socket.conn, frame.Bytes()) - socket.cpsWindow.Write(c.payload) + _, _ = socket.cpsWindow.Write(c.payload) socket.mu.Unlock() return err } @@ -259,7 +277,7 @@ func (c *Broadcaster) Broadcast(socket *Conn) error { atomic.AddInt64(&c.state, 1) socket.writeQueue.Push(func() { var err = c.writeFrame(socket, msg.frame) - socket.emitError(err) + socket.emitError(false, err) if atomic.AddInt64(&c.state, -1) == 0 { c.doClose() } diff --git a/writer_test.go b/writer_test.go index 66e45d33..46530d0d 100644 --- a/writer_test.go +++ b/writer_test.go @@ -21,7 +21,7 @@ func testWrite(c *Conn, fin bool, opcode Opcode, payload []byte) error { var buf = bytes.NewBufferString("") err := c.deflater.Compress(internal.Bytes(payload), buf, c.cpsWindow.dict) if err != nil { - return internal.NewError(internal.CloseInternalServerErr, err) + return internal.NewError(internal.CloseInternalErr, err) } payload = buf.Bytes() } @@ -114,28 +114,90 @@ func TestWriteBigMessage(t *testing.T) { func TestWriteClose(t *testing.T) { var as = assert.New(t) - var serverHandler = new(webSocketMocker) - var clientHandler = new(webSocketMocker) - var serverOption = &ServerOption{} - var clientOption = &ClientOption{} - - var wg = sync.WaitGroup{} - wg.Add(1) - serverHandler.onClose = func(socket *Conn, err error) { - as.Error(err) - wg.Done() - } - server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) - go server.ReadLoop() - go client.ReadLoop() - server.WriteClose(1000, []byte("goodbye")) - wg.Wait() t.Run("", func(t *testing.T) { + var serverHandler = new(webSocketMocker) + var clientHandler = new(webSocketMocker) + var serverOption = &ServerOption{} + var clientOption = &ClientOption{} + + var wg = sync.WaitGroup{} + wg.Add(1) + serverHandler.onClose = func(socket *Conn, err error) { + as.Error(err) + wg.Done() + } + server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) + go server.ReadLoop() + go client.ReadLoop() + server.WriteClose(1000, []byte("goodbye")) + wg.Wait() var socket = &Conn{closed: 1, config: server.config} socket.WriteMessage(OpcodeText, nil) socket.WriteAsync(OpcodeText, nil, nil) }) + + t.Run("", func(t *testing.T) { + var pd = PermessageDeflate{ + Enabled: true, + } + var serverHandler = new(webSocketMocker) + var clientHandler = new(webSocketMocker) + var serverOption = &ServerOption{ + PermessageDeflate: pd, + } + var clientOption = &ClientOption{ + PermessageDeflate: pd, + } + var wg = &sync.WaitGroup{} + wg.Add(1) + + serverHandler.onClose = func(socket *Conn, err error) { + if v, ok := err.(*CloseError); ok && string(v.Reason) == "goodbye" { + wg.Done() + } + } + + server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) + go server.ReadLoop() + go client.ReadLoop() + + var err = client.WriteClose(1006, []byte("goodbye")) + assert.NoError(t, err) + err = client.WriteClose(1006, []byte("goodbye")) + assert.True(t, errors.Is(err, ErrConnClosed)) + wg.Wait() + }) + + t.Run("", func(t *testing.T) { + var pd = PermessageDeflate{ + Enabled: true, + } + var serverHandler = new(webSocketMocker) + var clientHandler = new(webSocketMocker) + var serverOption = &ServerOption{ + PermessageDeflate: pd, + } + var clientOption = &ClientOption{ + PermessageDeflate: pd, + } + var wg = &sync.WaitGroup{} + wg.Add(1) + + serverHandler.onClose = func(socket *Conn, err error) { + if v, ok := err.(*CloseError); ok && len(v.Reason) == 123 { + wg.Done() + } + } + + server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption) + go server.ReadLoop() + go client.ReadLoop() + + var err = client.WriteClose(1006, internal.AlphabetNumeric.Generate(1024)) + assert.NoError(t, err) + wg.Wait() + }) } func TestConn_WriteAsyncError(t *testing.T) { @@ -755,7 +817,8 @@ func TestConn_WriteFile(t *testing.T) { var fw = &flateWriter{cb: func(index int, eof bool, p []byte) error { return nil }} - err := deflater.Compress(new(writerTo), fw, nil, new(slideWindow)) + var reader = &readerWrapper{r: new(writerTo), sw: new(slideWindow)} + err := deflater.Compress(reader, fw, nil) assert.Error(t, err) }) @@ -768,8 +831,8 @@ func TestConn_WriteFile(t *testing.T) { var fw = &flateWriter{cb: func(index int, eof bool, p []byte) error { return errors.New("2") }} - src := bytes.NewBufferString("hello") - err := deflater.Compress(src, fw, nil, new(slideWindow)) + var reader = &readerWrapper{r: new(writerTo), sw: new(slideWindow)} + err := deflater.Compress(reader, fw, nil) assert.Error(t, err) })