diff --git a/README.md b/README.md
index 45f36267..f34e5d0c 100755
--- a/README.md
+++ b/README.md
@@ -77,8 +77,8 @@ ok github.com/lxzan/gws 17.231s
- [Introduction](#introduction)
- [Why GWS](#why-gws)
- [Benchmark](#benchmark)
- - [IOPS (Echo Server)](#iops-echo-server)
- - [GoBench](#gobench)
+ - [IOPS (Echo Server)](#iops-echo-server)
+ - [GoBench](#gobench)
- [Index](#index)
- [Feature](#feature)
- [Attention](#attention)
@@ -87,13 +87,14 @@ ok github.com/lxzan/gws 17.231s
- [Quick Start](#quick-start)
- [Best Practice](#best-practice)
- [More Examples](#more-examples)
- - [KCP](#kcp)
- - [Proxy](#proxy)
- - [Broadcast](#broadcast)
- - [WriteWithTimeout](#writewithtimeout)
- - [Pub / Sub](#pub--sub)
+ - [KCP](#kcp)
+ - [Proxy](#proxy)
+ - [Broadcast](#broadcast)
+ - [WriteWithTimeout](#writewithtimeout)
+ - [Pub / Sub](#pub--sub)
- [Autobahn Test](#autobahn-test)
- [Communication](#communication)
+- [Buy me a coffee](#buy-me-a-coffee)
- [Acknowledgments](#acknowledgments)
### Feature
@@ -388,6 +389,10 @@ docker run -it --rm \
+### Buy me a coffee
+
+
+
### Acknowledgments
The following project had particular influence on gws's design.
diff --git a/README_CN.md b/README_CN.md
index e05395fd..cee253b8 100755
--- a/README_CN.md
+++ b/README_CN.md
@@ -67,8 +67,8 @@ ok github.com/lxzan/gws 17.231s
- [介绍](#介绍)
- [为什么选择 GWS](#为什么选择-gws)
- [基准测试](#基准测试)
- - [IOPS (Echo Server)](#iops-echo-server)
- - [GoBench](#gobench)
+ - [IOPS (Echo Server)](#iops-echo-server)
+ - [GoBench](#gobench)
- [Index](#index)
- [特性](#特性)
- [注意](#注意)
@@ -77,13 +77,14 @@ ok github.com/lxzan/gws 17.231s
- [快速上手](#快速上手)
- [最佳实践](#最佳实践)
- [更多用例](#更多用例)
- - [KCP](#kcp)
- - [代理](#代理)
- - [广播](#广播)
- - [写入超时](#写入超时)
- - [发布/订阅](#发布订阅)
+ - [KCP](#kcp)
+ - [代理](#代理)
+ - [广播](#广播)
+ - [写入超时](#写入超时)
+ - [发布/订阅](#发布订阅)
- [Autobahn 测试](#autobahn-测试)
- [交流](#交流)
+- [赞赏](#赞赏)
- [致谢](#致谢)
### 特性
@@ -375,6 +376,10 @@ docker run -it --rm \
+### 赞赏
+
+
+
### 致谢
- [crossbario/autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
diff --git a/assets/alipay.jpg b/assets/alipay.jpg
new file mode 100644
index 00000000..be8245aa
Binary files /dev/null and b/assets/alipay.jpg differ
diff --git a/compress.go b/compress.go
index 9d4b8750..431c2673 100644
--- a/compress.go
+++ b/compress.go
@@ -111,6 +111,14 @@ func (c *deflater) Compress(src internal.Payload, dst *bytes.Buffer, dict []byte
return nil
}
+func compressTo(cpsWriter *flate.Writer, r io.WriterTo, w io.Writer, dict []byte) error {
+ cpsWriter.ResetDict(w, dict)
+ if _, err := r.WriteTo(cpsWriter); err != nil {
+ return err
+ }
+ return cpsWriter.Flush()
+}
+
// 滑动窗口
// Sliding window
type slideWindow struct {
diff --git a/conn.go b/conn.go
index 60842db4..92dd0dee 100644
--- a/conn.go
+++ b/conn.go
@@ -100,7 +100,7 @@ func (c *Conn) ReadLoop() {
// Infinite loop to read messages, if an error occurs, trigger the error event and exit the loop
for {
if err := c.readMessage(); err != nil {
- c.emitError(err)
+ c.emitError(true, err)
break
}
}
@@ -125,106 +125,36 @@ func (c *Conn) ReadLoop() {
}
}
-// 获取压缩字典
-// Get compressed dictionary
-func (c *Conn) getCpsDict(isBroadcast bool) []byte {
- // 广播模式必须保证每一帧都是相同的内容, 所以不使用上下文接管优化压缩率
- // In broadcast mode, each frame must be the same content, so context takeover is not used to optimize compression ratio
- if isBroadcast {
- return nil
- }
-
- // 如果是服务器并且服务器上下文接管启用,返回压缩字典
- // If it is a server and server context takeover is enabled, return the compression dictionary
- if c.isServer && c.pd.ServerContextTakeover {
- return c.cpsWindow.dict
- }
-
- // 如果是客户端并且客户端上下文接管启用,返回压缩字典
- // If client-side and client context takeover is enabled, return the compression dictionary
- if !c.isServer && c.pd.ClientContextTakeover {
- return c.cpsWindow.dict
- }
-
- return nil
-}
-
-// 获取解压字典
-// Get decompression dictionary
-func (c *Conn) getDpsDict() []byte {
- // 如果是服务器并且客户端上下文接管启用,返回解压字典
- // If it is a server and client context takeover is enabled, return the decompression dictionary
- if c.isServer && c.pd.ClientContextTakeover {
- return c.dpsWindow.dict
- }
-
- // 如果是客户端并且服务器上下文接管启用,返回解压字典
- // If it is a client and server context takeover is enabled, return the decompressed dictionary
- if !c.isServer && c.pd.ServerContextTakeover {
- return c.dpsWindow.dict
- }
-
- return nil
-}
-
-// UTF8编码检查
-// UTF8 encoding check
-func (c *Conn) isTextValid(opcode Opcode, payload []byte) bool {
- if c.config.CheckUtf8Enabled {
- return internal.CheckEncoding(uint8(opcode), payload)
- }
- return true
-}
-
// 检查连接是否已关闭
// Checks if the connection is closed
func (c *Conn) isClosed() bool {
return atomic.LoadUint32(&c.closed) == 1
}
-// 关闭连接并存储错误信息
-// Closes the connection and stores the error information
-func (c *Conn) close(reason []byte, err error) {
- c.ev.Store(err)
- _ = c.doWrite(OpcodeCloseConnection, internal.Bytes(reason))
- _ = c.conn.Close()
-}
-
// 处理错误事件
// Handle the error event
-func (c *Conn) emitError(err error) {
+func (c *Conn) emitError(reading bool, err error) {
if err == nil {
return
}
- // 使用原子操作检查并设置连接的关闭状态
- // Use atomic operation to check and set the closed state of the connection
if atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
- var responseCode = internal.CloseNormalClosure
- var responseErr error = internal.CloseNormalClosure
-
- // 根据错误类型设置响应代码和响应错误
- // Set response code and response error based on the error type
- switch v := err.(type) {
- case internal.StatusCode:
- responseCode = v
- case *internal.Error:
- responseCode = v.Code
- responseErr = v.Err
- default:
- responseErr = err
- }
-
- var content = responseCode.Bytes()
- content = append(content, err.Error()...)
-
- // 如果内容长度超过阈值,截断内容
- // If the content length exceeds the threshold, truncate the content
- if len(content) > internal.ThresholdV1 {
- content = content[:internal.ThresholdV1]
+ // 待发送的错误码和错误原因
+ // Error code to be sent and cause of error
+ var sendCode, sendErr = internal.CloseGoingAway, error(internal.CloseGoingAway)
+ if reading {
+ switch v := err.(type) {
+ case internal.StatusCode:
+ sendCode, sendErr = v, v
+ case *internal.Error:
+ sendCode, sendErr, err = v.Code, v.Err, v.Err
+ default:
+ sendCode, sendErr = internal.CloseNormalClosure, err
+ }
}
- c.close(content, responseErr)
+ var reason = append(sendCode.Bytes(), sendErr.Error()...)
+ _ = c.writeClose(err, reason)
}
}
@@ -257,12 +187,12 @@ func (c *Conn) emitClose(buf *bytes.Buffer) error {
responseCode = internal.StatusCode(realCode)
}
}
- if !c.isTextValid(OpcodeCloseConnection, buf.Bytes()) {
+ if !internal.CheckEncoding(c.config.CheckUtf8Enabled, uint8(OpcodeCloseConnection), buf.Bytes()) {
responseCode = internal.CloseUnsupportedData
}
}
if atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
- c.close(responseCode.Bytes(), &CloseError{Code: realCode, Reason: buf.Bytes()})
+ _ = c.writeClose(&CloseError{Code: realCode, Reason: buf.Bytes()}, responseCode.Bytes())
}
return internal.CloseNormalClosure
}
@@ -271,7 +201,7 @@ func (c *Conn) emitClose(buf *bytes.Buffer) error {
// Sets the deadline for the connection
func (c *Conn) SetDeadline(t time.Time) error {
err := c.conn.SetDeadline(t)
- c.emitError(err)
+ c.emitError(false, err)
return err
}
@@ -279,7 +209,7 @@ func (c *Conn) SetDeadline(t time.Time) error {
// Sets the deadline for read operations
func (c *Conn) SetReadDeadline(t time.Time) error {
err := c.conn.SetReadDeadline(t)
- c.emitError(err)
+ c.emitError(false, err)
return err
}
@@ -287,7 +217,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
// Sets the deadline for write operations
func (c *Conn) SetWriteDeadline(t time.Time) error {
err := c.conn.SetWriteDeadline(t)
- c.emitError(err)
+ c.emitError(false, err)
return err
}
diff --git a/conn_test.go b/conn_test.go
index 680181b2..d37882cd 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -151,6 +151,6 @@ func TestConn_EmitError(t *testing.T) {
server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption)
go client.ReadLoop()
err := errors.New(string(internal.AlphabetNumeric.Generate(500)))
- server.emitError(err)
+ server.emitError(false, err)
wg.Wait()
}
diff --git a/internal/error.go b/internal/error.go
index 518fd023..35bbc9f8 100644
--- a/internal/error.go
+++ b/internal/error.go
@@ -3,21 +3,21 @@ package internal
// closeErrorMap 将状态码映射到错误信息
// map status codes to error messages
var closeErrorMap = map[StatusCode]string{
- 0: "empty code",
- CloseNormalClosure: "close normal",
- CloseGoingAway: "client going away",
- CloseProtocolError: "protocol error",
- CloseUnsupported: "unsupported data",
- CloseNoStatusReceived: "no status",
- CloseAbnormalClosure: "abnormal closure",
- CloseUnsupportedData: "invalid payload data",
- ClosePolicyViolation: "policy violation",
- CloseMessageTooLarge: "message too large",
- CloseMissingExtension: "mandatory extension missing",
- CloseInternalServerErr: "internal server error",
- CloseServiceRestart: "server restarting",
- CloseTryAgainLater: "try again later",
- CloseTLSHandshake: "TLS handshake error",
+ 0: "empty code",
+ CloseNormalClosure: "close normal",
+ CloseGoingAway: "client going away",
+ CloseProtocolError: "protocol error",
+ CloseUnsupported: "unsupported data",
+ CloseNoStatusReceived: "no status",
+ CloseAbnormalClosure: "abnormal closure",
+ CloseUnsupportedData: "invalid payload data",
+ ClosePolicyViolation: "policy violation",
+ CloseMessageTooLarge: "message too large",
+ CloseMissingExtension: "mandatory extension missing",
+ CloseInternalErr: "internal error",
+ CloseServiceRestart: "server restarting",
+ CloseTryAgainLater: "try again later",
+ CloseTLSHandshake: "TLS handshake error",
}
// StatusCode WebSocket错误码
@@ -55,8 +55,8 @@ const (
// CloseMissingExtension 客户端期望服务器商定一个或多个拓展, 但服务器没有处理, 因此客户端断开连接.
CloseMissingExtension StatusCode = 1010
- // CloseInternalServerErr 客户端由于遇到没有预料的情况阻止其完成请求, 因此服务端断开连接.
- CloseInternalServerErr StatusCode = 1011
+ // CloseInternalErr 客户端由于遇到没有预料的情况阻止其完成请求, 因此服务端断开连接.
+ CloseInternalErr StatusCode = 1011
// CloseServiceRestart 服务器由于重启而断开连接. [Ref]
CloseServiceRestart StatusCode = 1012
diff --git a/internal/io.go b/internal/io.go
index 32aefa7d..45674176 100644
--- a/internal/io.go
+++ b/internal/io.go
@@ -21,13 +21,11 @@ func WriteN(writer io.Writer, content []byte) error {
// CheckEncoding 检查 payload 的编码是否有效
// checks if the encoding of the payload is valid
-func CheckEncoding(opcode uint8, payload []byte) bool {
- switch opcode {
- case 1, 8:
+func CheckEncoding(enabled bool, opcode uint8, payload []byte) bool {
+ if enabled && (opcode == 1 || opcode == 8) {
return utf8.Valid(payload)
- default:
- return true
}
+ return true
}
type Payload interface {
@@ -39,11 +37,9 @@ type Payload interface {
type Buffers [][]byte
func (b Buffers) CheckEncoding(enabled bool, opcode uint8) bool {
- if enabled {
- for i, _ := range b {
- if !CheckEncoding(opcode, b[i]) {
- return false
- }
+ for i, _ := range b {
+ if !CheckEncoding(enabled, opcode, b[i]) {
+ return false
}
}
return true
@@ -73,10 +69,7 @@ func (b Buffers) WriteTo(w io.Writer) (int64, error) {
type Bytes []byte
func (b Bytes) CheckEncoding(enabled bool, opcode uint8) bool {
- if enabled {
- return CheckEncoding(opcode, b)
- }
- return true
+ return CheckEncoding(enabled, opcode, b)
}
func (b Bytes) Len() int {
diff --git a/reader.go b/reader.go
index 7235ef19..ea2037fe 100644
--- a/reader.go
+++ b/reader.go
@@ -161,13 +161,13 @@ func (c *Conn) dispatch(msg *Message) error {
// Emit onmessage event
func (c *Conn) emitMessage(msg *Message) (err error) {
if msg.compressed {
- msg.Data, err = c.deflater.Decompress(msg.Data, c.getDpsDict())
+ msg.Data, err = c.deflater.Decompress(msg.Data, c.dpsWindow.dict)
if err != nil {
- return internal.NewError(internal.CloseInternalServerErr, err)
+ return internal.NewError(internal.CloseInternalErr, err)
}
_, _ = c.dpsWindow.Write(msg.Bytes())
}
- if !c.isTextValid(msg.Opcode, msg.Bytes()) {
+ if !internal.CheckEncoding(c.config.CheckUtf8Enabled, uint8(msg.Opcode), msg.Bytes()) {
return internal.NewError(internal.CloseUnsupportedData, ErrTextEncoding)
}
if c.config.ParallelEnabled {
diff --git a/types.go b/types.go
index 9acc0d90..6088504d 100644
--- a/types.go
+++ b/types.go
@@ -73,6 +73,10 @@ var (
// Text message encoding error (must be utf8)
ErrTextEncoding = errors.New("invalid text encoding")
+ // ErrMessageTooLarge 消息体积过大
+ // message is too large
+ ErrMessageTooLarge = errors.New("message too large")
+
// ErrConnClosed 连接已关闭
// Connection closed
ErrConnClosed = net.ErrClosed
diff --git a/upgrader.go b/upgrader.go
index def7b67a..6b465094 100644
--- a/upgrader.go
+++ b/upgrader.go
@@ -116,7 +116,7 @@ func NewUpgrader(eventHandler Event, option *ServerOption) *Upgrader {
func (c *Upgrader) hijack(w http.ResponseWriter) (net.Conn, *bufio.Reader, error) {
hj, ok := w.(http.Hijacker)
if !ok {
- return nil, nil, internal.CloseInternalServerErr
+ return nil, nil, internal.CloseInternalErr
}
netConn, _, err := hj.Hijack()
if err != nil {
@@ -249,11 +249,11 @@ func (c *Upgrader) doUpgradeFromConn(netConn net.Conn, br *bufio.Reader, r *http
// Compressing and decompressing dictionaries has a large memory overhead, so use lazy loading.
if pd.Enabled {
socket.deflater = c.deflaterPool.Select()
- if c.option.PermessageDeflate.ServerContextTakeover {
- socket.cpsWindow.initialize(config.cswPool, c.option.PermessageDeflate.ServerMaxWindowBits)
+ if pd.ServerContextTakeover {
+ socket.cpsWindow.initialize(config.cswPool, pd.ServerMaxWindowBits)
}
- if c.option.PermessageDeflate.ClientContextTakeover {
- socket.dpsWindow.initialize(config.dswPool, c.option.PermessageDeflate.ClientMaxWindowBits)
+ if pd.ClientContextTakeover {
+ socket.dpsWindow.initialize(config.dswPool, pd.ClientMaxWindowBits)
}
}
return socket, nil
diff --git a/bigfile.go b/writefile.go
similarity index 90%
rename from bigfile.go
rename to writefile.go
index 468bcf40..2c7a2911 100644
--- a/bigfile.go
+++ b/writefile.go
@@ -57,7 +57,7 @@ func (c *Conn) splitReader(r io.Reader, f func(index int, eof bool, p []byte) er
// Segmented write technology to reduce memory usage during write process
func (c *Conn) WriteFile(opcode Opcode, payload io.Reader) error {
err := c.doWriteFile(opcode, payload)
- c.emitError(err)
+ c.emitError(false, err)
return err
}
@@ -92,7 +92,8 @@ func (c *Conn) doWriteFile(opcode Opcode, payload io.Reader) error {
if c.pd.Enabled {
var deflater = c.getBigDeflater()
var fw = &flateWriter{cb: cb}
- err := deflater.Compress(payload, fw, c.getCpsDict(false), &c.cpsWindow)
+ var reader = &readerWrapper{r: payload, sw: &c.cpsWindow}
+ err := deflater.Compress(reader, fw, c.cpsWindow.dict)
c.putBigDeflater(deflater)
return err
} else {
@@ -119,11 +120,11 @@ func newBigDeflater(isServer bool, options PermessageDeflate) *bigDeflater {
func (c *bigDeflater) FlateWriter() *flate.Writer { return (*flate.Writer)(c) }
// Compress 压缩
-func (c *bigDeflater) Compress(src io.Reader, dst *flateWriter, dict []byte, sw *slideWindow) error {
- if err := compressTo(c.FlateWriter(), &readerWrapper{r: src, sw: sw}, dst, dict); err != nil {
+func (c *bigDeflater) Compress(r io.WriterTo, w *flateWriter, dict []byte) error {
+ if err := compressTo(c.FlateWriter(), r, w, dict); err != nil {
return err
}
- return dst.Flush()
+ return w.Flush()
}
// 写入代理
@@ -223,11 +224,3 @@ func (c *readerWrapper) WriteTo(w io.Writer) (int64, error) {
}
return int64(sum), err
}
-
-func compressTo(cpsWriter *flate.Writer, r io.WriterTo, w io.Writer, dict []byte) error {
- cpsWriter.ResetDict(w, dict)
- if _, err := r.WriteTo(cpsWriter); err != nil {
- return err
- }
- return cpsWriter.Flush()
-}
diff --git a/writer.go b/writer.go
index 7ec569b9..a534bbd9 100644
--- a/writer.go
+++ b/writer.go
@@ -2,7 +2,6 @@ package gws
import (
"bytes"
- "errors"
"math"
"sync"
"sync/atomic"
@@ -15,12 +14,29 @@ import (
// Send shutdown frame, active disconnection
// If you don't have any special needs, we recommend code=1000, reason=nil
// https://developer.mozilla.org/zh-CN/docs/Web/API/CloseEvent#status_codes
-func (c *Conn) WriteClose(code uint16, reason []byte) {
- var err = internal.NewError(internal.StatusCode(code), errEmpty)
- if len(reason) > 0 {
- err.Err = errors.New(string(reason))
+func (c *Conn) WriteClose(code uint16, reason []byte) error {
+ if atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
+ var buf = binaryPool.Get(128)
+ code = internal.SelectValue(code < 1000, 1000, code)
+ buf.Write(internal.StatusCode(code).Bytes())
+ buf.Write(reason)
+ err := c.writeClose(internal.StatusCode(code), buf.Bytes())
+ binaryPool.Put(buf)
+ return err
}
- c.emitError(err)
+ return ErrConnClosed
+}
+
+// 关闭连接并存储错误信息
+// Closes the connection and stores the error information
+func (c *Conn) writeClose(ev error, reason []byte) error {
+ if len(reason) > internal.ThresholdV1 {
+ reason = reason[:internal.ThresholdV1]
+ }
+ c.ev.Store(ev)
+ err := c.doWrite(OpcodeCloseConnection, internal.Bytes(reason))
+ _ = c.conn.Close()
+ return err
}
// WritePing
@@ -49,7 +65,7 @@ func (c *Conn) WriteString(s string) error {
// Writes text/binary messages, text messages should be encoded in UTF8.
func (c *Conn) WriteMessage(opcode Opcode, payload []byte) error {
err := c.doWrite(opcode, internal.Bytes(payload))
- c.emitError(err)
+ c.emitError(false, err)
return err
}
@@ -59,7 +75,7 @@ func (c *Conn) WriteMessage(opcode Opcode, payload []byte) error {
// Write messages to the task queue asynchronously and non-blockingly,
// allowing payload memory to be recycled only after receiving the callback
func (c *Conn) WriteAsync(opcode Opcode, payload []byte, callback func(error)) {
- c.writeQueue.Push(func() {
+ c.Async(func() {
if err := c.WriteMessage(opcode, payload); callback != nil {
callback(err)
}
@@ -71,14 +87,14 @@ func (c *Conn) WriteAsync(opcode Opcode, payload []byte, callback func(error)) {
// Writev is similar to WriteMessage, except that you can write multiple slices at once.
func (c *Conn) Writev(opcode Opcode, payloads ...[]byte) error {
var err = c.doWrite(opcode, internal.Buffers(payloads))
- c.emitError(err)
+ c.emitError(false, err)
return err
}
// WritevAsync 类似 WriteAsync, 区别是可以一次写入多个切片
// It's similar to WriteAsync, except that you can write multiple slices at once.
func (c *Conn) WritevAsync(opcode Opcode, payloads [][]byte, callback func(error)) {
- c.writeQueue.Push(func() {
+ c.Async(func() {
if err := c.Writev(opcode, payloads...); callback != nil {
callback(err)
}
@@ -146,20 +162,19 @@ type frameConfig struct {
// 生成帧数据
// Generates the frame data
func (c *Conn) genFrame(opcode Opcode, payload internal.Payload, cfg frameConfig) (*bytes.Buffer, error) {
+ var n = payload.Len()
if opcode == OpcodeText && !payload.CheckEncoding(cfg.checkEncoding, uint8(opcode)) {
- return nil, internal.NewError(internal.CloseUnsupportedData, ErrTextEncoding)
+ return nil, ErrTextEncoding
}
-
- var n = payload.Len()
if n > c.config.WriteMaxPayloadSize {
- return nil, internal.CloseMessageTooLarge
+ return nil, ErrMessageTooLarge
}
var buf = binaryPool.Get(n + frameHeaderSize)
buf.Write(framePadding[0:])
if cfg.compress && opcode.isDataFrame() && n >= c.pd.Threshold {
- return c.compressData(buf, opcode, cfg.fin, payload, cfg.broadcast)
+ return c.compressData(opcode, payload, buf, cfg)
}
var header = frameHeader{}
@@ -177,15 +192,18 @@ func (c *Conn) genFrame(opcode Opcode, payload internal.Payload, cfg frameConfig
// 压缩数据并生成帧
// Compresses the data and generates the frame
-func (c *Conn) compressData(buf *bytes.Buffer, opcode Opcode, fin bool, payload internal.Payload, isBroadcast bool) (*bytes.Buffer, error) {
- err := c.deflater.Compress(payload, buf, c.getCpsDict(isBroadcast))
- if err != nil {
+func (c *Conn) compressData(opcode Opcode, payload internal.Payload, buf *bytes.Buffer, cfg frameConfig) (*bytes.Buffer, error) {
+ // 广播模式必须保证每一帧都是相同的内容, 所以不能使用字典优化压缩率
+ // Broadcast mode must ensure that every frame is the same, so you can't use a dictionary to optimize the compression rate.
+ var dict = internal.SelectValue(cfg.broadcast, nil, c.cpsWindow.dict)
+ if err := c.deflater.Compress(payload, buf, dict); err != nil {
return nil, err
}
+
var contents = buf.Bytes()
var payloadSize = buf.Len() - frameHeaderSize
var header = frameHeader{}
- headerLength, maskBytes := header.GenerateHeader(c.isServer, fin, true, opcode, payloadSize)
+ headerLength, maskBytes := header.GenerateHeader(c.isServer, cfg.fin, true, opcode, payloadSize)
if !c.isServer {
internal.MaskXOR(contents[frameHeaderSize:], maskBytes)
}
@@ -232,7 +250,7 @@ func (c *Broadcaster) writeFrame(socket *Conn, frame *bytes.Buffer) error {
}
socket.mu.Lock()
var err = internal.WriteN(socket.conn, frame.Bytes())
- socket.cpsWindow.Write(c.payload)
+ _, _ = socket.cpsWindow.Write(c.payload)
socket.mu.Unlock()
return err
}
@@ -259,7 +277,7 @@ func (c *Broadcaster) Broadcast(socket *Conn) error {
atomic.AddInt64(&c.state, 1)
socket.writeQueue.Push(func() {
var err = c.writeFrame(socket, msg.frame)
- socket.emitError(err)
+ socket.emitError(false, err)
if atomic.AddInt64(&c.state, -1) == 0 {
c.doClose()
}
diff --git a/writer_test.go b/writer_test.go
index 66e45d33..46530d0d 100644
--- a/writer_test.go
+++ b/writer_test.go
@@ -21,7 +21,7 @@ func testWrite(c *Conn, fin bool, opcode Opcode, payload []byte) error {
var buf = bytes.NewBufferString("")
err := c.deflater.Compress(internal.Bytes(payload), buf, c.cpsWindow.dict)
if err != nil {
- return internal.NewError(internal.CloseInternalServerErr, err)
+ return internal.NewError(internal.CloseInternalErr, err)
}
payload = buf.Bytes()
}
@@ -114,28 +114,90 @@ func TestWriteBigMessage(t *testing.T) {
func TestWriteClose(t *testing.T) {
var as = assert.New(t)
- var serverHandler = new(webSocketMocker)
- var clientHandler = new(webSocketMocker)
- var serverOption = &ServerOption{}
- var clientOption = &ClientOption{}
-
- var wg = sync.WaitGroup{}
- wg.Add(1)
- serverHandler.onClose = func(socket *Conn, err error) {
- as.Error(err)
- wg.Done()
- }
- server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption)
- go server.ReadLoop()
- go client.ReadLoop()
- server.WriteClose(1000, []byte("goodbye"))
- wg.Wait()
t.Run("", func(t *testing.T) {
+ var serverHandler = new(webSocketMocker)
+ var clientHandler = new(webSocketMocker)
+ var serverOption = &ServerOption{}
+ var clientOption = &ClientOption{}
+
+ var wg = sync.WaitGroup{}
+ wg.Add(1)
+ serverHandler.onClose = func(socket *Conn, err error) {
+ as.Error(err)
+ wg.Done()
+ }
+ server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption)
+ go server.ReadLoop()
+ go client.ReadLoop()
+ server.WriteClose(1000, []byte("goodbye"))
+ wg.Wait()
var socket = &Conn{closed: 1, config: server.config}
socket.WriteMessage(OpcodeText, nil)
socket.WriteAsync(OpcodeText, nil, nil)
})
+
+ t.Run("", func(t *testing.T) {
+ var pd = PermessageDeflate{
+ Enabled: true,
+ }
+ var serverHandler = new(webSocketMocker)
+ var clientHandler = new(webSocketMocker)
+ var serverOption = &ServerOption{
+ PermessageDeflate: pd,
+ }
+ var clientOption = &ClientOption{
+ PermessageDeflate: pd,
+ }
+ var wg = &sync.WaitGroup{}
+ wg.Add(1)
+
+ serverHandler.onClose = func(socket *Conn, err error) {
+ if v, ok := err.(*CloseError); ok && string(v.Reason) == "goodbye" {
+ wg.Done()
+ }
+ }
+
+ server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption)
+ go server.ReadLoop()
+ go client.ReadLoop()
+
+ var err = client.WriteClose(1006, []byte("goodbye"))
+ assert.NoError(t, err)
+ err = client.WriteClose(1006, []byte("goodbye"))
+ assert.True(t, errors.Is(err, ErrConnClosed))
+ wg.Wait()
+ })
+
+ t.Run("", func(t *testing.T) {
+ var pd = PermessageDeflate{
+ Enabled: true,
+ }
+ var serverHandler = new(webSocketMocker)
+ var clientHandler = new(webSocketMocker)
+ var serverOption = &ServerOption{
+ PermessageDeflate: pd,
+ }
+ var clientOption = &ClientOption{
+ PermessageDeflate: pd,
+ }
+ var wg = &sync.WaitGroup{}
+ wg.Add(1)
+
+ serverHandler.onClose = func(socket *Conn, err error) {
+ if v, ok := err.(*CloseError); ok && len(v.Reason) == 123 {
+ wg.Done()
+ }
+ }
+
+ server, client := newPeer(serverHandler, serverOption, clientHandler, clientOption)
+ go server.ReadLoop()
+ go client.ReadLoop()
+
+ var err = client.WriteClose(1006, internal.AlphabetNumeric.Generate(1024))
+ assert.NoError(t, err)
+ wg.Wait()
+ })
}
func TestConn_WriteAsyncError(t *testing.T) {
@@ -755,7 +817,8 @@ func TestConn_WriteFile(t *testing.T) {
var fw = &flateWriter{cb: func(index int, eof bool, p []byte) error {
return nil
}}
- err := deflater.Compress(new(writerTo), fw, nil, new(slideWindow))
+ var reader = &readerWrapper{r: new(writerTo), sw: new(slideWindow)}
+ err := deflater.Compress(reader, fw, nil)
assert.Error(t, err)
})
@@ -768,8 +831,8 @@ func TestConn_WriteFile(t *testing.T) {
var fw = &flateWriter{cb: func(index int, eof bool, p []byte) error {
return errors.New("2")
}}
- src := bytes.NewBufferString("hello")
- err := deflater.Compress(src, fw, nil, new(slideWindow))
+ var reader = &readerWrapper{r: new(writerTo), sw: new(slideWindow)}
+ err := deflater.Compress(reader, fw, nil)
assert.Error(t, err)
})