diff --git a/client.go b/client.go index d8c1c606..606af882 100644 --- a/client.go +++ b/client.go @@ -15,343 +15,165 @@ import ( "github.com/lxzan/gws/internal" ) -// Dialer 接口定义了拨号方法 -// Dialer interface defines the dial method +// Dialer 拨号器接口 +// dialer interface type Dialer interface { - // Dial 方法用于建立网络连接 - // Dial method is used to establish a network connection + // Dial 连接到指定网络上的地址 + // connects to the address on the named network Dial(network, addr string) (c net.Conn, err error) } -// connector 结构体用于管理客户端连接 -// connector struct is used to manage client connections type connector struct { - // option 表示客户端选项 - // option represents client options - option *ClientOption - - // conn 表示网络连接 - // conn represents the network connection - conn net.Conn - - // eventHandler 表示事件处理器 - // eventHandler represents the event handler - eventHandler Event - - // secWebsocketKey 表示 WebSocket 安全密钥 - // secWebsocketKey represents the WebSocket security key + option *ClientOption + conn net.Conn + eventHandler Event secWebsocketKey string } // NewClient 创建一个新的 WebSocket 客户端连接 -// NewClient creates a new WebSocket client connection +// creates a new WebSocket client connection func NewClient(handler Event, option *ClientOption) (*Conn, *http.Response, error) { - // 初始化客户端选项 - // Initialize client options option = initClientOption(option) - - // 创建一个新的连接器实例 - // Create a new connector instance c := &connector{option: option, eventHandler: handler} - - // 解析 WebSocket 地址 - // Parse the WebSocket address URL, err := url.Parse(option.Addr) if err != nil { return nil, nil, err } - - // 检查协议是否为 ws 或 wss - // Check if the protocol is ws or wss if URL.Scheme != "ws" && URL.Scheme != "wss" { return nil, nil, ErrUnsupportedProtocol } - // 判断是否启用 TLS - // Determine if TLS is enabled var tlsEnabled = URL.Scheme == "wss" - - // 创建拨号器 - // Create a dialer dialer, err := option.NewDialer() if err != nil { return nil, nil, err } - // 选择端口号,默认情况下 wss 使用 443 端口,ws 使用 80 端口 - // Select the port number, default to 443 for wss and 80 for ws port := internal.SelectValue(URL.Port() == "", internal.SelectValue(tlsEnabled, "443", "80"), URL.Port()) - - // 选择主机名,默认情况下使用 127.0.0.1 - // Select the hostname, default to 127.0.0.1 hp := internal.SelectValue(URL.Hostname() == "", "127.0.0.1", URL.Hostname()) + ":" + port - - // 通过拨号器拨号连接到服务器 - // Dial the server using the dialer c.conn, err = dialer.Dial("tcp", hp) if err != nil { return nil, nil, err } - - // 如果启用了 TLS,配置 TLS 设置 - // If TLS is enabled, configure TLS settings if tlsEnabled { - // 如果没有提供 TlsConfig,则创建一个新的 tls.Config 实例 - // If TlsConfig is not provided, create a new tls.Config instance if option.TlsConfig == nil { option.TlsConfig = &tls.Config{} } - - // 如果 TlsConfig 中没有设置 ServerName,则使用 URL 的主机名 - // If ServerName is not set in TlsConfig, use the hostname from the URL if option.TlsConfig.ServerName == "" { option.TlsConfig.ServerName = URL.Hostname() } - - // 使用配置的 TlsConfig 创建一个新的 TLS 客户端连接 - // Create a new TLS client connection using the configured TlsConfig c.conn = tls.Client(c.conn, option.TlsConfig) } - // 执行握手操作 - // Perform the handshake operation client, resp, err := c.handshake() if err != nil { _ = c.conn.Close() } - - // 返回客户端连接、HTTP 响应和错误信息 - // Return the client connection, HTTP response, and error information return client, resp, err } // NewClientFromConn 通过外部连接创建客户端, 支持 TCP/KCP/Unix Domain Socket -// Create New client via external connection, supports TCP/KCP/Unix Domain Socket. +// Create new client via external connection, supports TCP/KCP/Unix Domain Socket. func NewClientFromConn(handler Event, option *ClientOption, conn net.Conn) (*Conn, *http.Response, error) { - // 初始化客户端选项 - // Initialize client options option = initClientOption(option) - - // 创建一个新的 connector 实例 - // Create a new connector instance c := &connector{option: option, conn: conn, eventHandler: handler} - - // 执行握手操作 - // Perform the handshake operation client, resp, err := c.handshake() - - // 如果握手失败,关闭连接 - // If the handshake fails, close the connection if err != nil { _ = c.conn.Close() } - - // 返回客户端连接、HTTP 响应和错误信息 - // Return the client connection, HTTP response, and error information return client, resp, err } -// request 发送 HTTP 请求以发起 WebSocket 握手 -// request sends an HTTP request to initiate a WebSocket handshake +// request 发送HTTP请求, 即WebSocket握手 +// Sends an http request, i.e., websocket handshake func (c *connector) request() (*http.Response, *bufio.Reader, error) { - // 设置连接的超时时间 - // Set the connection timeout _ = c.conn.SetDeadline(time.Now().Add(c.option.HandshakeTimeout)) - - // 创建一个带有超时的上下文 - // Create a context with a timeout ctx, cancel := context.WithTimeout(context.Background(), c.option.HandshakeTimeout) defer cancel() - // 创建一个新的 HTTP GET 请求 - // Create a new HTTP GET request + // 构建HTTP请求 + // building a http request r, err := http.NewRequestWithContext(ctx, http.MethodGet, c.option.Addr, nil) if err != nil { return nil, nil, err } - - // 将客户端选项中的请求头复制到 HTTP 请求头中 - // Copy the request headers from client options to the HTTP request headers for k, v := range c.option.RequestHeader { r.Header[k] = v } - - // 设置 Connection 头为 "Upgrade" - // Set the Connection header to "Upgrade" r.Header.Set(internal.Connection.Key, internal.Connection.Val) - - // 设置 Upgrade 头为 "websocket" - // Set the Upgrade header to "websocket" r.Header.Set(internal.Upgrade.Key, internal.Upgrade.Val) - - // 设置 Sec-WebSocket-Version 头为 "13" - // Set the Sec-WebSocket-Version header to "13" r.Header.Set(internal.SecWebSocketVersion.Key, internal.SecWebSocketVersion.Val) - - // 如果启用了每消息压缩扩展,则设置 Sec-WebSocket-Extensions 头 - // If per-message deflate extension is enabled, set the Sec-WebSocket-Extensions header if c.option.PermessageDeflate.Enabled { r.Header.Set(internal.SecWebSocketExtensions.Key, c.option.PermessageDeflate.genRequestHeader()) } - - // 如果没有安全 WebSocket 密钥,则生成一个 - // Generate a security WebSocket key if not already set if c.secWebsocketKey == "" { - // 创建一个 16 字节的数组用于存储密钥 - // Create a 16-byte array to store the key var key [16]byte - - // 使用内部方法生成前 8 字节的随机数并存储在 key 数组中 - // Use an internal method to generate a random number for the first 8 bytes and store it in the key array binary.BigEndian.PutUint64(key[0:8], internal.AlphabetNumeric.Uint64()) - - // 使用内部方法生成后 8 字节的随机数并存储在 key 数组中 - // Use an internal method to generate a random number for the last 8 bytes and store it in the key array binary.BigEndian.PutUint64(key[8:16], internal.AlphabetNumeric.Uint64()) - - // 将生成的密钥编码为 base64 字符串并赋值给 secWebsocketKey - // Encode the generated key as a base64 string and assign it to secWebsocketKey c.secWebsocketKey = base64.StdEncoding.EncodeToString(key[0:]) - - // 将生成的密钥设置到请求头中 - // Set the generated key in the request header r.Header.Set(internal.SecWebSocketKey.Key, c.secWebsocketKey) } - // 创建一个用于接收错误的通道 - // Create a channel to receive errors var ch = make(chan error) - // 启动一个 goroutine 发送请求 - // Start a goroutine to send the request + // 发送http请求 + // send http request go func() { ch <- r.Write(c.conn) }() - // 等待请求完成或上下文超时 - // Wait for the request to complete or the context to timeout + // 同步等待请求是否发送成功 + // Synchronized waiting for the request to be sent successfully select { case err = <-ch: - // 如果请求完成,将错误赋值给 err - // If the request completes, assign the error to err case <-ctx.Done(): - // 如果上下文超时或取消,将上下文的错误赋值给 err - // If the context times out or is canceled, assign the context's error to err err = ctx.Err() } - - // 如果发生错误,返回错误信息 - // If an error occurs, return the error information if err != nil { return nil, nil, err } - // 创建一个带有指定缓冲区大小的 bufio.Reader - // Create a bufio.Reader with the specified buffer size + // 读取响应结果 + // Read the response result br := bufio.NewReaderSize(c.conn, c.option.ReadBufferSize) - - // 读取 HTTP 响应 - // Read the HTTP response resp, err := http.ReadResponse(br, r) - - // 返回 HTTP 响应、缓冲读取器和错误信息 - // Return the HTTP response, buffered reader, and error information return resp, br, err } -// getPermessageDeflate 获取每消息压缩扩展的配置 -// getPermessageDeflate retrieves the configuration for per-message deflate extension +// getPermessageDeflate 获取压缩拓展结果 +// Get compression expansion results func (c *connector) getPermessageDeflate(extensions string) PermessageDeflate { - // 解析服务器端的每消息压缩扩展配置 - // Parse the server-side per-message deflate extension configuration serverPD := permessageNegotiation(extensions) - - // 获取客户端的每消息压缩配置 - // Get the client-side per-message deflate configuration clientPD := c.option.PermessageDeflate - - // 创建一个新的每消息压缩配置实例 - // Create a new instance of per-message deflate configuration pd := PermessageDeflate{ - // 启用状态取决于客户端配置和服务器扩展是否包含每消息压缩 - // Enabled status depends on client configuration and whether the server extensions include per-message deflate - Enabled: clientPD.Enabled && strings.Contains(extensions, internal.PermessageDeflate), - - // 设置压缩阈值 - // Set the compression threshold - Threshold: clientPD.Threshold, - - // 设置压缩级别 - // Set the compression level - Level: clientPD.Level, - - // 设置缓冲池大小 - // Set the buffer pool size - PoolSize: clientPD.PoolSize, - - // 设置服务器上下文接管配置 - // Set the server context takeover configuration + Enabled: clientPD.Enabled && strings.Contains(extensions, internal.PermessageDeflate), + Threshold: clientPD.Threshold, + Level: clientPD.Level, + PoolSize: clientPD.PoolSize, ServerContextTakeover: serverPD.ServerContextTakeover, - - // 设置客户端上下文接管配置 - // Set the client context takeover configuration ClientContextTakeover: serverPD.ClientContextTakeover, - - // 设置服务器最大窗口位数 - // Set the server max window bits - ServerMaxWindowBits: serverPD.ServerMaxWindowBits, - - // 设置客户端最大窗口位数 - // Set the client max window bits - ClientMaxWindowBits: serverPD.ClientMaxWindowBits, + ServerMaxWindowBits: serverPD.ServerMaxWindowBits, + ClientMaxWindowBits: serverPD.ClientMaxWindowBits, } - - // 设置压缩阈值 - // Set the compression threshold pd.setThreshold(false) - - // 返回每消息压缩配置 - // Return the per-message deflate configuration return pd } // handshake 执行 WebSocket 握手操作 -// handshake performs the WebSocket handshake operation +// performs the WebSocket handshake operation func (c *connector) handshake() (*Conn, *http.Response, error) { - // 发送握手请求并读取响应 - // Send the handshake request and read the response resp, br, err := c.request() if err != nil { - // 如果请求失败,返回错误信息 - // If the request fails, return the error information return nil, resp, err } - - // 检查响应头以验证握手是否成功 - // Check the response headers to verify if the handshake was successful if err = c.checkHeaders(resp); err != nil { - // 如果握手失败,返回错误信息 - // If the handshake fails, return the error information return nil, resp, err } - - // 获取协商的子协议 - // Get the negotiated subprotocol subprotocol, err := c.getSubProtocol(resp) if err != nil { - // 如果获取子协议失败,返回错误信息 - // If getting the subprotocol fails, return the error information return nil, resp, err } - // 获取响应头中的扩展字段 - // Get the extensions field from the response header var extensions = resp.Header.Get(internal.SecWebSocketExtensions.Key) - - // 获取每消息压缩扩展的配置 - // Get the per-message deflate configuration var pd = c.getPermessageDeflate(extensions) - - // 创建 WebSocket 连接对象 - // Create the WebSocket connection object socket := &Conn{ ss: c.option.NewSession(), isServer: false, @@ -368,86 +190,44 @@ func (c *connector) handshake() (*Conn, *http.Response, error) { writeQueue: workerQueue{maxConcurrency: 1}, readQueue: make(channel, c.option.ParallelGolimit), } - - // 如果启用了每消息压缩扩展,初始化压缩器和窗口 - // If per-message deflate is enabled, initialize the deflater and windows if pd.Enabled { - // 初始化压缩器,传入是否为服务器端、压缩配置和最大负载大小 - // Initialize the deflater, passing whether it is server-side, compression configuration, and max payload size socket.deflater.initialize(false, pd, c.option.ReadMaxPayloadSize) - - // 如果服务器上下文接管启用,初始化服务器端窗口 - // If server context takeover is enabled, initialize the server-side window if pd.ServerContextTakeover { socket.dpsWindow.initialize(nil, pd.ServerMaxWindowBits) } - - // 如果客户端上下文接管启用,初始化客户端窗口 - // If client context takeover is enabled, initialize the client-side window if pd.ClientContextTakeover { socket.cpsWindow.initialize(nil, pd.ClientMaxWindowBits) } } - - // 返回 WebSocket 连接对象、HTTP 响应和错误信息 - // Return the WebSocket connection object, HTTP response, and error information return socket, resp, c.conn.SetDeadline(time.Time{}) } // getSubProtocol 从响应中获取子协议 -// getSubProtocol retrieves the subprotocol from the response +// retrieves the subprotocol from the response func (c *connector) getSubProtocol(resp *http.Response) (string, error) { - // 从请求头中获取客户端支持的子协议列表 - // Get the list of subprotocols supported by the client from the request header a := internal.Split(c.option.RequestHeader.Get(internal.SecWebSocketProtocol.Key), ",") - - // 从响应头中获取服务器支持的子协议列表 - // Get the list of subprotocols supported by the server from the response header b := internal.Split(resp.Header.Get(internal.SecWebSocketProtocol.Key), ",") - - // 获取客户端和服务器支持的子协议的交集 - // Get the intersection of subprotocols supported by both client and server subprotocol := internal.GetIntersectionElem(a, b) - - // 如果客户端支持子协议但未协商出共同的子协议,返回子协议协商错误 - // If the client supports subprotocols but no common subprotocol is negotiated, return subprotocol negotiation error if len(a) > 0 && subprotocol == "" { return "", ErrSubprotocolNegotiation } - - // 返回协商出的子协议 - // Return the negotiated subprotocol return subprotocol, nil } // checkHeaders 检查响应头以验证握手是否成功 -// checkHeaders checks the response headers to verify if the handshake was successful +// checks the response headers to verify if the handshake was successful func (c *connector) checkHeaders(resp *http.Response) error { - // 检查状态码是否为 101 Switching Protocols - // Check if the status code is 101 Switching Protocols if resp.StatusCode != http.StatusSwitchingProtocols { return ErrHandshake } - - // 检查响应头中的 Connection 字段是否包含 "Upgrade" - // Check if the Connection field in the response header contains "Upgrade" if !internal.HttpHeaderContains(resp.Header.Get(internal.Connection.Key), internal.Connection.Val) { return ErrHandshake } - - // 检查响应头中的 Upgrade 字段是否为 "websocket" - // Check if the Upgrade field in the response header is "websocket" if !strings.EqualFold(resp.Header.Get(internal.Upgrade.Key), internal.Upgrade.Val) { return ErrHandshake } - - // 检查 Sec-WebSocket-Accept 字段的值是否正确 - // Check if the Sec-WebSocket-Accept field value is correct if resp.Header.Get(internal.SecWebSocketAccept.Key) != internal.ComputeAcceptKey(c.secWebsocketKey) { return ErrHandshake } - - // 如果所有检查都通过,返回 nil 表示成功 - // If all checks pass, return nil to indicate success return nil } diff --git a/compress.go b/compress.go index f1573b86..1a4a73e2 100644 --- a/compress.go +++ b/compress.go @@ -14,610 +14,257 @@ import ( "github.com/lxzan/gws/internal" ) -// flateTail 是一个字节切片,用于表示 deflate 压缩算法的尾部标记。 -// flateTail is a byte slice, used to represent the tail marker of the deflate compression algorithm. +// flateTail deflate压缩算法的尾部标记 +// the tail marker of the deflate compression algorithm var flateTail = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff} -// deflaterPool 是一个结构体,用于表示一个 deflater 对象的池。 -// deflaterPool is a struct, used to represent a pool of deflater objects. type deflaterPool struct { - // serial 是一个无符号64位整数,用于表示 deflaterPool 的序列号。 - // serial is an unsigned 64-bit integer, used to represent the serial number of the deflaterPool. serial uint64 - - // num 是一个无符号64位整数,用于表示 deflaterPool 中 deflater 对象的数量。 - // num is an unsigned 64-bit integer, used to represent the number of deflater objects in the deflaterPool. - num uint64 - - // pool 是一个 deflater 对象的切片,用于存储 deflaterPool 中的 deflater 对象。 - // pool is a slice of deflater objects, used to store the deflater objects in the deflaterPool. - pool []*deflater + num uint64 + pool []*deflater } -// initialize 是 deflaterPool 结构体的一个方法,用于初始化 deflaterPool。 -// initialize is a method of the deflaterPool struct, used to initialize the deflaterPool. +// initialize 初始化deflaterPool +// initialize the deflaterPool func (c *deflaterPool) initialize(options PermessageDeflate, limit int) *deflaterPool { - // 设置 deflaterPool 的大小为 options 的 PoolSize 属性值 - // Set the size of the deflaterPool to the PoolSize property value of options c.num = uint64(options.PoolSize) - - // 循环创建 deflater 对象并添加到 deflaterPool 中 - // Loop to create deflater objects and add them to the deflaterPool for i := uint64(0); i < c.num; i++ { - // 创建一个新的 deflater 对象并初始化,然后添加到 deflaterPool 的 pool 切片中 - // Create a new deflater object and initialize it, then add it to the pool slice of the deflaterPool c.pool = append(c.pool, new(deflater).initialize(true, options, limit)) } - - // 返回初始化后的 deflaterPool 对象 - // Return the initialized deflaterPool object return c } -// Select 是 deflaterPool 结构体的一个方法,用于从 deflaterPool 中选择一个 deflater 对象。 -// Select is a method of the deflaterPool struct, used to select a deflater object from the deflaterPool. +// Select 从deflaterPool中选择一个deflater对象 +// select a deflater object from the deflaterPool func (c *deflaterPool) Select() *deflater { - // 使用原子操作增加 deflaterPool 的序列号,并与 deflaterPool 的大小减一进行按位与运算,得到一个索引值 - // Use atomic operation to increase the serial number of the deflaterPool, and perform bitwise AND operation with the size of the deflaterPool minus one to get an index value var j = atomic.AddUint64(&c.serial, 1) & (c.num - 1) - - // 返回 deflaterPool 中索引为 j 的 deflater 对象 - // Return the deflater object at index j in the deflaterPool return c.pool[j] } -// deflater 是一个结构体,用于表示一个 deflate 压缩器。 -// deflater is a struct, used to represent a deflate compressor. type deflater struct { - // dpsLocker 是一个互斥锁,用于保护 dpsBuffer 和 dpsReader 的并发访问。 - // dpsLocker is a mutex, used to protect concurrent access to dpsBuffer and dpsReader. dpsLocker sync.Mutex - - // buf 是一个字节切片,用于存储压缩数据。 - // buf is a byte slice, used to store compressed data. - buf []byte - - // limit 是一个整数,用于表示压缩数据的大小限制。 - // limit is an integer, used to represent the size limit of compressed data. - limit int - - // dpsBuffer 是一个字节缓冲,用于存储待压缩的数据。 - // dpsBuffer is a byte buffer, used to store data to be compressed. + buf []byte + limit int dpsBuffer *bytes.Buffer - - // dpsReader 是一个读取器,用于从 dpsBuffer 中读取数据。 - // dpsReader is a reader, used to read data from dpsBuffer. dpsReader io.ReadCloser - - // cpsLocker 是一个互斥锁,用于保护 cpsWriter 的并发访问。 - // cpsLocker is a mutex, used to protect concurrent access to cpsWriter. cpsLocker sync.Mutex - - // cpsWriter 是一个写入器,用于向目标写入压缩后的数据。 - // cpsWriter is a writer, used to write compressed data to the target. cpsWriter *flate.Writer } -// initialize 是 deflater 结构体的一个方法,用于初始化 deflater。 -// initialize is a method of the deflater struct, used to initialize the deflater. +// initialize 初始化deflater +// initialize the deflater func (c *deflater) initialize(isServer bool, options PermessageDeflate, limit int) *deflater { - // 创建一个新的 deflate 读取器 - // Create a new deflate reader c.dpsReader = flate.NewReader(nil) - - // 创建一个新的字节缓冲 - // Create a new byte buffer c.dpsBuffer = bytes.NewBuffer(nil) - - // 创建一个大小为 32*1024 的字节切片 - // Create a byte slice of size 32*1024 c.buf = make([]byte, 32*1024) - - // 设置压缩数据的大小限制 - // Set the size limit of compressed data c.limit = limit - - // 根据是否是服务器,选择服务器或客户端的最大窗口位数 - // Select the maximum window bits of the server or client depending on whether it is a server windowBits := internal.SelectValue(isServer, options.ServerMaxWindowBits, options.ClientMaxWindowBits) - - // 如果窗口位数为 15 - // If the window bits is 15 if windowBits == 15 { - // 创建一个新的 deflate 写入器,压缩级别为 options.Level - // Create a new deflate writer with a compression level of options.Level c.cpsWriter, _ = flate.NewWriter(nil, options.Level) } else { - // 创建一个新的 deflate 写入器,窗口大小为 2 的 windowBits 次方 - // Create a new deflate writer with a window size of 2 to the power of windowBits c.cpsWriter, _ = flate.NewWriterWindow(nil, internal.BinaryPow(windowBits)) } - - // 返回初始化后的 deflater 对象 - // Return the initialized deflater object return c } -// resetFR 是 deflater 结构体的一个方法,用于重置 deflate 读取器和字节缓冲。 -// resetFR is a method of the deflater struct, used to reset the deflate reader and byte buffer. +// resetFR 重置deflate reader +// reset the deflate reader func (c *deflater) resetFR(r io.Reader, dict []byte) { - // 获取 deflate 读取器的 Resetter 接口 - // Get the Resetter interface of the deflate reader resetter := c.dpsReader.(flate.Resetter) - - // 使用新的读取器和字典重置 deflate 读取器 - // Reset the deflate reader with a new reader and dictionary _ = resetter.Reset(r, dict) // must return a null pointer - - // 如果字节缓冲的容量大于 256*1024,则创建一个新的字节缓冲 - // If the capacity of the byte buffer is greater than 256*1024, create a new byte buffer if c.dpsBuffer.Cap() > 256*1024 { c.dpsBuffer = bytes.NewBuffer(nil) } - - // 重置字节缓冲 - // Reset the byte buffer c.dpsBuffer.Reset() } -// Decompress 是 deflater 结构体的一个方法,用于解压缩数据。 -// Decompress is a method of the deflater struct, used to decompress data. +// Decompress 解压 +// decompress data func (c *deflater) Decompress(src *bytes.Buffer, dict []byte) (*bytes.Buffer, error) { - // 加锁,保护 dpsBuffer 和 dpsReader 的并发访问 - // Lock to protect concurrent access to dpsBuffer and dpsReader c.dpsLocker.Lock() - - // 函数返回时解锁 - // Unlock when the function returns defer c.dpsLocker.Unlock() - // 将 deflate 压缩算法的尾部标记写入源数据 - // Write the tail marker of the deflate compression algorithm into the source data _, _ = src.Write(flateTail) - - // 重置 deflate 读取器和字节缓冲 - // Reset the deflate reader and byte buffer c.resetFR(src, dict) - - // 创建一个限制读取器,限制读取的数据大小不超过 c.limit - // Create a limit reader, limiting the size of the data read to not exceed c.limit reader := limitReader(c.dpsReader, c.limit) - - // 将 reader 中的数据复制到 dpsBuffer 中,使用 c.buf 作为缓冲 - // Copy the data in reader to dpsBuffer, using c.buf as a buffer if _, err := io.CopyBuffer(c.dpsBuffer, reader, c.buf); err != nil { - // 如果复制过程中出现错误,返回 nil 和错误信息 - // If an error occurs during the copy, return nil and the error message return nil, err } - - // 从二进制池中获取一个新的字节缓冲,大小为 dpsBuffer 的长度 - // Get a new byte buffer from the binary pool, the size is the length of dpsBuffer var dst = binaryPool.Get(c.dpsBuffer.Len()) - - // 将 dpsBuffer 中的数据写入 dst - // Write the data in dpsBuffer to dst _, _ = c.dpsBuffer.WriteTo(dst) - - // 返回 dst 和 nil - // Return dst and nil return dst, nil } -// Compress 是 deflater 结构体的一个方法,用于压缩数据。 -// Compress is a method of the deflater struct, used to compress data. +// Compress 压缩 +// compress data func (c *deflater) Compress(src internal.Payload, dst *bytes.Buffer, dict []byte) error { - // 加锁,保护 cpsWriter 的并发访问 - // Lock to protect concurrent access to cpsWriter c.cpsLocker.Lock() - - // 函数返回时解锁 - // Unlock when the function returns defer c.cpsLocker.Unlock() - // 使用新的字节缓冲和字典重置 cpsWriter - // Reset cpsWriter with a new byte buffer and dictionary c.cpsWriter.ResetDict(dst, dict) - - // 将源数据写入 cpsWriter - // Write the source data to cpsWriter if _, err := src.WriteTo(c.cpsWriter); err != nil { - // 如果写入过程中出现错误,返回错误信息 - // If an error occurs during the write, return the error message return err } - - // 刷新 cpsWriter,将所有未写入的数据写入字节缓冲 - // Flush cpsWriter, write all unwritten data to the byte buffer if err := c.cpsWriter.Flush(); err != nil { - // 如果刷新过程中出现错误,返回错误信息 - // If an error occurs during the flush, return the error message return err } - - // 如果字节缓冲的长度大于等于 4 - // If the length of the byte buffer is greater than or equal to 4 if n := dst.Len(); n >= 4 { - // 获取字节缓冲的字节切片 - // Get the byte slice of the byte buffer compressedContent := dst.Bytes() - - // 如果字节切片的尾部 4 个字节表示的无符号整数等于最大的 16 位无符号整数 - // If the unsigned integer represented by the last 4 bytes of the byte slice is equal to the maximum 16-bit unsigned integer if tail := compressedContent[n-4:]; binary.BigEndian.Uint32(tail) == math.MaxUint16 { - // 截断字节缓冲,去掉尾部的 4 个字节 - // Truncate the byte buffer, remove the last 4 bytes dst.Truncate(n - 4) } } - - // 返回 nil,表示压缩成功 - // Return nil, indicating that the compression was successful return nil } -// slideWindow 是一个结构体,用于表示一个滑动窗口。 -// slideWindow is a struct, used to represent a sliding window. +// slideWindow 滑动窗口 +// sliding window type slideWindow struct { - // enabled 是一个布尔值,表示滑动窗口是否启用。 - // enabled is a boolean value, indicating whether the sliding window is enabled. enabled bool - - // dict 是一个字节切片,用于存储滑动窗口的数据。 - // dict is a byte slice, used to store the data of the sliding window. - dict []byte - - // size 是一个整数,表示滑动窗口的大小。 - // size is an integer, representing the size of the sliding window. - size int + dict []byte + size int } -// initialize 是 slideWindow 结构体的一个方法,用于初始化滑动窗口。 -// initialize is a method of the slideWindow struct, used to initialize the sliding window. +// initialize 初始化滑动窗口 +// initialize the sliding window func (c *slideWindow) initialize(pool *internal.Pool[[]byte], windowBits int) *slideWindow { - // 启用滑动窗口 - // Enable the sliding window c.enabled = true - - // 设置滑动窗口的大小为 2 的 windowBits 次方 - // Set the size of the sliding window to 2 to the power of windowBits c.size = internal.BinaryPow(windowBits) - if pool != nil { - // 如果池不为空,从池中获取一个字节切片,并设置其长度为 0 - // If the pool is not empty, get a byte slice from the pool and set its length to 0 c.dict = pool.Get()[:0] } else { - // 如果池为空,创建一个新的字节切片,长度为 0,容量为滑动窗口的大小 - // If the pool is empty, create a new byte slice with a length of 0 and a capacity of the size of the sliding window c.dict = make([]byte, 0, c.size) } - - // 返回初始化后的滑动窗口对象 - // Return the initialized sliding window object return c } -// Write 是 slideWindow 结构体的一个方法,用于将数据写入滑动窗口。 -// Write is a method of the slideWindow struct, used to write data to the sliding window. +// Write 将数据写入滑动窗口 +// write data to the sliding window func (c *slideWindow) Write(p []byte) (int, error) { - // 如果滑动窗口未启用,返回 0 和 nil - // If the sliding window is not enabled, return 0 and nil if !c.enabled { return 0, nil } - // 获取 p 的长度 - // Get the length of p var total = len(p) - - // n 是待写入的数据长度 - // n is the length of the data to be written var n = total - - // 获取滑动窗口的长度 - // Get the length of the sliding window var length = len(c.dict) - - // 如果待写入的数据长度加上滑动窗口的长度小于等于滑动窗口的大小 - // If the length of the data to be written plus the length of the sliding window is less than or equal to the size of the sliding window if n+length <= c.size { - // 将 p 添加到滑动窗口的末尾 - // Add p to the end of the sliding window c.dict = append(c.dict, p...) - - // 返回 p 的长度和 nil - // Return the length of p and nil return total, nil } - // 如果滑动窗口的大小减去滑动窗口的长度大于 0 - // If the size of the sliding window minus the length of the sliding window is greater than 0 if m := c.size - length; m > 0 { - // 将 p 的前 m 个元素添加到滑动窗口的末尾 - // Add the first m elements of p to the end of the sliding window c.dict = append(c.dict, p[:m]...) - - // 将 p 的前 m 个元素删除 - // Delete the first m elements of p p = p[m:] - - // 更新待写入的数据长度 - // Update the length of the data to be written n = len(p) } - // 如果待写入的数据长度大于等于滑动窗口的大小 - // If the length of the data to be written is greater than or equal to the size of the sliding window if n >= c.size { - // 将 p 的后 c.size 个元素复制到滑动窗口 - // Copy the last c.size elements of p to the sliding window copy(c.dict, p[n-c.size:]) - - // 返回 p 的长度和 nil - // Return the length of p and nil return total, nil } - // 将滑动窗口的后 n 个元素复制到滑动窗口的前面 - // Copy the last n elements of the sliding window to the front of the sliding window copy(c.dict, c.dict[n:]) - - // 将 p 复制到滑动窗口的后面 - // Copy p to the back of the sliding window copy(c.dict[c.size-n:], p) - - // 返回 p 的长度和 nil - // Return the length of p and nil return total, nil } -// genRequestHeader 是 PermessageDeflate 结构体的一个方法,用于生成请求头。 -// genRequestHeader is a method of the PermessageDeflate struct, used to generate request headers. +// genRequestHeader 生成请求头 +// generate request headers func (c *PermessageDeflate) genRequestHeader() string { - // 创建一个字符串切片,长度为 0,容量为 5 - // Create a string slice with a length of 0 and a capacity of 5 var options = make([]string, 0, 5) - - // 将 PermessageDeflate 添加到 options - // Add PermessageDeflate to options options = append(options, internal.PermessageDeflate) - - // 如果 ServerContextTakeover 为 false - // If ServerContextTakeover is false if !c.ServerContextTakeover { - // 将 ServerNoContextTakeover 添加到 options - // Add ServerNoContextTakeover to options options = append(options, internal.ServerNoContextTakeover) } - - // 如果 ClientContextTakeover 为 false - // If ClientContextTakeover is false if !c.ClientContextTakeover { - // 将 ClientNoContextTakeover 添加到 options - // Add ClientNoContextTakeover to options options = append(options, internal.ClientNoContextTakeover) } - - // 如果 ServerMaxWindowBits 不等于 15 - // If ServerMaxWindowBits is not equal to 15 if c.ServerMaxWindowBits != 15 { - // 将 ServerMaxWindowBits 和其值添加到 options - // Add ServerMaxWindowBits and its value to options options = append(options, internal.ServerMaxWindowBits+internal.EQ+strconv.Itoa(c.ServerMaxWindowBits)) } - - // 如果 ClientMaxWindowBits 不等于 15 - // If ClientMaxWindowBits is not equal to 15 if c.ClientMaxWindowBits != 15 { - // 将 ClientMaxWindowBits 和其值添加到 options - // Add ClientMaxWindowBits and its value to options options = append(options, internal.ClientMaxWindowBits+internal.EQ+strconv.Itoa(c.ClientMaxWindowBits)) } else if c.ClientContextTakeover { - // 如果 ClientContextTakeover 为 true - // If ClientContextTakeover is true - // 将 ClientMaxWindowBits 添加到 options - // Add ClientMaxWindowBits to options options = append(options, internal.ClientMaxWindowBits) } - - // 使用 "; " 将 options 中的所有元素连接成一个字符串,并返回 - // Join all elements in options into a string using "; " and return return strings.Join(options, "; ") } -// genResponseHeader 是 PermessageDeflate 结构体的一个方法,用于生成响应头。 -// genResponseHeader is a method of the PermessageDeflate struct, used to generate response headers. +// genResponseHeader 生成响应头 +// generate response headers func (c *PermessageDeflate) genResponseHeader() string { - // 创建一个字符串切片,长度为 0,容量为 5 - // Create a string slice with a length of 0 and a capacity of 5 var options = make([]string, 0, 5) - - // 将 PermessageDeflate 添加到 options - // Add PermessageDeflate to options options = append(options, internal.PermessageDeflate) - - // 如果 ServerContextTakeover 为 false - // If ServerContextTakeover is false if !c.ServerContextTakeover { - // 将 ServerNoContextTakeover 添加到 options - // Add ServerNoContextTakeover to options options = append(options, internal.ServerNoContextTakeover) } - - // 如果 ClientContextTakeover 为 false - // If ClientContextTakeover is false if !c.ClientContextTakeover { - // 将 ClientNoContextTakeover 添加到 options - // Add ClientNoContextTakeover to options options = append(options, internal.ClientNoContextTakeover) } - - // 如果 ServerMaxWindowBits 不等于 15 - // If ServerMaxWindowBits is not equal to 15 if c.ServerMaxWindowBits != 15 { - // 将 ServerMaxWindowBits 和其值添加到 options - // Add ServerMaxWindowBits and its value to options options = append(options, internal.ServerMaxWindowBits+internal.EQ+strconv.Itoa(c.ServerMaxWindowBits)) } - - // 如果 ClientMaxWindowBits 不等于 15 - // If ClientMaxWindowBits is not equal to 15 if c.ClientMaxWindowBits != 15 { - // 将 ClientMaxWindowBits 和其值添加到 options - // Add ClientMaxWindowBits and its value to options options = append(options, internal.ClientMaxWindowBits+internal.EQ+strconv.Itoa(c.ClientMaxWindowBits)) } - - // 使用 "; " 将 options 中的所有元素连接成一个字符串,并返回 - // Join all elements in options into a string using "; " and return return strings.Join(options, "; ") } -// permessageNegotiation 是一个函数,用于解析 permessage-deflate 扩展头。 -// permessageNegotiation is a function used to parse the permessage-deflate extension header. +// permessageNegotiation 压缩拓展协商 +// Negotiation of compression parameters func permessageNegotiation(str string) PermessageDeflate { - // 创建一个 PermessageDeflate 结构体 options,并初始化其属性。 - // Create a PermessageDeflate struct options and initialize its properties. var options = PermessageDeflate{ - - // ServerContextTakeover 属性设置为 true,表示服务器可以接管上下文。 - // The ServerContextTakeover property is set to true, indicating that the server can take over the context. ServerContextTakeover: true, - - // ClientContextTakeover 属性设置为 true,表示客户端可以接管上下文。 - // The ClientContextTakeover property is set to true, indicating that the client can take over the context. ClientContextTakeover: true, - - // ServerMaxWindowBits 属性设置为 15,表示服务器的最大窗口位数为 15。 - // The ServerMaxWindowBits property is set to 15, indicating that the maximum window bits for the server is 15. - ServerMaxWindowBits: 15, - - // ClientMaxWindowBits 属性设置为 15,表示客户端的最大窗口位数为 15。 - // The ClientMaxWindowBits property is set to 15, indicating that the maximum window bits for the client is 15. - ClientMaxWindowBits: 15, + ServerMaxWindowBits: 15, + ClientMaxWindowBits: 15, } - // 将 str 以 ";" 为分隔符进行分割,得到一个字符串切片 ss - // Split the string str by ";" to get a string slice ss var ss = internal.Split(str, ";") - - // 遍历 ss 中的每一个字符串 s - // Iterate over each string s in ss for _, s := range ss { - - // 将 s 以 "=" 为分隔符进行分割,得到一个字符串切片 pair - // Split the string s by "=" to get a string slice pair var pair = strings.SplitN(s, "=", 2) - - // 根据 pair[0] 的值进行判断 - // Judge based on the value of pair[0] switch pair[0] { - - // 如果 pair[0] 的值为 PermessageDeflate 或者 ServerNoContextTakeover,则将 options 的 ServerContextTakeover 属性设置为 false - // If the value of pair[0] is PermessageDeflate or ServerNoContextTakeover, set the ServerContextTakeover property of options to false case internal.PermessageDeflate: case internal.ServerNoContextTakeover: options.ServerContextTakeover = false - - // 如果 pair[0] 的值为 ClientNoContextTakeover,则将 options 的 ClientContextTakeover 属性设置为 false - // If the value of pair[0] is ClientNoContextTakeover, set the ClientContextTakeover property of options to false case internal.ClientNoContextTakeover: options.ClientContextTakeover = false - - // 如果 pair[0] 的值为 ServerMaxWindowBits - // If the value of pair[0] is ServerMaxWindowBits case internal.ServerMaxWindowBits: - // 如果 pair 的长度为 2 - // If the length of pair is 2 if len(pair) == 2 { - // 将 pair[1] 转换为整数 x - // Convert pair[1] to integer x x, _ := strconv.Atoi(pair[1]) - - // 如果 x 为 0,则将 x 设置为 15 - // If x is 0, set x to 15 x = internal.WithDefault(x, 15) - - // 将 options 的 ServerMaxWindowBits 属性设置为 options 的 ServerMaxWindowBits 属性和 x 中的较小值 - // Set the ServerMaxWindowBits property of options to the smaller of the ServerMaxWindowBits property of options and x options.ServerMaxWindowBits = internal.Min(options.ServerMaxWindowBits, x) } - - // 如果 pair[0] 的值为 ClientMaxWindowBits - // If the value of pair[0] is ClientMaxWindowBits case internal.ClientMaxWindowBits: - // 如果 pair 的长度为 2 - // If the length of pair is 2 if len(pair) == 2 { - // 将 pair[1] 转换为整数 x - // Convert pair[1] to integer x x, _ := strconv.Atoi(pair[1]) - - // 如果 x 为 0,则将 x 设置为 15 - // If x is 0, set x to 15 x = internal.WithDefault(x, 15) - - // 将 options 的 ClientMaxWindowBits 属性设置为 options 的 ClientMaxWindowBits 属性和 x 中的较小值 - // Set the ClientMaxWindowBits property of options to the smaller of the ClientMaxWindowBits property of options and x options.ClientMaxWindowBits = internal.Min(options.ClientMaxWindowBits, x) } } } - // 如果 options.ClientMaxWindowBits 小于 8,那么将 options.ClientMaxWindowBits 设置为 8,否则保持不变。 - // If options.ClientMaxWindowBits is less than 8, then set options.ClientMaxWindowBits to 8, otherwise keep it unchanged. options.ClientMaxWindowBits = internal.SelectValue(options.ClientMaxWindowBits < 8, 8, options.ClientMaxWindowBits) - - // 如果 options.ServerMaxWindowBits 小于 8,那么将 options.ServerMaxWindowBits 设置为 8,否则保持不变。 - // If options.ServerMaxWindowBits is less than 8, then set options.ServerMaxWindowBits to 8, otherwise keep it unchanged. options.ServerMaxWindowBits = internal.SelectValue(options.ServerMaxWindowBits < 8, 8, options.ServerMaxWindowBits) - - // 返回 options 结构体 - // Return the options struct return options } -// limitReader 是一个函数,接收一个 io.Reader 和一个限制值,返回一个新的 limitedReader -// limitReader is a function that takes an io.Reader and a limit, and returns a new limitedReader -func limitReader(r io.Reader, limit int) io.Reader { return &limitedReader{R: r, M: limit} } +// limitReader 限制从io.Reader中最多读取m个字节 +// Limit reading up to m bytes from io.Reader +func limitReader(r io.Reader, m int) io.Reader { return &limitedReader{R: r, M: m} } -// limitedReader 是一个结构体,包含一个 io.Reader 和两个整数 N 和 M。 -// limitedReader is a struct that contains an io.Reader and two integers N and M. type limitedReader struct { - // R 是一个 io.Reader,它是一个接口,用于读取数据。 - // R is an io.Reader, which is an interface used for reading data. R io.Reader - - // N 是一个整数,用于表示限制读取的字节数。 - // N is an integer, used to represent the number of bytes to limit the read. N int - - // M 是一个整数,用于表示读取的最大字节数。 - // M is an integer, used to represent the maximum number of bytes to read. M int } -// Read 是 limitedReader 的一个方法,用于读取数据 -// Read is a method of limitedReader, used to read data func (c *limitedReader) Read(p []byte) (n int, err error) { - // 从 c.R 中读取数据到 p,返回读取的数据量 n 和可能的错误 err - // Read data from c.R into p, return the amount of data read n and possible error err n, err = c.R.Read(p) - - // 将读取的数据量加到 c.N - // Add the amount of data read to c.N c.N += n - - // 如果已读取的数据量超过限制 - // If the amount of data read exceeds the limit if c.N > c.M { - // 返回读取的数据量和一个错误信息 - // Return the amount of data read and an error message return n, internal.CloseMessageTooLarge } - - // 返回读取的数据量和可能的错误 - // Return the amount of data read and possible error return } diff --git a/conn.go b/conn.go index bab046ae..77087013 100644 --- a/conn.go +++ b/conn.go @@ -89,51 +89,35 @@ type Conn struct { pd PermessageDeflate } -// ReadLoop 循环读取消息. 如果复用了HTTP Server, 建议开启goroutine, 阻塞会导致请求上下文无法被GC. +// ReadLoop +// 循环读取消息. 如果复用了HTTP Server, 建议开启goroutine, 阻塞会导致请求上下文无法被GC. // Read messages in a loop. // If HTTP Server is reused, it is recommended to enable goroutine, as blocking will prevent the context from being GC. func (c *Conn) ReadLoop() { - // 触发连接打开事件 - // Trigger the connection open event c.handler.OnOpen(c) - // 无限循环读取消息 - // Infinite loop to read messages + // 无限循环读取消息, 如果发生错误则触发错误事件并退出循环 + // Infinite loop to read messages, if an error occurs, trigger the error event and exit the loop for { - // 读取消息,如果发生错误则触发错误事件并退出循环 - // Read message, if an error occurs, trigger the error event and exit the loop if err := c.readMessage(); err != nil { c.emitError(err) break } } - // 从原子值中加载错误 - // Load error from atomic value err, ok := c.err.Load().(error) - - // 触发连接关闭事件 - // Trigger the connection close event c.handler.OnClose(c, internal.SelectValue(ok, err, errEmpty)) // 回收资源 // Reclaim resources if c.isServer { - // 重置缓冲读取器并放回缓冲池 - // Reset buffered reader and put it back to the buffer pool c.br.Reset(nil) c.config.brPool.Put(c.br) c.br = nil - - // 如果压缩接收窗口启用,放回压缩字典池 - // If compression receive window is enabled, put the compression dictionary back to the pool if c.cpsWindow.enabled { c.config.cswPool.Put(c.cpsWindow.dict) c.cpsWindow.dict = nil } - - // 如果压缩发送窗口启用,放回压缩字典池 - // If compression send window is enabled, put the compression dictionary back to the pool if c.dpsWindow.enabled { c.config.dswPool.Put(c.dpsWindow.dict) c.dpsWindow.dict = nil @@ -141,8 +125,8 @@ func (c *Conn) ReadLoop() { } } -// getCpsDict 返回用于压缩接收窗口的字典 -// getCpsDict returns the dictionary for the compression receive window +// 获取压缩字典 +// Get compressed dictionary func (c *Conn) getCpsDict(isBroadcast bool) []byte { // 广播模式必须保证每一帧都是相同的内容, 所以不使用上下文接管优化压缩率 // In broadcast mode, each frame must be the same content, so context takeover is not used to optimize compression ratio @@ -150,86 +134,65 @@ func (c *Conn) getCpsDict(isBroadcast bool) []byte { return nil } - // 如果是服务器并且服务器上下文接管启用,返回压缩接收窗口的字典 - // If it is a server and server context takeover is enabled, return the dictionary for the compression receive window + // 如果是服务器并且服务器上下文接管启用,返回压缩字典 + // If it is a server and server context takeover is enabled, return the compression dictionary if c.isServer && c.pd.ServerContextTakeover { return c.cpsWindow.dict } - // 如果不是服务器并且客户端上下文接管启用,返回压缩接收窗口的字典 - // If it is not a server and client context takeover is enabled, return the dictionary for the compression receive window + // 如果是客户端并且客户端上下文接管启用,返回压缩字典 + // If client-side and client context takeover is enabled, return the compression dictionary if !c.isServer && c.pd.ClientContextTakeover { return c.cpsWindow.dict } - // 否则返回 nil - // Otherwise, return nil return nil } -// getDpsDict 返回用于压缩发送窗口的字典 -// getDpsDict returns the dictionary for the compression send window +// 获取解压字典 +// Get decompression dictionary func (c *Conn) getDpsDict() []byte { - // 如果是服务器并且客户端上下文接管启用,返回压缩发送窗口的字典 - // If it is a server and client context takeover is enabled, return the dictionary for the compression send window + // 如果是服务器并且客户端上下文接管启用,返回解压字典 + // If it is a server and client context takeover is enabled, return the decompression dictionary if c.isServer && c.pd.ClientContextTakeover { return c.dpsWindow.dict } - // 如果不是服务器并且服务器上下文接管启用,返回压缩发送窗口的字典 - // If it is not a server and server context takeover is enabled, return the dictionary for the compression send window + // 如果是客户端并且服务器上下文接管启用,返回解压字典 + // If it is a client and server context takeover is enabled, return the decompressed dictionary if !c.isServer && c.pd.ServerContextTakeover { return c.dpsWindow.dict } - // 否则返回 nil - // Otherwise, return nil return nil } -// isTextValid 检查文本数据的有效性 -// isTextValid checks the validity of text data +// UTF8编码检查 +// UTF8 encoding check func (c *Conn) isTextValid(opcode Opcode, payload []byte) bool { - // 如果配置启用了 UTF-8 检查 - // If the configuration has UTF-8 check enabled if c.config.CheckUtf8Enabled { - // 检查编码是否有效 - // Check if the encoding is valid return internal.CheckEncoding(uint8(opcode), payload) } - - // 如果未启用 UTF-8 检查,始终返回 true - // If UTF-8 check is not enabled, always return true return true } -// isClosed 检查连接是否已关闭 -// isClosed checks if the connection is closed +// 检查连接是否已关闭 +// Checks if the connection is closed func (c *Conn) isClosed() bool { return atomic.LoadUint32(&c.closed) == 1 } -// close 关闭连接并存储错误信息 -// close closes the connection and stores the error information +// 关闭连接并存储错误信息 +// Closes the connection and stores the error information func (c *Conn) close(reason []byte, err error) { - // 存储错误信息 - // Store the error information c.err.Store(err) - - // 发送关闭连接的帧 - // Send a frame to close the connection _ = c.doWrite(OpcodeCloseConnection, internal.Bytes(reason)) - - // 关闭底层网络连接 - // Close the underlying network connection _ = c.conn.Close() } -// emitError 处理并发出错误事件 -// emitError handles and emits an error event +// 处理错误事件 +// Handle the error event func (c *Conn) emitError(err error) { - // 如果错误为空,直接返回 - // If the error is nil, return immediately if err == nil { return } @@ -237,8 +200,6 @@ func (c *Conn) emitError(err error) { // 使用原子操作检查并设置连接的关闭状态 // Use atomic operation to check and set the closed state of the connection if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - // 初始化响应代码和响应错误 - // Initialize response code and response error var responseCode = internal.CloseNormalClosure var responseErr error = internal.CloseNormalClosure @@ -246,22 +207,14 @@ func (c *Conn) emitError(err error) { // Set response code and response error based on the error type switch v := err.(type) { case internal.StatusCode: - // 如果错误类型是 internal.StatusCode,设置响应代码为该状态码 - // If the error type is internal.StatusCode, set the response code to this status code responseCode = v case *internal.Error: - // 如果错误类型是 *internal.Error,设置响应代码为该错误的状态码,并设置响应错误为该错误的错误信息 - // If the error type is *internal.Error, set the response code to the status code of this error and set the response error to the error message of this error responseCode = v.Code responseErr = v.Err default: - // 对于其他类型的错误,直接设置响应错误为该错误 - // For other types of errors, directly set the response error to this error responseErr = err } - // 将响应代码转换为字节切片并附加错误信息 - // Convert response code to byte slice and append error message var content = responseCode.Bytes() content = append(content, err.Error()...) @@ -271,192 +224,114 @@ func (c *Conn) emitError(err error) { content = content[:internal.ThresholdV1] } - // 关闭连接并传递内容和响应错误 - // Close the connection and pass the content and response error c.close(content, responseErr) } } -// emitClose 处理关闭帧并关闭连接 -// emitClose handles the close frame and closes the connection +// 处理关闭事件 +// Handles the close event func (c *Conn) emitClose(buf *bytes.Buffer) error { - // 默认响应代码为正常关闭 - // Default response code is normal closure var responseCode = internal.CloseNormalClosure - - // 默认实际代码为正常关闭的 Uint16 值 - // Default real code is the Uint16 value of normal closure var realCode = internal.CloseNormalClosure.Uint16() - - // 根据缓冲区长度设置响应代码和实际代码 - // Set response code and real code based on buffer length switch buf.Len() { case 0: - // 如果缓冲区长度为 0,设置响应代码和实际代码为 0 - // If the buffer length is 0, set the response code and the actual code to 0 responseCode = 0 realCode = 0 - case 1: - // 如果缓冲区长度为 1,设置响应代码为协议错误,并将缓冲区的第一个字节作为实际代码 - // If the buffer length is 1, set the response code to protocol error and use the first byte of the buffer as the actual code responseCode = internal.CloseProtocolError realCode = uint16(buf.Bytes()[0]) - - // 重置缓冲区 - // Reset the buffer buf.Reset() - default: - // 如果缓冲区长度大于 1,读取前两个字节作为实际代码 - // If the buffer length is greater than 1, read the first two bytes as the actual code var b [2]byte - - // 从缓冲区读取两个字节到 b - // Read two bytes from the buffer into b _, _ = buf.Read(b[0:]) - - // 将 b 的两个字节解释为大端序的 uint16,并赋值给 realCode - // Interpret the two bytes of b as a big-endian uint16 and assign it to realCode realCode = binary.BigEndian.Uint16(b[0:]) - - // 根据实际代码设置响应代码 - // Set response code based on the real code switch realCode { case 1004, 1005, 1006, 1014, 1015: - // 这些代码表示协议错误 - // These codes indicate protocol errors responseCode = internal.CloseProtocolError default: - // 检查实际代码是否在有效范围内 - // Check if the real code is within a valid range if realCode < 1000 || realCode >= 5000 || (realCode >= 1016 && realCode < 3000) { - // 如果实际代码小于 1000 或大于等于 5000,或者在 1016 和 3000 之间,设置响应代码为协议错误 - // If the real code is less than 1000 or greater than or equal to 5000, or between 1016 and 3000, set the response code to protocol error responseCode = internal.CloseProtocolError } else if realCode < 1016 { - // 如果实际代码小于 1016,设置响应代码为正常关闭 - // If the real code is less than 1016, set the response code to normal closure responseCode = internal.CloseNormalClosure } else { - // 否则,将实际代码转换为状态码并设置为响应代码 - // Otherwise, convert the real code to a status code and set it as the response code responseCode = internal.StatusCode(realCode) } } - - // 检查文本数据的有效性 - // Check the validity of text data if !c.isTextValid(OpcodeCloseConnection, buf.Bytes()) { responseCode = internal.CloseUnsupportedData } } - - // 如果连接未关闭,关闭连接并存储错误信息 - // If the connection is not closed, close the connection and store the error information if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { c.close(responseCode.Bytes(), &CloseError{Code: realCode, Reason: buf.Bytes()}) } - - // 返回正常关闭状态码 - // Return normal closure status code return internal.CloseNormalClosure } // SetDeadline 设置连接的截止时间 -// SetDeadline sets the deadline for the connection +// Sets the deadline for the connection func (c *Conn) SetDeadline(t time.Time) error { - // 设置底层连接的截止时间 - // Set the deadline for the underlying connection err := c.conn.SetDeadline(t) - - // 触发错误处理 - // Emit error handling c.emitError(err) - - // 返回错误信息 - // Return the error return err } // SetReadDeadline 设置读取操作的截止时间 -// SetReadDeadline sets the deadline for read operations +// Sets the deadline for read operations func (c *Conn) SetReadDeadline(t time.Time) error { - // 设置底层连接的读取截止时间 - // Set the read deadline for the underlying connection err := c.conn.SetReadDeadline(t) - - // 触发错误处理 - // Emit error handling c.emitError(err) - - // 返回错误信息 - // Return the error return err } // SetWriteDeadline 设置写入操作的截止时间 -// SetWriteDeadline sets the deadline for write operations +// Sets the deadline for write operations func (c *Conn) SetWriteDeadline(t time.Time) error { - // 设置底层连接的写入截止时间 - // Set the write deadline for the underlying connection err := c.conn.SetWriteDeadline(t) - - // 触发错误处理 - // Emit error handling c.emitError(err) - - // 返回错误信息 - // Return the error return err } // LocalAddr 返回本地网络地址 -// LocalAddr returns the local network address +// Returns the local network address func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } // RemoteAddr 返回远程网络地址 -// RemoteAddr returns the remote network address +// Returns the remote network address func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } -// NetConn 获取底层的 TCP/TLS/KCP 等连接 -// NetConn gets the underlying TCP/TLS/KCP... connection +// NetConn +// 获取底层的 TCP/TLS/KCP 等连接 +// Gets the underlying TCP/TLS/KCP... connection func (c *Conn) NetConn() net.Conn { return c.conn } -// SetNoDelay 控制操作系统是否应该延迟数据包传输以期望发送更少的数据包(Nagle 算法)。 +// SetNoDelay +// 控制操作系统是否应该延迟数据包传输以期望发送更少的数据包(Nagle 算法)。 // 默认值是 true(无延迟),这意味着数据在 Write 之后尽快发送。 -// SetNoDelay controls whether the operating system should delay packet transmission in hopes of sending fewer packets (Nagle's algorithm). +// Controls whether the operating system should delay packet transmission in hopes of sending fewer packets (Nagle's algorithm). // The default is true (no delay), meaning that data is sent as soon as possible after a Write. func (c *Conn) SetNoDelay(noDelay bool) error { switch v := c.conn.(type) { case *net.TCPConn: - // 如果底层连接是 TCP 连接,设置无延迟选项 - // If the underlying connection is a TCP connection, set the no delay option return v.SetNoDelay(noDelay) case *tls.Conn: - // 如果底层连接是 TLS 连接,获取其底层的 TCP 连接并设置无延迟选项 - // If the underlying connection is a TLS connection, get its underlying TCP connection and set the no delay option if netConn, ok := v.NetConn().(*net.TCPConn); ok { return netConn.SetNoDelay(noDelay) } } - - // 如果不是 TCP 或 TLS 连接,返回 nil - // If it is not a TCP or TLS connection, return nil return nil } // SubProtocol 获取协商的子协议 -// SubProtocol gets the negotiated sub-protocol +// Gets the negotiated sub-protocol func (c *Conn) SubProtocol() string { return c.subprotocol } // Session 获取会话存储 -// Session gets the session storage +// Gets the session storage func (c *Conn) Session() SessionStorage { return c.ss } diff --git a/init.go b/init.go index b0a492d9..a47b9963 100644 --- a/init.go +++ b/init.go @@ -3,17 +3,7 @@ package gws import "github.com/lxzan/gws/internal" var ( - // framePadding 用于填充帧头 - // framePadding is used to pad the frame header - framePadding = frameHeader{} - - // binaryPool 是一个缓冲池,用于管理二进制数据缓冲区 - // binaryPool is a buffer pool used to manage binary data buffers - // 参数 128 表示缓冲区的初始大小,256*1024 表示缓冲区的最大大小 - // The parameter 128 represents the initial size of the buffer, and 256*1024 represents the maximum size of the buffer - binaryPool = internal.NewBufferPool(128, 256*1024) - - // defaultLogger 是默认的日志工具 - // defaultLogger is the default logging tool - defaultLogger = new(stdLogger) + framePadding = frameHeader{} // 帧头填充物 + binaryPool = internal.NewBufferPool(128, 256*1024) // 内存池 + defaultLogger = new(stdLogger) // 默认日志工具 ) diff --git a/internal/deque.go b/internal/deque.go index 8abfd340..57ee3210 100644 --- a/internal/deque.go +++ b/internal/deque.go @@ -1,65 +1,46 @@ package internal -// Nil 常量表示空指针 -// Nil constant represents a null pointer const Nil = 0 type ( - // Pointer 类型表示一个无符号 32 位整数,用于指向元素的位置 - // Pointer type represents an unsigned 32-bit integer used to point to the position of an element Pointer uint32 - // Element 结构体表示双端队列中的一个元素 - // Element struct represents an element in the deque Element[T any] struct { - // prev 指向前一个元素的位置 - // prev points to the position of the previous element - prev Pointer - - // addr 指向当前元素的位置 - // addr points to the position of the current element - addr Pointer - - // next 指向下一个元素的位置 - // next points to the position of the next element - next Pointer - - // value 存储元素的值 - // value stores the value of the element - value T + prev, addr, next Pointer + value T } // Deque 结构体表示一个双端队列 // Deque struct represents a double-ended queue Deque[T any] struct { // head 指向队列头部元素的位置 - // head points to the position of the head element in the queue + // points to the position of the head element in the queue head Pointer // tail 指向队列尾部元素的位置 - // tail points to the position of the tail element in the queue + // points to the position of the tail element in the queue tail Pointer - // length 表示队列的长度 - // length represents the length of the queue + // length 队列长度 + // length of the queue length int - // stack 用于存储空闲位置的栈 - // stack is used to store the stack of free positions + // stack 存储空闲位置的栈 + // store the stack of free positions stack Stack[Pointer] // elements 存储队列中的所有元素 - // elements stores all the elements in the queue + // stores all the elements in the queue elements []Element[T] - // template 用于创建新元素的模板 - // template is used as a template for creating new elements + // template 创建新元素的模板 + // template for creating new elements template Element[T] } ) // IsNil 检查指针是否为空 -// IsNil checks if the pointer is null +// checks if the pointer is null func (c Pointer) IsNil() bool { return c == Nil } @@ -71,628 +52,330 @@ func (c *Element[T]) Addr() Pointer { } // Next 返回下一个元素的地址 -// Next returns the address of the next element +// returns the address of the next element func (c *Element[T]) Next() Pointer { return c.next } // Prev 返回前一个元素的地址 -// Prev returns the address of the previous element +// returns the address of the previous element func (c *Element[T]) Prev() Pointer { return c.prev } // Value 返回元素的值 -// Value returns the value of the element +// returns the value of the element func (c *Element[T]) Value() T { return c.value } // New 创建双端队列 -// New creates a double-ended queue +// creates a double-ended queue func New[T any](capacity int) *Deque[T] { - // 初始化 Deque 结构体,elements 切片的容量为 1 + capacity - // Initialize the Deque struct, with the capacity of the elements slice set to 1 + capacity return &Deque[T]{elements: make([]Element[T], 1, 1+capacity)} } // Get 根据地址获取元素 -// Get retrieves an element based on its address +// retrieves an element based on its address func (c *Deque[T]) Get(addr Pointer) *Element[T] { - // 如果地址大于 0,返回对应地址的元素 - // If the address is greater than 0, return the element at that address if addr > 0 { return &(c.elements[addr]) } - - // 否则返回 nil - // Otherwise, return nil return nil } // getElement 追加元素一定要先调用此方法, 因为追加可能会造成扩容, 地址发生变化!!! -// getElement must be called before appending elements, as appending may cause reallocation and address changes!!! +// must be called before appending elements, as appending may cause reallocation and address changes!!! func (c *Deque[T]) getElement() *Element[T] { - // 如果 elements 切片为空,追加一个模板元素 - // If the elements slice is empty, append a template element if len(c.elements) == 0 { c.elements = append(c.elements, c.template) } - // 如果 stack 中有空闲地址,从 stack 中弹出一个地址并返回对应的元素 - // If there are free addresses in the stack, pop an address from the stack and return the corresponding element if c.stack.Len() > 0 { - // 从 stack 中弹出一个空闲地址 - // Pop a free address from the stack addr := c.stack.Pop() - - // 获取该地址对应的元素 - // Get the element corresponding to that address v := c.Get(addr) - - // 设置元素的地址 - // Set the address of the element v.addr = addr - - // 返回该元素 - // Return the element return v } - // 否则 stack 中没有空闲地址,计算新元素的地址 - // Otherwise, there are no free addresses in the stack, calculate the address of the new element addr := Pointer(len(c.elements)) - - // 将模板元素追加到 elements 列表中 - // Append the template element to the elements list c.elements = append(c.elements, c.template) - - // 获取新元素 - // Get the new element v := c.Get(addr) - - // 设置新元素的地址 - // Set the address of the new element v.addr = addr - - // 返回新元素 - // Return the new element return v } -// putElement 将元素放回空闲栈中,并重置元素内容 -// putElement puts the element back into the free stack and resets the element's content func (c *Deque[T]) putElement(ele *Element[T]) { - // 将元素的地址压入空闲栈中 - // Push the element's address into the free stack c.stack.Push(ele.addr) - - // 将元素重置为模板元素 - // Reset the element to the template element *ele = c.template } // Reset 重置双端队列 -// Reset resets the deque +// resets the deque func (c *Deque[T]) Reset() { - // 调用内部方法 autoReset 进行重置 - // Call the internal method autoReset to reset c.autoReset() } -// autoReset 内部方法,重置双端队列的状态 -// autoReset is an internal method that resets the state of the deque +// autoReset 重置双端队列的状态 +// resets the state of the deque func (c *Deque[T]) autoReset() { - // 重置头部、尾部指针和长度 - // Reset the head, tail pointers, and length c.head, c.tail, c.length = Nil, Nil, 0 - - // 清空空闲栈 - // Clear the free stack c.stack = c.stack[:0] - - // 保留 elements 列表中的第一个元素,清空其他元素 - // Keep the first element in the elements list, clear other elements c.elements = c.elements[:1] } // Len 返回双端队列的长度 -// Len returns the length of the deque +// returns the length of the deque func (c *Deque[T]) Len() int { return c.length } // Front 返回队列头部的元素 -// Front returns the element at the front of the queue +// returns the element at the front of the queue func (c *Deque[T]) Front() *Element[T] { return c.Get(c.head) } // Back 返回队列尾部的元素 -// Back returns the element at the back of the queue +// returns the element at the back of the queue func (c *Deque[T]) Back() *Element[T] { return c.Get(c.tail) } // PushFront 将一个元素添加到队列的头部 -// PushFront adds an element to the front of the deque +// adds an element to the front of the deque func (c *Deque[T]) PushFront(value T) *Element[T] { - // 获取一个空闲的元素 - // Get a free element ele := c.getElement() - - // 设置元素的值 - // Set the value of the element ele.value = value - - // 执行将元素推到队列头部的操作 - // Perform the operation to push the element to the front of the deque c.doPushFront(ele) - - // 返回该元素 - // Return the element return ele } -// doPushFront 执行将元素推到队列头部的操作 -// doPushFront performs the operation to push the element to the front of the deque func (c *Deque[T]) doPushFront(ele *Element[T]) { - // 增加队列长度 - // Increase the length of the deque c.length++ - // 如果队列为空,设置头部和尾部指针为新元素的地址 - // If the deque is empty, set the head and tail pointers to the new element's address if c.head.IsNil() { c.head, c.tail = ele.addr, ele.addr return } - // 获取当前头部元素 - // Get the current head element head := c.Get(c.head) - - // 设置当前头部元素的前一个元素为新元素 - // Set the previous element of the current head element to the new element head.prev = ele.addr - - // 设置新元素的下一个元素为当前头部元素 - // Set the next element of the new element to the current head element ele.next = head.addr - - // 更新头部指针为新元素的地址 - // Update the head pointer to the new element's address c.head = ele.addr } // PushBack 将一个元素添加到队列的尾部 -// PushBack adds an element to the back of the deque +// adds an element to the back of the deque func (c *Deque[T]) PushBack(value T) *Element[T] { - // 获取一个空闲的元素 - // Get a free element ele := c.getElement() - - // 设置元素的值 - // Set the value of the element ele.value = value - - // 执行将元素推到队列尾部的操作 - // Perform the operation to push the element to the back of the deque c.doPushBack(ele) - - // 返回该元素 - // Return the element return ele } -// doPushBack 将元素添加到队列的尾部 -// doPushBack adds an element to the back of the deque func (c *Deque[T]) doPushBack(ele *Element[T]) { - // 增加队列长度 - // Increase the length of the deque c.length++ - // 如果队列为空,设置头部和尾部指针为新元素的地址 - // If the deque is empty, set the head and tail pointers to the new element's address if c.tail.IsNil() { c.head, c.tail = ele.addr, ele.addr return } - // 获取当前尾部元素 - // Get the current tail element tail := c.Get(c.tail) - - // 设置当前尾部元素的下一个元素为新元素 - // Set the next element of the current tail element to the new element tail.next = ele.addr - - // 设置新元素的前一个元素为当前尾部元素 - // Set the previous element of the new element to the current tail element ele.prev = tail.addr - - // 更新尾部指针为新元素的地址 - // Update the tail pointer to the new element's address c.tail = ele.addr } // PopFront 从队列头部弹出一个元素并返回其值 -// PopFront pops an element from the front of the deque and returns its value +// pops an element from the front of the deque and returns its value func (c *Deque[T]) PopFront() (value T) { - // 获取队列头部的元素 - // Get the element at the front of the deque if ele := c.Front(); ele != nil { - // 获取元素的值 - // Get the value of the element value = ele.value - - // 从队列中移除该元素 - // Remove the element from the deque c.doRemove(ele) - - // 将元素放回空闲栈中 - // Put the element back into the free stack c.putElement(ele) - - // 如果队列为空,重置队列 - // If the deque is empty, reset the deque if c.length == 0 { c.autoReset() } } - - // 返回弹出的元素值 - // Return the popped element's value return value } // PopBack 从队列尾部弹出一个元素并返回其值 -// PopBack pops an element from the back of the deque and returns its value +// pops an element from the back of the deque and returns its value func (c *Deque[T]) PopBack() (value T) { - // 获取队列尾部的元素 - // Get the element at the back of the deque if ele := c.Back(); ele != nil { - // 获取元素的值 - // Get the value of the element value = ele.value - - // 从队列中移除该元素 - // Remove the element from the deque c.doRemove(ele) - - // 将元素放回空闲栈中 - // Put the element back into the free stack c.putElement(ele) - - // 如果队列为空,重置队列 - // If the deque is empty, reset the deque if c.length == 0 { c.autoReset() } } - - // 返回弹出的元素值 - // Return the popped element's value return value } // InsertAfter 在指定元素之后插入一个新元素 -// InsertAfter inserts a new element after the specified element +// inserts a new element after the specified element func (c *Deque[T]) InsertAfter(value T, mark Pointer) *Element[T] { - // 如果标记指针为空,返回 nil - // If the mark pointer is null, return nil if mark.IsNil() { return nil } - // 增加队列长度 - // Increase the length of the deque c.length++ - - // 获取一个空闲的元素 - // Get a free element e1 := c.getElement() - - // 获取标记的元素 - // Get the marked element e0 := c.Get(mark) - - // 获取标记元素的下一个元素 - // Get the next element of the marked element e2 := c.Get(e0.next) - - // 设置新元素的前一个元素、下一个元素和值 - // Set the previous element, next element, and value of the new element e1.prev, e1.next, e1.value = e0.addr, e0.next, value - // 如果下一个元素不为空,设置其前一个元素为新元素 - // If the next element is not null, set its previous element to the new element if e2 != nil { e2.prev = e1.addr } - // 设置标记元素的下一个元素为新元素 - // Set the next element of the marked element to the new element e0.next = e1.addr - - // 如果新元素是最后一个元素,更新尾部指针 - // If the new element is the last element, update the tail pointer if e1.next.IsNil() { c.tail = e1.addr } - - // 返回新插入的元素 - // Return the newly inserted element return e1 } // InsertBefore 在指定元素之前插入一个新元素 -// InsertBefore inserts a new element before the specified element +// inserts a new element before the specified element func (c *Deque[T]) InsertBefore(value T, mark Pointer) *Element[T] { - // 如果标记指针为空,返回 nil - // If the mark pointer is null, return nil if mark.IsNil() { return nil } - // 增加队列长度 - // Increase the length of the deque c.length++ - - // 获取一个空闲的元素 - // Get a free element e1 := c.getElement() - - // 获取标记的元素 - // Get the marked element e2 := c.Get(mark) - - // 获取标记元素的前一个元素 - // Get the previous element of the marked element e0 := c.Get(e2.prev) - - // 设置新元素的前一个元素、下一个元素和值 - // Set the previous element, next element, and value of the new element e1.prev, e1.next, e1.value = e2.prev, e2.addr, value - // 如果前一个元素不为空,设置其下一个元素为新元素 - // If the previous element is not null, set its next element to the new element if e0 != nil { e0.next = e1.addr } - // 设置标记元素的前一个元素为新元素 - // Set the previous element of the marked element to the new element e2.prev = e1.addr - // 如果新元素是第一个元素,更新头部指针 - // If the new element is the first element, update the head pointer if e1.prev.IsNil() { c.head = e1.addr } - - // 返回新插入的元素 - // Return the newly inserted element return e1 } // MoveToBack 将指定地址的元素移动到队列尾部 -// MoveToBack moves the element at the specified address to the back of the deque +// moves the element at the specified address to the back of the deque func (c *Deque[T]) MoveToBack(addr Pointer) { - // 获取指定地址的元素 - // Get the element at the specified address if ele := c.Get(addr); ele != nil { - // 从队列中移除该元素 - // Remove the element from the deque c.doRemove(ele) - - // 重置元素的前后指针 - // Reset the previous and next pointers of the element ele.prev, ele.next = Nil, Nil - - // 将元素推到队列尾部 - // Push the element to the back of the deque c.doPushBack(ele) } } // MoveToFront 将指定地址的元素移动到队列头部 -// MoveToFront moves the element at the specified address to the front of the deque +// moves the element at the specified address to the front of the deque func (c *Deque[T]) MoveToFront(addr Pointer) { - // 获取指定地址的元素 - // Get the element at the specified address if ele := c.Get(addr); ele != nil { - // 从队列中移除该元素 - // Remove the element from the deque c.doRemove(ele) - - // 重置元素的前后指针 - // Reset the previous and next pointers of the element ele.prev, ele.next = Nil, Nil - - // 将元素推到队列头部 - // Push the element to the front of the deque c.doPushFront(ele) } } // Update 更新指定地址的元素的值 -// Update updates the value of the element at the specified address +// updates the value of the element at the specified address func (c *Deque[T]) Update(addr Pointer, value T) { - // 获取指定地址的元素 - // Get the element at the specified address if ele := c.Get(addr); ele != nil { - // 更新元素的值 - // Update the value of the element ele.value = value } } // Remove 从队列中移除指定地址的元素 -// Remove removes the element at the specified address from the deque +// removes the element at the specified address from the deque func (c *Deque[T]) Remove(addr Pointer) { - // 获取指定地址的元素 - // Get the element at the specified address if ele := c.Get(addr); ele != nil { - // 从队列中移除该元素 - // Remove the element from the deque c.doRemove(ele) - - // 将元素放回空闲栈中 - // Put the element back into the free stack c.putElement(ele) - - // 如果队列为空,重置队列 - // If the deque is empty, reset the deque if c.length == 0 { c.autoReset() } } } -// doRemove 从队列中移除指定的元素 -// doRemove removes the specified element from the deque func (c *Deque[T]) doRemove(ele *Element[T]) { - // 初始化前后元素指针为 nil - // Initialize previous and next element pointers to nil var prev, next *Element[T] = nil, nil - - // 初始化状态为 0 - // Initialize state to 0 var state = 0 - - // 如果前一个元素不为空,获取前一个元素并更新状态 - // If the previous element is not nil, get the previous element and update the state if !ele.prev.IsNil() { - // 使用 c.Get 方法获取 ele 的前一个元素,并将其赋值给 prev - // Use the c.Get method to get the previous element of ele and assign it to prev prev = c.Get(ele.prev) - - // 将状态值 state 增加 1, 用于标记前一个元素存在 - // Increase the state value by 1, used to mark that the previous element exists state += 1 } - - // 如果下一个元素不为空,获取下一个元素并更新状态 - // If the next element is not nil, get the next element and update the state if !ele.next.IsNil() { - // 使用 c.Get 方法获取 ele 的下一个元素,并将其赋值给 next - // Use the c.Get method to get the next element of ele and assign it to next next = c.Get(ele.next) - - // 将状态值 state 增加 2, 用于标记前后元素都存在 - // Increase the state value by 2, used to mark that both previous and next elements exist state += 2 } - // 减少队列长度 - // Decrease the length of the deque c.length-- - - // 根据状态更新前后元素的指针 - // Update the pointers of the previous and next elements based on the state switch state { case 3: - // 如果前后元素都存在,更新前一个元素的 next 指针和后一个元素的 prev 指针 - // If both previous and next elements exist, update the next pointer of the previous element and the prev pointer of the next element prev.next = next.addr next.prev = prev.addr case 2: - // 如果只有后一个元素存在,更新后一个元素的 prev 指针并设置头部指针 - // If only the next element exists, update the prev pointer of the next element and set the head pointer next.prev = Nil c.head = next.addr case 1: - // 如果只有前一个元素存在,更新前一个元素的 next 指针并设置尾部指针 - // If only the previous element exists, update the next pointer of the previous element and set the tail pointer prev.next = Nil c.tail = prev.addr default: - // 如果前后元素都不存在,重置头部和尾部指针 - // If neither previous nor next elements exist, reset the head and tail pointers c.head = Nil c.tail = Nil } } // Range 遍历队列中的每个元素,并对每个元素执行给定的函数 -// Range iterates over each element in the deque and executes the given function on each element +// iterates over each element in the deque and executes the given function on each element func (c *Deque[T]) Range(f func(ele *Element[T]) bool) { - // 从队列头部开始遍历 - // Start iterating from the head of the deque for i := c.Get(c.head); i != nil; i = c.Get(i.next) { - // 如果函数返回 false,则停止遍历 - // If the function returns false, stop iterating if !f(i) { break } } } -// Clone 创建并返回队列的一个副本 -// Clone creates and returns a copy of the deque +// Clone 深拷贝 +// deep copy func (c *Deque[T]) Clone() *Deque[T] { - // 创建队列的副本 - // Create a copy of the deque var v = *c - - // 为副本分配新的元素切片 - // Allocate a new slice for the elements of the copy v.elements = make([]Element[T], len(c.elements)) - - // 为副本分配新的指针栈 - // Allocate a new slice for the stack of the copy v.stack = make([]Pointer, len(c.stack)) - - // 复制元素到副本 - // Copy the elements to the copy copy(v.elements, c.elements) - - // 复制指针栈到副本 - // Copy the stack to the copy copy(v.stack, c.stack) - - // 返回副本 - // Return the copy return &v } -// Stack 是一个泛型栈类型 -// Stack is a generic stack type +// Stack 泛型栈 +// generic stack type Stack[T any] []T // Len 获取栈中元素的数量 -// Len returns the number of elements in the stack +// returns the number of elements in the stack func (c *Stack[T]) Len() int { - // 返回栈的长度 - // Return the length of the stack return len(*c) } // Push 将元素追加到栈顶 -// Push appends an element to the top of the stack +// appends an element to the top of the stack func (c *Stack[T]) Push(v T) { - // 将元素追加到栈顶 - // Append the element to the top of the stack *c = append(*c, v) } // Pop 从栈顶弹出元素并返回其值 -// Pop removes the top element from the stack and returns its value +// removes the top element from the stack and returns its value func (c *Stack[T]) Pop() T { - // 获取栈的长度 - // Get the length of the stack n := c.Len() - - // 获取栈顶元素的值 - // Get the value of the top element value := (*c)[n-1] - - // 移除栈顶元素 - // Remove the top element from the stack *c = (*c)[:n-1] - - // 返回栈顶元素的值 - // Return the value of the top element return value } diff --git a/internal/error.go b/internal/error.go index 90b6d69f..fcfa9a89 100644 --- a/internal/error.go +++ b/internal/error.go @@ -1,7 +1,7 @@ package internal -// closeErrorMap 是一个映射,用于将状态码映射到错误信息 -// closeErrorMap is a map used to map status codes to error messages +// closeErrorMap 将状态码映射到错误信息 +// map status codes to error messages var closeErrorMap = map[StatusCode]string{ // 空状态码 // Empty status code @@ -64,8 +64,8 @@ var closeErrorMap = map[StatusCode]string{ CloseTLSHandshake: "TLS handshake error", } -// StatusCode 类型定义为一个 uint16 -// StatusCode type is defined as a uint16 +// StatusCode WebSocket错误码 +// websocket error code type StatusCode uint16 const ( @@ -112,73 +112,41 @@ const ( CloseTLSHandshake StatusCode = 1015 ) -// Uint16 将 StatusCode 转换为 uint16 -// Uint16 converts StatusCode to uint16 func (c StatusCode) Uint16() uint16 { - // 返回 StatusCode 的 uint16 表示 - // Return the uint16 representation of StatusCode return uint16(c) } -// Bytes 将 StatusCode 转换为字节切片 -// Bytes converts StatusCode to a byte slice func (c StatusCode) Bytes() []byte { - // 如果 StatusCode 为 0,返回空字节切片 - // If StatusCode is 0, return an empty byte slice if c == 0 { return []byte{} } - // 返回包含 StatusCode 高字节和低字节的字节切片 - // Return a byte slice containing the high byte and low byte of StatusCode return []byte{uint8(c >> 8), uint8(c << 8 >> 8)} } -// Error 返回 StatusCode 对应的错误字符串 -// Error returns the error string corresponding to StatusCode func (c StatusCode) Error() string { - // 返回包含错误信息的字符串 - // Return a string containing the error message return "gws: " + closeErrorMap[c] } -// NewError 创建一个新的 Error 实例 -// NewError creates a new Error instance func NewError(code StatusCode, err error) *Error { - // 返回包含指定状态码和错误的 Error 实例 - // Return an Error instance containing the specified status code and error return &Error{Code: code, Err: err} } -// Error 结构体定义了一个包含错误和状态码的错误类型 -// Error struct defines an error type containing an error and a status code type Error struct { Err error // 错误信息 Code StatusCode // 状态码 } -// Error 返回错误的字符串表示 -// Error returns the string representation of the error func (c *Error) Error() string { - // 返回错误信息的字符串表示 - // Return the string representation of the error message return c.Err.Error() } // Errors 依次执行传入的函数,返回第一个遇到的错误 -// Errors executes the passed functions in sequence and returns the first encountered error +// executes the passed functions in sequence and returns the first encountered error func Errors(funcs ...func() error) error { - // 遍历每个函数 - // Iterate over each function for _, f := range funcs { - // 执行函数并检查是否有错误 - // Execute the function and check for an error if err := f(); err != nil { - // 返回遇到的第一个错误 - // Return the first encountered error return err } } - // 如果没有遇到错误,返回 nil - // If no errors are encountered, return nil return nil } diff --git a/internal/io.go b/internal/io.go index 6597f56d..32aefa7d 100644 --- a/internal/io.go +++ b/internal/io.go @@ -6,177 +6,85 @@ import ( ) // ReadN 精准地读取 len(data) 个字节, 否则返回错误 -// ReadN reads exactly len(data) bytes, otherwise returns an error +// reads exactly len(data) bytes, otherwise returns an error func ReadN(reader io.Reader, data []byte) error { - // 使用 io.ReadFull 函数从 reader 中读取 len(data) 个字节 - // Use io.ReadFull to read len(data) bytes from the reader _, err := io.ReadFull(reader, data) - - // 返回读取过程中遇到的错误 - // Return any error encountered during reading return err } -// WriteN 将 content 写入 writer 中, 否则返回错误 -// WriteN writes the content to the writer, otherwise returns an error +// WriteN 将 content 写入 writer 中 +// writes the content to the writer func WriteN(writer io.Writer, content []byte) error { - // 使用 writer.Write 函数将 content 写入 writer 中 - // Use writer.Write to write the content to the writer _, err := writer.Write(content) - - // 返回写入过程中遇到的错误 - // Return any error encountered during writing return err } // CheckEncoding 检查 payload 的编码是否有效 -// CheckEncoding checks if the encoding of the payload is valid +// checks if the encoding of the payload is valid func CheckEncoding(opcode uint8, payload []byte) bool { - // 根据 opcode 的值进行不同的处理 - // Handle different cases based on the value of opcode switch opcode { - // 如果 opcode 的值为 1 或 8 - // If the value of opcode is 1 or 8 case 1, 8: - // 调用 utf8.Valid 函数检查 payload 是否为有效的 UTF-8 编码,如果是则返回 true,否则返回 false - // Call the utf8.Valid function to check if payload is a valid UTF-8 encoding, return true if it is, otherwise return false return utf8.Valid(payload) - - // 如果 opcode 的值为其他值 - // If the value of opcode is other values default: - // 直接返回 true - // Return true directly return true } } -// Payload 接口定义了处理负载数据的方法 -// Payload interface defines methods for handling payload data type Payload interface { - // WriterTo 接口用于将数据写入 io.Writer - // WriterTo interface is used to write data to an io.Writer io.WriterTo - - // Len 返回负载数据的长度 - // Len returns the length of the payload data Len() int - - // CheckEncoding 检查负载数据的编码是否有效 - // CheckEncoding checks if the encoding of the payload data is valid CheckEncoding(enabled bool, opcode uint8) bool } -// Buffers 类型定义为一个二维字节切片 -// Buffers type is defined as a slice of byte slices type Buffers [][]byte -// CheckEncoding 检查每个缓冲区的编码是否有效 -// CheckEncoding checks if the encoding of each buffer is valid func (b Buffers) CheckEncoding(enabled bool, opcode uint8) bool { - // 如果启用了编码检查 - // If encoding check is enabled if enabled { - // 遍历每个缓冲区 - // Iterate over each buffer for i, _ := range b { - // 如果任意一个缓冲区的编码无效,返回 false - // If any buffer's encoding is invalid, return false if !CheckEncoding(opcode, b[i]) { return false } } } - - // 如果所有缓冲区的编码都有效,返回 true - // If all buffers' encodings are valid, return true return true } -// Len 返回所有缓冲区的总长度 -// Len returns the total length of all buffers func (b Buffers) Len() int { - // 初始化总长度为 0 - // Initialize total length to 0 var sum = 0 - - // 遍历每个缓冲区 - // Iterate over each buffer for i, _ := range b { - // 累加每个缓冲区的长度 - // Accumulate the length of each buffer sum += len(b[i]) } - - // 返回总长度 - // Return the total length return sum } -// WriteTo 将所有缓冲区的数据写入指定的 io.Writer -// WriteTo writes the data of all buffers to the specified io.Writer +// WriteTo 可重复写 func (b Buffers) WriteTo(w io.Writer) (int64, error) { - // 初始化写入的总字节数为 0 - // Initialize the total number of bytes written to 0 var n = 0 - - // 遍历每个缓冲区 - // Iterate over each buffer for i, _ := range b { - // 将当前缓冲区的数据写入 io.Writer - // Write the current buffer's data to the io.Writer x, err := w.Write(b[i]) - - // 累加写入的字节数 - // Accumulate the number of bytes written n += x - - // 如果写入过程中遇到错误,返回已写入的字节数和错误 - // If an error is encountered during writing, return the number of bytes written and the error if err != nil { return int64(n), err } } - - // 返回写入的总字节数和 nil 错误 - // Return the total number of bytes written and a nil error return int64(n), nil } -// Bytes 类型定义为一个字节切片 -// Bytes type is defined as a byte slice type Bytes []byte -// CheckEncoding 检查字节切片的编码是否有效 -// CheckEncoding checks if the encoding of the byte slice is valid func (b Bytes) CheckEncoding(enabled bool, opcode uint8) bool { - // 如果启用了编码检查 - // If encoding check is enabled if enabled { - // 检查字节切片的编码是否有效 - // Check if the encoding of the byte slice is valid return CheckEncoding(opcode, b) } - - // 如果未启用编码检查,始终返回 true - // If encoding check is not enabled, always return true return true } -// Len 返回字节切片的长度 -// Len returns the length of the byte slice func (b Bytes) Len() int { return len(b) } -// WriteTo 将字节切片的数据写入指定的 io.Writer -// WriteTo writes the data of the byte slice to the specified io.Writer +// WriteTo 可重复写 func (b Bytes) WriteTo(w io.Writer) (int64, error) { - // 将字节切片的数据写入 io.Writer - // Write the byte slice's data to the io.Writer n, err := w.Write(b) - - // 返回写入的字节数和可能的错误 - // Return the number of bytes written and any potential error return int64(n), err } diff --git a/internal/others.go b/internal/others.go index 165de403..8a3a5e77 100644 --- a/internal/others.go +++ b/internal/others.go @@ -6,95 +6,39 @@ import ( ) const ( - // PermessageDeflate 表示 WebSocket 扩展 "permessage-deflate" - // PermessageDeflate represents the WebSocket extension "permessage-deflate" - PermessageDeflate = "permessage-deflate" - - // ServerMaxWindowBits 表示服务器最大窗口位数的参数 - // ServerMaxWindowBits represents the parameter for the server's maximum window bits - ServerMaxWindowBits = "server_max_window_bits" - - // ClientMaxWindowBits 表示客户端最大窗口位数的参数 - // ClientMaxWindowBits represents the parameter for the client's maximum window bits - ClientMaxWindowBits = "client_max_window_bits" - - // ServerNoContextTakeover 表示服务器不进行上下文接管的参数 - // ServerNoContextTakeover represents the parameter for the server's no context takeover + PermessageDeflate = "permessage-deflate" + ServerMaxWindowBits = "server_max_window_bits" + ClientMaxWindowBits = "client_max_window_bits" ServerNoContextTakeover = "server_no_context_takeover" - - // ClientNoContextTakeover 表示客户端不进行上下文接管的参数 - // ClientNoContextTakeover represents the parameter for the client's no context takeover ClientNoContextTakeover = "client_no_context_takeover" - - // EQ 表示等号 "=" - // EQ represents the equal sign "=" - EQ = "=" + EQ = "=" ) -// Pair 表示一个键值对 -// Pair represents a key-value pair type Pair struct { - // Key 表示键 - // Key represents the key Key string - - // Val 表示值 - // Val represents the value Val string } var ( - // SecWebSocketVersion 表示 WebSocket 版本的键值对 - // SecWebSocketVersion represents the key-value pair for WebSocket version - SecWebSocketVersion = Pair{"Sec-WebSocket-Version", "13"} - - // SecWebSocketKey 表示 WebSocket 密钥的键值对 - // SecWebSocketKey represents the key-value pair for WebSocket key - SecWebSocketKey = Pair{"Sec-WebSocket-Key", ""} - - // SecWebSocketExtensions 表示 WebSocket 扩展的键值对 - // SecWebSocketExtensions represents the key-value pair for WebSocket extensions + SecWebSocketVersion = Pair{"Sec-WebSocket-Version", "13"} + SecWebSocketKey = Pair{"Sec-WebSocket-Key", ""} SecWebSocketExtensions = Pair{"Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover"} - - // Connection 表示连接类型的键值对 - // Connection represents the key-value pair for connection type - Connection = Pair{"Connection", "Upgrade"} - - // Upgrade 表示升级协议的键值对 - // Upgrade represents the key-value pair for upgrade protocol - Upgrade = Pair{"Upgrade", "websocket"} - - // SecWebSocketAccept 表示 WebSocket 接受密钥的键值对 - // SecWebSocketAccept represents the key-value pair for WebSocket accept key - SecWebSocketAccept = Pair{"Sec-WebSocket-Accept", ""} - - // SecWebSocketProtocol 表示 WebSocket 协议的键值对 - // SecWebSocketProtocol represents the key-value pair for WebSocket protocol - SecWebSocketProtocol = Pair{"Sec-WebSocket-Protocol", ""} + Connection = Pair{"Connection", "Upgrade"} + Upgrade = Pair{"Upgrade", "websocket"} + SecWebSocketAccept = Pair{"Sec-WebSocket-Accept", ""} + SecWebSocketProtocol = Pair{"Sec-WebSocket-Protocol", ""} ) -// MagicNumber 是 WebSocket 握手过程中使用的魔术字符串 -// MagicNumber is the magic string used during the WebSocket handshake +// MagicNumber WebSocket 握手过程中使用的魔术字符串 +// the magic string used during the WebSocket handshake const MagicNumber = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" const ( - // ThresholdV1 是第一个版本的阈值,最大值为 125 - // ThresholdV1 is the threshold for the first version, with a maximum value of 125 ThresholdV1 = 125 - - // ThresholdV2 是第二个版本的阈值,最大值为 math.MaxUint16 - // ThresholdV2 is the threshold for the second version, with a maximum value of math.MaxUint16 ThresholdV2 = math.MaxUint16 - - // ThresholdV3 是第三个版本的阈值,最大值为 math.MaxUint64 - // ThresholdV3 is the threshold for the third version, with a maximum value of math.MaxUint64 ThresholdV3 = math.MaxUint64 ) -// NetConn 是一个网络连接接口,定义了一个返回 net.Conn 的方法 -// NetConn is a network connection interface that defines a method returning a net.Conn type NetConn interface { - // NetConn 返回一个底层的 net.Conn 对象 - // NetConn returns an underlying net.Conn object NetConn() net.Conn } diff --git a/internal/pool.go b/internal/pool.go index 2262fa91..9a3b73df 100644 --- a/internal/pool.go +++ b/internal/pool.go @@ -5,158 +5,92 @@ import ( "sync" ) -// BufferPool 结构体定义了一个缓冲区池 -// BufferPool struct defines a buffer pool type BufferPool struct { - // begin 表示缓冲区池的起始大小 - // begin indicates the starting size of the buffer pool - begin int - - // end 表示缓冲区池的结束大小 - // end indicates the ending size of the buffer pool - end int - - // shards 是一个映射,键是缓冲区大小,值是对应大小的 sync.Pool - // shards is a map where the key is the buffer size and the value is a sync.Pool for that size + begin int + end int shards map[int]*sync.Pool } // NewBufferPool 创建一个内存池 -// NewBufferPool creates a memory pool +// creates a memory pool // left 和 right 表示内存池的区间范围,它们将被转换为 2 的 n 次幂 // left and right indicate the interval range of the memory pool, they will be transformed into pow(2, n) // 小于 left 的情况下,Get 方法将返回至少 left 字节的缓冲区;大于 right 的情况下,Put 方法不会回收缓冲区 // Below left, the Get method will return at least left bytes; above right, the Put method will not reclaim the buffer func NewBufferPool(left, right uint32) *BufferPool { - // 计算 begin 和 end,分别为 left 和 right 向上取整到 2 的 n 次幂的值 - // Calculate begin and end, which are the ceiling values of left and right to the nearest power of 2 var begin, end = int(binaryCeil(left)), int(binaryCeil(right)) - - // 初始化 BufferPool 结构体 - // Initialize the BufferPool struct var p = &BufferPool{ begin: begin, end: end, - shards: make(map[int]*sync.Pool), + shards: map[int]*sync.Pool{}, } - - // 遍历从 begin 到 end 的所有 2 的 n 次幂的值 - // Iterate over all powers of 2 from begin to end for i := begin; i <= end; i *= 2 { - // 将当前容量赋值给局部变量 capacity - // Assign the current capacity to the local variable capacity capacity := i - - // 为当前容量创建一个 sync.Pool,并将其添加到 shards 映射中 - // Create a sync.Pool for the current capacity and add it to the shards map p.shards[i] = &sync.Pool{ - // 定义当池中没有可用缓冲区时创建新缓冲区的函数 - // Define the function to create a new buffer when there are no available buffers in the pool New: func() any { return bytes.NewBuffer(make([]byte, 0, capacity)) }, } } - - // 返回初始化后的 BufferPool - // Return the initialized BufferPool return p } -// Put 将缓冲区返回到内存池 -// Put returns the buffer to the memory pool +// Put 将缓冲区放回到内存池 +// returns the buffer to the memory pool func (p *BufferPool) Put(b *bytes.Buffer) { - // 如果缓冲区不为空 - // If the buffer is not nil if b != nil { - // 检查缓冲区的容量是否在 shards 映射中 - // Check if the buffer's capacity is in the shards map if pool, ok := p.shards[b.Cap()]; ok { - // 将缓冲区放回对应容量的池中 - // Put the buffer back into the pool of the corresponding capacity pool.Put(b) } } } // Get 从内存池中获取一个至少 n 字节的缓冲区 -// Get fetches a buffer from the memory pool, of at least n bytes +// fetches a buffer from the memory pool, of at least n bytes func (p *BufferPool) Get(n int) *bytes.Buffer { - // 计算所需的缓冲区大小,取 n 和 begin 中较大的值,并向上取整到 2 的 n 次幂 - // Calculate the required buffer size, taking the larger of n and begin, and rounding up to the nearest power of 2 var size = Max(int(binaryCeil(uint32(n))), p.begin) - - // 检查所需大小的缓冲区池是否存在于 shards 映射中 - // Check if the buffer pool of the required size exists in the shards map if pool, ok := p.shards[size]; ok { - // 从池中获取一个缓冲区 - // Get a buffer from the pool b := pool.Get().(*bytes.Buffer) - - // 如果缓冲区的容量小于所需大小,则扩展缓冲区 - // If the buffer's capacity is less than the required size, grow the buffer if b.Cap() < size { b.Grow(size) } - - // 重置缓冲区 - // Reset the buffer b.Reset() - - // 返回缓冲区 - // Return the buffer return b } - - // 如果所需大小的缓冲区池不存在,则创建一个新的缓冲区 - // If the buffer pool of the required size does not exist, create a new buffer return bytes.NewBuffer(make([]byte, 0, n)) } // binaryCeil 将给定的 uint32 值向上取整到最近的 2 的幂 -// binaryCeil rounds up the given uint32 value to the nearest power of 2 +// rounds up the given uint32 value to the nearest power of 2 func binaryCeil(v uint32) uint32 { - // 首先将 v 减 1,以处理 v 本身已经是 2 的幂的情况 - // First, decrement v by 1 to handle the case where v is already a power of 2 v-- - - // 将 v 的每一位与其右边的位进行或运算,逐步填充所有低位 - // Perform bitwise OR operations to fill all lower bits v |= v >> 1 v |= v >> 2 v |= v >> 4 v |= v >> 8 v |= v >> 16 - - // 最后将 v 加 1,得到大于或等于原始 v 的最小 2 的幂 - // Finally, increment v by 1 to get the smallest power of 2 greater than or equal to the original v v++ - - // 返回结果 - // Return the result return v } -// NewPool 创建一个新的泛型池 -// NewPool creates a new generic pool +// NewPool 创建一个新的泛型内存池 +// creates a new generic pool func NewPool[T any](f func() T) *Pool[T] { - // 返回一个包含 sync.Pool 的 Pool 结构体 - // Return a Pool struct containing a sync.Pool return &Pool[T]{p: sync.Pool{New: func() any { return f() }}} } -// Pool 是一个泛型池结构体 -// Pool is a generic pool struct +// Pool 泛型内存池 +// generic pool type Pool[T any] struct { p sync.Pool // 内嵌的 sync.Pool } // Put 将一个值放入池中 -// Put puts a value into the pool +// puts a value into the pool func (c *Pool[T]) Put(v T) { - c.p.Put(v) // 调用 sync.Pool 的 Put 方法 + c.p.Put(v) } // Get 从池中获取一个值 -// Get gets a value from the pool +// gets a value from the pool func (c *Pool[T]) Get() T { - return c.p.Get().(T) // 调用 sync.Pool 的 Get 方法并进行类型断言 + return c.p.Get().(T) } diff --git a/internal/random.go b/internal/random.go index dd69e3a6..b7968aaa 100644 --- a/internal/random.go +++ b/internal/random.go @@ -6,148 +6,69 @@ import ( "time" ) -// RandomString 结构体用于生成随机字符串 -// RandomString struct is used to generate random strings +// RandomString 随机字符串生成器 +// random string generator type RandomString struct { - // mu 是一个互斥锁,用于保护并发访问 - // mu is a mutex to protect concurrent access - mu sync.Mutex - - // r 是一个随机数生成器 - // r is a random number generator - r *rand.Rand - - // layout 是用于生成随机字符串的字符集 - // layout is the character set used to generate random strings + mu sync.Mutex + r *rand.Rand layout string } var ( - // AlphabetNumeric 是一个包含字母和数字字符集的 RandomString 实例 - // AlphabetNumeric is a RandomString instance with an alphanumeric character set + // AlphabetNumeric 包含字母和数字字符集的 RandomString 实例 + // It's a RandomString instance with an alphanumeric character set AlphabetNumeric = &RandomString{ - // layout 包含数字和大小写字母 - // layout contains numbers and uppercase and lowercase letters layout: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ", - - // r 使用当前时间的纳秒数作为种子创建一个新的随机数生成器 - // r creates a new random number generator seeded with the current time in nanoseconds - r: rand.New(rand.NewSource(time.Now().UnixNano())), - - // mu 初始化为一个新的互斥锁 - // mu is initialized as a new mutex - mu: sync.Mutex{}, + r: rand.New(rand.NewSource(time.Now().UnixNano())), + mu: sync.Mutex{}, } - // Numeric 是一个仅包含数字字符集的 RandomString 实例 - // Numeric is a RandomString instance with a numeric character set + // Numeric 仅包含数字字符集的 RandomString 实例 + // It's a RandomString instance with a numeric character set Numeric = &RandomString{ - // layout 仅包含数字 - // layout contains only numbers layout: "0123456789", - - // r 使用当前时间的纳秒数作为种子创建一个新的随机数生成器 - // r creates a new random number generator seeded with the current time in nanoseconds - r: rand.New(rand.NewSource(time.Now().UnixNano())), - - // mu 初始化为一个新的互斥锁 - // mu is initialized as a new mutex - mu: sync.Mutex{}, + r: rand.New(rand.NewSource(time.Now().UnixNano())), + mu: sync.Mutex{}, } ) // Generate 生成一个长度为 n 的随机字节切片 -// Generate generates a random byte slice of length n +// generates a random byte slice of length n func (c *RandomString) Generate(n int) []byte { - // 加锁以确保线程安全 - // Lock to ensure thread safety c.mu.Lock() - - // 创建一个长度为 n 的字节切片 - // Create a byte slice of length n var b = make([]byte, n, n) - - // 获取字符集的长度 - // Get the length of the character set var length = len(c.layout) - - // 生成随机字节 - // Generate random bytes for i := 0; i < n; i++ { - // 从字符集中随机选择一个字符的索引 - // Randomly select an index from the character set var idx = c.r.Intn(length) - - // 将字符集中的字符赋值给字节切片 - // Assign the character from the character set to the byte slice b[i] = c.layout[idx] } - - // 解锁 - // Unlock c.mu.Unlock() - - // 返回生成的字节切片 - // Return the generated byte slice return b } // Intn 返回一个 [0, n) 范围内的随机整数 -// Intn returns a random integer in the range [0, n) +// returns a random integer in the range [0, n) func (c *RandomString) Intn(n int) int { - // 加锁以确保线程安全 - // Lock to ensure thread safety c.mu.Lock() - - // 生成随机整数 - // Generate a random integer x := c.r.Intn(n) - - // 解锁 - // Unlock c.mu.Unlock() - - // 返回生成的随机整数 - // Return the generated random integer return x } // Uint32 返回一个随机的 uint32 值 -// Uint32 returns a random uint32 value +// returns a random uint32 value func (c *RandomString) Uint32() uint32 { - // 加锁以确保线程安全 - // Lock to ensure thread safety c.mu.Lock() - - // 生成随机的 uint32 值 - // Generate a random uint32 value x := c.r.Uint32() - - // 解锁 - // Unlock c.mu.Unlock() - - // 返回生成的随机 uint32 值 - // Return the generated random uint32 value return x } // Uint64 返回一个随机的 uint64 值 -// Uint64 returns a random uint64 value +// returns a random uint64 value func (c *RandomString) Uint64() uint64 { - // 加锁以确保线程安全 - // Lock to ensure thread safety c.mu.Lock() - - // 生成随机的 uint64 值 - // Generate a random uint64 value x := c.r.Uint64() - - // 解锁 - // Unlock c.mu.Unlock() - - // 返回生成的随机 uint64 值 - // Return the generated random uint64 value return x } diff --git a/internal/utils.go b/internal/utils.go index fbea140f..06839e98 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -10,117 +10,54 @@ import ( "unsafe" ) -// 定义一个常量 prime64,其值为 1099511628211 -// Define a constant prime64, its value is 1099511628211 -const prime64 = 1099511628211 - -// 定义一个常量 offset64,其值为 14695981039346656037 -// Define a constant offset64, its value is 14695981039346656037 -const offset64 = 14695981039346656037 +const ( + prime64 = 1099511628211 + offset64 = 14695981039346656037 +) -// 定义一个接口 Integer,它可以是 int、int64、int32、uint、uint64 或 uint32 类型 -// Define an interface Integer, it can be of type int, int64, int32, uint, uint64, or uint32 type Integer interface { int | int64 | int32 | uint | uint64 | uint32 } -// MaskByByte 是一个函数,接收两个字节切片参数 content 和 key,对 content 进行按位异或操作。 -// MaskByByte is a function that takes two byte slice parameters, content and key, and performs bitwise XOR operation on content. func MaskByByte(content []byte, key []byte) { - // 获取 content 的长度,并赋值给 n - // Get the length of content and assign it to n var n = len(content) - - // 遍历 content 中的每一个元素 - // Iterate over each element in content for i := 0; i < n; i++ { - // 计算 i 与 3 的按位与运算结果,并赋值给 idx - // Calculate the bitwise AND operation result of i and 3, and assign it to idx var idx = i & 3 - - // 对 content[i] 和 key[idx] 进行按位异或操作,并将结果赋值给 content[i] - // Perform bitwise XOR operation on content[i] and key[idx], and assign the result to content[i] content[i] ^= key[idx] } } -// ComputeAcceptKey 是一个函数,接收一个字符串参数 challengeKey,计算其 SHA-1 哈希值,并返回其 Base64 编码。 -// ComputeAcceptKey is a function that takes a string parameter challengeKey, calculates its SHA-1 hash, and returns its Base64 encoding. func ComputeAcceptKey(challengeKey string) string { - // 创建一个新的 SHA-1 哈希 - // Create a new SHA-1 hash h := sha1.New() - - // 将 challengeKey 转换为字节切片,并赋值给 buf - // Convert challengeKey to a byte slice and assign it to buf buf := []byte(challengeKey) - - // 将 MagicNumber 追加到 buf 的末尾 - // Append MagicNumber to the end of buf buf = append(buf, MagicNumber...) - - // 将 buf 写入到 h - // Write buf to h h.Write(buf) - - // 计算 h 的 SHA-1 哈希值,并将其转换为 Base64 编码 - // Calculate the SHA-1 hash of h and convert it to Base64 encoding return base64.StdEncoding.EncodeToString(h.Sum(nil)) } -// NewMaskKey 是一个函数,它返回一个长度为 4 的字节构成的随机数切片 -// NewMaskKey is a function that returns a random number slice consisting of 4 bytes func NewMaskKey() [4]byte { - // 调用 AlphabetNumeric.Uint32() 方法生成一个 uint32 类型的随机数 n - // Call the AlphabetNumeric.Uint32() method to generate a random number n of type uint32 n := AlphabetNumeric.Uint32() - - // 返回一个长度为 4 的字节切片,每个字节是 n(随机数) 的一个字节 - // Return a byte slice of length 4, each byte is a byte of n(random number) return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} } -// MethodExists 是一个函数,判断一个对象是否存在某个方法。 -// MethodExists is a function that determines whether an object has a method. +// MethodExists +// if nil return false func MethodExists(in any, method string) (reflect.Value, bool) { - // 如果 in 为 nil 或 method 为空字符串,那么返回一个空的 reflect.Value 和 false。 - // If in is nil or method is an empty string, then return an empty reflect.Value and false. if in == nil || method == "" { return reflect.Value{}, false } - - // 获取 in 的类型,并赋值给 p。 - // Get the type of in and assign it to p. p := reflect.TypeOf(in) - - // 如果 p 的种类(Kind)是指针,那么获取 p 的元素类型。 - // If the kind of p is a pointer, then get the element type of p. if p.Kind() == reflect.Ptr { p = p.Elem() } - - // 如果 p 的种类(Kind)不是结构体,那么返回一个空的 reflect.Value 和 false。 - // If the kind of p is not a struct, then return an empty reflect.Value and false. if p.Kind() != reflect.Struct { return reflect.Value{}, false } - - // 获取 in 的值,并赋值给 object。 - // Get the value of in and assign it to object. object := reflect.ValueOf(in) - - // 通过 method 名称获取 object 的方法,并赋值给 newMethod。 - // Get the method of object by the name of method and assign it to newMethod. newMethod := object.MethodByName(method) - - // 如果 newMethod 不是有效的,那么返回一个空的 reflect.Value 和 false。 - // If newMethod is not valid, then return an empty reflect.Value and false. if !newMethod.IsValid() { return reflect.Value{}, false } - - // 返回 newMethod 和 true。 - // Return newMethod and true. return newMethod, true } @@ -134,73 +71,30 @@ func StringToBytes(s string) []byte { return *(*[]byte)(unsafe.Pointer(&bh)) } -// FnvString 函数接收一个字符串 s,然后使用 FNV-1a 哈希算法计算其哈希值。 -// The FnvString function takes a string s and calculates its hash value using the FNV-1a hash algorithm. func FnvString(s string) uint64 { - // 初始化哈希值为 offset64 - // Initialize the hash value to offset64 var h = uint64(offset64) - - // 遍历字符串 s 中的每个字符 - // Iterate over each character in the string s for _, b := range s { - // 将哈希值乘以 prime64 - // Multiply the hash value by prime64 h *= prime64 - - // 将哈希值与字符的 ASCII 值进行异或操作 - // XOR the hash value with the ASCII value of the character h ^= uint64(b) } - - // 返回计算得到的哈希值 - // Return the calculated hash value return h } -// FnvNumber 函数接收一个整数 x,然后使用 FNV-1a 哈希算法计算其哈希值。 -// The FnvNumber function takes an integer x and calculates its hash value using the FNV-1a hash algorithm. func FnvNumber[T Integer](x T) uint64 { - // 初始化哈希值为 offset64 - // Initialize the hash value to offset64 var h = uint64(offset64) - - // 将哈希值乘以 prime64 - // Multiply the hash value by prime64 h *= prime64 - - // 将哈希值与整数 x 进行异或操作 - // XOR the hash value with the integer x h ^= uint64(x) - - // 返回计算得到的哈希值 - // Return the calculated hash value return h } -// MaskXOR 是一个函数,它接受两个字节切片作为参数:b 和 key。 -// 它使用 key 对 b 进行异或操作,然后将结果存储在 b 中。 -// MaskXOR is a function that takes two byte slices as arguments: b and key. -// It performs an XOR operation on b using key, and then stores the result in b. +// MaskXOR 计算掩码 func MaskXOR(b []byte, key []byte) { - // 将 key 转换为小端序的 uint32,然后转换为 uint64,并将其复制到 key64 的高位和低位。 - // Convert key to a little-endian uint32, then convert it to a uint64, and copy it to the high and low bits of key64. var maskKey = binary.LittleEndian.Uint32(key) var key64 = uint64(maskKey)<<32 + uint64(maskKey) - // 当 b 的长度大于或等于 64 时,将 b 的每 8 个字节与 key64 进行异或操作。 - // When the length of b is greater than or equal to 64, XOR every 8 bytes of b with key64. for len(b) >= 64 { - // 读取 b 的前 8 个字节,并将其解释为小端序的 uint64 - // Read the first 8 bytes of b and interpret it as a little-endian uint64 v := binary.LittleEndian.Uint64(b) - - // 将 v 与 key64 进行异或操作,然后将结果写回 b 的前 8 个字节 - // XOR v with key64 and then write the result back to the first 8 bytes of b binary.LittleEndian.PutUint64(b, v^key64) - - // 以下代码块重复上述操作,但是对 b 的不同部分进行操作 - // The following code blocks repeat the above operation, but operate on different parts of b v = binary.LittleEndian.Uint64(b[8:16]) binary.LittleEndian.PutUint64(b[8:16], v^key64) v = binary.LittleEndian.Uint64(b[16:24]) @@ -215,104 +109,55 @@ func MaskXOR(b []byte, key []byte) { binary.LittleEndian.PutUint64(b[48:56], v^key64) v = binary.LittleEndian.Uint64(b[56:64]) binary.LittleEndian.PutUint64(b[56:64], v^key64) - - // 将 b 的前 64 个字节移除,以便在下一次循环中处理剩余的字节 - // Remove the first 64 bytes of b so that the remaining bytes can be processed in the next loop b = b[64:] } - // 当 b 的长度小于 64 但大于或等于 8 时,将 b 的每 8 个字节与 key64 进行异或操作。 - // When the length of b is less than 64 but greater than or equal to 8, XOR every 8 bytes of b with key64. for len(b) >= 8 { - // 读取 b 的前 8 个字节,并将其解释为小端序的 uint64 - // Read the first 8 bytes of b and interpret it as a little-endian uint64 v := binary.LittleEndian.Uint64(b[:8]) - - // 将 v 与 key64 进行异或操作,然后将结果写回 b 的前 8 个字节 - // XOR v with key64 and then write the result back to the first 8 bytes of b binary.LittleEndian.PutUint64(b[:8], v^key64) - - // 将 b 的前 8 个字节移除,以便在下一次循环中处理剩余的字节 - // Remove the first 8 bytes of b so that the remaining bytes can be processed in the next loop b = b[8:] } - // 当 b 的长度小于 8 时,将 b 的每个字节与 key 的相应字节进行异或操作。 - // When the length of b is less than 8, XOR each byte of b with the corresponding byte of key. var n = len(b) for i := 0; i < n; i++ { - // 计算 key 的索引,这里使用了位运算符 &,它会返回两个数字的二进制表示中都为 1 的位的结果 - // Calculate the index of key, here we use the bitwise operator &, it will return the result of the bits that are 1 in the binary representation of both numbers idx := i & 3 - - // 将 b 的第 i 个字节与 key 的第 idx 个字节进行异或操作,然后将结果写回 b 的第 i 个字节 - // XOR the i-th byte of b with the idx-th byte of key, and then write the result back to the i-th byte of b b[i] ^= key[idx] } } -// InCollection 函数检查给定的字符串 elem 是否在字符串切片 elems 中。 -// The InCollection function checks if the given string elem is in the string slice elems. +// InCollection 检查给定的字符串 elem 是否在字符串切片 elems 中 +// Checks if the given string elem is in the string slice elems. func InCollection(elem string, elems []string) bool { - // 遍历 elems 中的每个元素 - // Iterate over each element in elems for _, item := range elems { - // 如果找到了与 elem 相等的元素,返回 true - // If an element equal to elem is found, return true if item == elem { return true } } - - // 如果没有找到与 elem 相等的元素,返回 false - // If no element equal to elem is found, return false return false } -// GetIntersectionElem 函数获取两个字符串切片 a 和 b 的交集中的一个元素。 -// The GetIntersectionElem function gets an element in the intersection of two string slices a and b. +// GetIntersectionElem 获取两个字符串切片 a 和 b 的交集中的一个元素 +// Gets an element in the intersection of two string slices a and b func GetIntersectionElem(a, b []string) string { - // 遍历 a 中的每个元素 - // Iterate over each element in a for _, item := range a { - // 如果 item 在 b 中,返回 item - // If item is in b, return item if InCollection(item, b) { return item } } - - // 如果 a 和 b 没有交集,返回空字符串 - // If a and b have no intersection, return an empty string return "" } -// Split 函数分割给定的字符串 s,使用 sep 作为分隔符。空值将会被过滤掉。 -// The Split function splits the given string s using sep as the separator. Empty values will be filtered out. +// Split 分割给定的字符串 s,使用 sep 作为分隔符。空值将会被过滤掉。 +// Splits the given string s using sep as the separator. Empty values will be filtered out. func Split(s string, sep string) []string { - // 使用 sep 分割 s,得到一个字符串切片 - // Split s using sep to get a string slice var list = strings.Split(s, sep) - - // 初始化一个索引 j - // Initialize an index j var j = 0 - - // 遍历 list 中的每个元素 - // Iterate over each element in list for _, v := range list { - // 去除 v 的前后空白字符 - // Remove the leading and trailing white space of v if v = strings.TrimSpace(v); v != "" { - // 如果 v 不为空,将其添加到 list 的 j 索引处,并将 j 加 1 - // If v is not empty, add it to the j index of list and increment j by 1 list[j] = v j++ } } - - // 返回 list 的前 j 个元素,即去除了空值的部分 - // Return the first j elements of list, i.e., the part without empty values return list[:j] } @@ -339,28 +184,16 @@ func ToBinaryNumber[T Integer](n T) T { return x } -// BinaryPow 函数接收一个整数 n,然后计算并返回 2 的 n 次方。 -// The BinaryPow function takes an integer n and then calculates and returns 2 to the power of n. func BinaryPow(n int) int { - // 初始化答案为 1 - // Initialize the answer to 1 var ans = 1 - - // 循环 n 次, 持续左移 - // Loop n times, continue to shift left for i := 0; i < n; i++ { - // 将答案左移一位,这相当于将答案乘以 2 - // Shift the answer to the left by one bit, which is equivalent to multiplying the answer by 2 ans <<= 1 } - - // 返回计算得到的答案 - // Return the calculated answer return ans } -// BufferReset 函数接收一个字节缓冲区 b 和一个字节切片 p,然后将 b 的底层切片重置为 p。 -// The BufferReset function takes a byte buffer b and a byte slice p, and then resets the underlying slice of b to p. +// BufferReset 重置buffer底层的切片 +// Reset the buffer's underlying slice // 注意:修改后面的属性一定要加偏移量,否则可能会导致未定义的行为。 // Note: Be sure to add an offset when modifying the following properties, otherwise it may lead to undefined behavior. func BufferReset(b *bytes.Buffer, p []byte) { *(*[]byte)(unsafe.Pointer(b)) = p } diff --git a/option.go b/option.go index 18db5c60..2696b27e 100644 --- a/option.go +++ b/option.go @@ -149,8 +149,8 @@ type ( Logger Logger } - // ServerOption 结构体定义,用于配置 WebSocket 服务器的选项 - // ServerOption struct definition, used to configure WebSocket server options + // ServerOption 服务端配置 + // server configurations ServerOption struct { // 配置 // Configuration @@ -234,235 +234,115 @@ func (c *PermessageDeflate) setThreshold(isServer bool) { } // deleteProtectedHeaders 删除受保护的 WebSocket 头部字段 -// deleteProtectedHeaders removes protected WebSocket header fields +// removes protected WebSocket header fields func (c *ServerOption) deleteProtectedHeaders() { - // 删除 Upgrade 头部字段 - // Remove the Upgrade header field c.ResponseHeader.Del(internal.Upgrade.Key) - - // 删除 Connection 头部字段 - // Remove the Connection header field c.ResponseHeader.Del(internal.Connection.Key) - - // 删除 Sec-WebSocket-Accept 头部字段 - // Remove the Sec-WebSocket-Accept header field c.ResponseHeader.Del(internal.SecWebSocketAccept.Key) - - // 删除 Sec-WebSocket-Extensions 头部字段 - // Remove the Sec-WebSocket-Extensions header field c.ResponseHeader.Del(internal.SecWebSocketExtensions.Key) - - // 删除 Sec-WebSocket-Protocol 头部字段 - // Remove the Sec-WebSocket-Protocol header field c.ResponseHeader.Del(internal.SecWebSocketProtocol.Key) } -// 初始化服务器选项 +// 初始化服务器配置 // Initialize server options func initServerOption(c *ServerOption) *ServerOption { - // 如果 c 为 nil,则创建一个新的 ServerOption 实例 - // If c is nil, create a new ServerOption instance if c == nil { c = new(ServerOption) } - - // 如果 ReadMaxPayloadSize 小于等于 0,则设置为默认值 - // If ReadMaxPayloadSize is less than or equal to 0, set it to the default value if c.ReadMaxPayloadSize <= 0 { c.ReadMaxPayloadSize = defaultReadMaxPayloadSize } - - // 如果 ParallelGolimit 小于等于 0,则设置为默认值 - // If ParallelGolimit is less than or equal to 0, set it to the default value if c.ParallelGolimit <= 0 { c.ParallelGolimit = defaultParallelGolimit } - - // 如果 ReadBufferSize 小于等于 0,则设置为默认值 - // If ReadBufferSize is less than or equal to 0, set it to the default value if c.ReadBufferSize <= 0 { c.ReadBufferSize = defaultReadBufferSize } - - // 如果 WriteMaxPayloadSize 小于等于 0,则设置为默认值 - // If WriteMaxPayloadSize is less than or equal to 0, set it to the default value if c.WriteMaxPayloadSize <= 0 { c.WriteMaxPayloadSize = defaultWriteMaxPayloadSize } - - // 如果 WriteBufferSize 小于等于 0,则设置为默认值 - // If WriteBufferSize is less than or equal to 0, set it to the default value if c.WriteBufferSize <= 0 { c.WriteBufferSize = defaultWriteBufferSize } - - // 如果 Authorize 函数为 nil,则设置为默认函数 - // If the Authorize function is nil, set it to the default function if c.Authorize == nil { c.Authorize = func(r *http.Request, session SessionStorage) bool { return true } } - - // 如果 NewSession 函数为 nil,则设置为默认函数 - // If the NewSession function is nil, set it to the default function if c.NewSession == nil { c.NewSession = func() SessionStorage { return newSmap() } } - - // 如果 ResponseHeader 为 nil,则初始化为一个新的 http.Header - // If ResponseHeader is nil, initialize it as a new http.Header if c.ResponseHeader == nil { c.ResponseHeader = http.Header{} } - - // 如果 HandshakeTimeout 小于等于 0,则设置为默认值 - // If HandshakeTimeout is less than or equal to 0, set it to the default value if c.HandshakeTimeout <= 0 { c.HandshakeTimeout = defaultHandshakeTimeout } - - // 如果 Logger 为 nil,则设置为默认日志记录器 - // If Logger is nil, set it to the default logger if c.Logger == nil { c.Logger = defaultLogger } - - // 如果 Recovery 函数为 nil,则设置为默认函数 - // If the Recovery function is nil, set it to the default function if c.Recovery == nil { c.Recovery = func(logger Logger) {} } - // 如果启用了 PermessageDeflate,则进行相关配置 - // If PermessageDeflate is enabled, configure related settings if c.PermessageDeflate.Enabled { - // 如果 ServerMaxWindowBits 不在 8 到 15 之间,则设置为默认值 - // If ServerMaxWindowBits is not between 8 and 15, set it to the default value if c.PermessageDeflate.ServerMaxWindowBits < 8 || c.PermessageDeflate.ServerMaxWindowBits > 15 { c.PermessageDeflate.ServerMaxWindowBits = internal.SelectValue(c.PermessageDeflate.ServerContextTakeover, 12, 15) } - - // 如果 ClientMaxWindowBits 不在 8 到 15 之间,则设置为默认值 - // If ClientMaxWindowBits is not between 8 and 15, set it to the default value if c.PermessageDeflate.ClientMaxWindowBits < 8 || c.PermessageDeflate.ClientMaxWindowBits > 15 { c.PermessageDeflate.ClientMaxWindowBits = internal.SelectValue(c.PermessageDeflate.ClientContextTakeover, 12, 15) } - - // 如果 Threshold 小于等于 0,则设置为默认值 - // If Threshold is less than or equal to 0, set it to the default value if c.PermessageDeflate.Threshold <= 0 { c.PermessageDeflate.Threshold = defaultCompressThreshold } - - // 如果 Level 等于 0,则设置为默认值 - // If Level is equal to 0, set it to the default value if c.PermessageDeflate.Level == 0 { c.PermessageDeflate.Level = defaultCompressLevel } - - // 如果 PoolSize 小于等于 0,则设置为默认值 - // If PoolSize is less than or equal to 0, set it to the default value if c.PermessageDeflate.PoolSize <= 0 { c.PermessageDeflate.PoolSize = defaultCompressorPoolSize } - - // 将 PoolSize 转换为二进制数 - // Convert PoolSize to a binary number c.PermessageDeflate.PoolSize = internal.ToBinaryNumber(c.PermessageDeflate.PoolSize) } - // 删除受保护的头部信息 - // Delete protected headers c.deleteProtectedHeaders() - // 配置 WebSocket 客户端的选项 - // Configure WebSocket client options c.config = &Config{ - // 是否启用并行处理 - // Whether parallel processing is enabled - ParallelEnabled: c.ParallelEnabled, - - // 并行协程限制 - // Parallel goroutine limit - ParallelGolimit: c.ParallelGolimit, - - // 读取最大负载大小 - // Maximum payload size for reading - ReadMaxPayloadSize: c.ReadMaxPayloadSize, - - // 读取缓冲区大小 - // Read buffer size - ReadBufferSize: c.ReadBufferSize, - - // 写入最大负载大小 - // Maximum payload size for writing + ParallelEnabled: c.ParallelEnabled, + ParallelGolimit: c.ParallelGolimit, + ReadMaxPayloadSize: c.ReadMaxPayloadSize, + ReadBufferSize: c.ReadBufferSize, WriteMaxPayloadSize: c.WriteMaxPayloadSize, - - // 写缓冲区大小 - // Write buffer size - WriteBufferSize: c.WriteBufferSize, - - // 是否启用 UTF-8 检查 - // Whether UTF-8 check is enabled - CheckUtf8Enabled: c.CheckUtf8Enabled, - - // 恢复函数 - // Recovery function - Recovery: c.Recovery, - - // 日志记录器 - // Logger - Logger: c.Logger, - - // 缓冲区读取池 - // Buffer reader pool + WriteBufferSize: c.WriteBufferSize, + CheckUtf8Enabled: c.CheckUtf8Enabled, + Recovery: c.Recovery, + Logger: c.Logger, brPool: internal.NewPool(func() *bufio.Reader { return bufio.NewReaderSize(nil, c.ReadBufferSize) }), } - // 如果启用了 PermessageDeflate,则进行相关配置 - // If PermessageDeflate is enabled, configure related settings if c.PermessageDeflate.Enabled { - // 如果服务器上下文接管启用,则配置服务器窗口大小池 - // If server context takeover is enabled, configure server window size pool if c.PermessageDeflate.ServerContextTakeover { - // 计算服务器窗口大小 - // Calculate server window size windowSize := internal.BinaryPow(c.PermessageDeflate.ServerMaxWindowBits) - - // 创建服务器窗口大小池 - // Create server window size pool c.config.cswPool = internal.NewPool[[]byte](func() []byte { return make([]byte, 0, windowSize) }) } - - // 如果客户端上下文接管启用,则配置客户端窗口大小池 - // If client context takeover is enabled, configure client window size pool if c.PermessageDeflate.ClientContextTakeover { - // 计算客户端窗口大小 - // Calculate client window size windowSize := internal.BinaryPow(c.PermessageDeflate.ClientMaxWindowBits) - - // 创建客户端窗口大小池 - // Create client window size pool c.config.dswPool = internal.NewPool[[]byte](func() []byte { return make([]byte, 0, windowSize) }) } } - // 返回配置后的客户端选项 - // Return the configured client options return c } -// 获取通用配置 -// Get common configuration +// 获取服务器配置 +// Get server configuration func (c *ServerOption) getConfig() *Config { return c.config } -// ClientOption 结构体定义,用于配置 WebSocket 客户端的选项 -// ClientOption struct definition, used to configure WebSocket client options +// ClientOption 客户端配置 +// client configurations type ClientOption struct { // 写缓冲区的大小, v1.4.5版本此参数被废弃 // Deprecated: Size of the write buffer, v1.4.5 version of this parameter is deprecated @@ -535,162 +415,76 @@ type ClientOption struct { NewSession func() SessionStorage } -// 初始化客户端选项 +// 初始化客户端配置 // Initialize client options func initClientOption(c *ClientOption) *ClientOption { - // 如果 c 为 nil,则创建一个新的 ClientOption 实例 - // If c is nil, create a new ClientOption instance if c == nil { c = new(ClientOption) } - - // 如果 ReadMaxPayloadSize 小于等于 0,则设置为默认值 - // If ReadMaxPayloadSize is less than or equal to 0, set it to the default value if c.ReadMaxPayloadSize <= 0 { c.ReadMaxPayloadSize = defaultReadMaxPayloadSize } - - // 如果 ParallelGolimit 小于等于 0,则设置为默认值 - // If ParallelGolimit is less than or equal to 0, set it to the default value if c.ParallelGolimit <= 0 { c.ParallelGolimit = defaultParallelGolimit } - - // 如果 ReadBufferSize 小于等于 0,则设置为默认值 - // If ReadBufferSize is less than or equal to 0, set it to the default value if c.ReadBufferSize <= 0 { c.ReadBufferSize = defaultReadBufferSize } - - // 如果 WriteMaxPayloadSize 小于等于 0,则设置为默认值 - // If WriteMaxPayloadSize is less than or equal to 0, set it to the default value if c.WriteMaxPayloadSize <= 0 { c.WriteMaxPayloadSize = defaultWriteMaxPayloadSize } - - // 如果 WriteBufferSize 小于等于 0,则设置为默认值 - // If WriteBufferSize is less than or equal to 0, set it to the default value if c.WriteBufferSize <= 0 { c.WriteBufferSize = defaultWriteBufferSize } - - // 如果 HandshakeTimeout 小于等于 0,则设置为默认值 - // If HandshakeTimeout is less than or equal to 0, set it to the default value if c.HandshakeTimeout <= 0 { c.HandshakeTimeout = defaultHandshakeTimeout } - - // 如果 RequestHeader 为 nil,则初始化为一个新的 http.Header - // If RequestHeader is nil, initialize it as a new http.Header if c.RequestHeader == nil { c.RequestHeader = http.Header{} } - - // 如果 NewDialer 函数为 nil,则设置为默认函数 - // If the NewDialer function is nil, set it to the default function if c.NewDialer == nil { c.NewDialer = func() (Dialer, error) { return &net.Dialer{Timeout: defaultDialTimeout}, nil } } - - // 如果 NewSession 函数为 nil,则设置为默认函数 - // If the NewSession function is nil, set it to the default function if c.NewSession == nil { c.NewSession = func() SessionStorage { return newSmap() } } - - // 如果 Logger 为 nil,则设置为默认日志记录器 - // If Logger is nil, set it to the default logger if c.Logger == nil { c.Logger = defaultLogger } - - // 如果 Recovery 函数为 nil,则设置为默认函数 - // If the Recovery function is nil, set it to the default function if c.Recovery == nil { c.Recovery = func(logger Logger) {} } - - // 如果启用了 PermessageDeflate,则进行相关配置 - // If PermessageDeflate is enabled, configure related settings if c.PermessageDeflate.Enabled { - // 如果 ServerMaxWindowBits 不在 8 到 15 之间,则设置为默认值 - // If ServerMaxWindowBits is not between 8 and 15, set it to the default value if c.PermessageDeflate.ServerMaxWindowBits < 8 || c.PermessageDeflate.ServerMaxWindowBits > 15 { c.PermessageDeflate.ServerMaxWindowBits = 15 } - - // 如果 ClientMaxWindowBits 不在 8 到 15 之间,则设置为默认值 - // If ClientMaxWindowBits is not between 8 and 15, set it to the default value if c.PermessageDeflate.ClientMaxWindowBits < 8 || c.PermessageDeflate.ClientMaxWindowBits > 15 { c.PermessageDeflate.ClientMaxWindowBits = 15 } - - // 如果 Threshold 小于等于 0,则设置为默认值 - // If Threshold is less than or equal to 0, set it to the default value if c.PermessageDeflate.Threshold <= 0 { c.PermessageDeflate.Threshold = defaultCompressThreshold } - - // 如果 Level 等于 0,则设置为默认值 - // If Level is equal to 0, set it to the default value if c.PermessageDeflate.Level == 0 { c.PermessageDeflate.Level = defaultCompressLevel } - - // 设置 PoolSize 为 1 - // Set PoolSize to 1 c.PermessageDeflate.PoolSize = 1 } - - // 返回配置后的 ClientOption 实例 - // Return the configured ClientOption instance return c } -// getConfig 方法将 ClientOption 的配置转换为 Config 并返回 -// The getConfig method converts the ClientOption configuration to Config and returns it +// 将 ClientOption 的配置转换为 Config 并返回 +// converts the ClientOption configuration to Config and returns it func (c *ClientOption) getConfig() *Config { - // 创建一个新的 Config 实例,并将 ClientOption 的各项配置赋值给它 - // Create a new Config instance and assign the ClientOption configurations to it config := &Config{ - // 并行处理是否启用 - // Whether parallel processing is enabled - ParallelEnabled: c.ParallelEnabled, - - // 并行处理的协程限制 - // The goroutine limit for parallel processing - ParallelGolimit: c.ParallelGolimit, - - // 读取的最大有效负载大小 - // The maximum payload size for reading - ReadMaxPayloadSize: c.ReadMaxPayloadSize, - - // 读取缓冲区大小 - // The buffer size for reading - ReadBufferSize: c.ReadBufferSize, - - // 写入的最大有效负载大小 - // The maximum payload size for writing + ParallelEnabled: c.ParallelEnabled, + ParallelGolimit: c.ParallelGolimit, + ReadMaxPayloadSize: c.ReadMaxPayloadSize, + ReadBufferSize: c.ReadBufferSize, WriteMaxPayloadSize: c.WriteMaxPayloadSize, - - // 写入缓冲区大小 - // The buffer size for writing - WriteBufferSize: c.WriteBufferSize, - - // 是否启用 UTF-8 检查 - // Whether UTF-8 checking is enabled - CheckUtf8Enabled: c.CheckUtf8Enabled, - - // 恢复函数 - // The recovery function - Recovery: c.Recovery, - - // 日志记录器 - // The logger - Logger: c.Logger, + WriteBufferSize: c.WriteBufferSize, + CheckUtf8Enabled: c.CheckUtf8Enabled, + Recovery: c.Recovery, + Logger: c.Logger, } - - // 返回配置后的 Config 实例 - // Return the configured Config instance return config } diff --git a/reader.go b/reader.go index c5735c82..d53a82a3 100644 --- a/reader.go +++ b/reader.go @@ -8,24 +8,19 @@ import ( "github.com/lxzan/gws/internal" ) -// checkMask 检查掩码设置是否符合 RFC6455 协议。 -// checkMask checks if the mask setting complies with the RFC6455 protocol. +// 检查掩码设置是否符合 RFC6455 协议。 +// Checks if the mask setting complies with the RFC6455 protocol. func (c *Conn) checkMask(enabled bool) error { // RFC6455: 所有从客户端发送到服务器的帧都必须设置掩码位为 1。 // RFC6455: All frames sent from client to server must have the mask bit set to 1. if (c.isServer && !enabled) || (!c.isServer && enabled) { - // 如果服务器端未启用掩码或客户端启用了掩码,则返回协议错误。 - // Return a protocol error if the server has the mask disabled or the client has the mask enabled. return internal.CloseProtocolError } - - // 掩码设置正确,返回 nil 表示没有错误。 - // The mask setting is correct, return nil indicating no error. return nil } -// readControl 读取控制帧 -// readControl reads a control frame +// 读取控制帧 +// Reads a control frame func (c *Conn) readControl() error { // RFC6455: 控制帧本身不能被分片。 // RFC6455: Control frames themselves MUST NOT be fragmented. @@ -36,9 +31,6 @@ func (c *Conn) readControl() error { // RFC6455: 所有控制帧的有效载荷长度必须为 125 字节或更少,并且不能被分片。 // RFC6455: All control frames MUST have a payload length of 125 bytes or fewer and MUST NOT be fragmented. var n = c.fh.GetLengthCode() - - // 控制帧的有效载荷长度不能超过 125 字节 - // The payload length of the control frame cannot exceed 125 bytes if n > internal.ThresholdV1 { return internal.CloseProtocolError } @@ -46,62 +38,34 @@ func (c *Conn) readControl() error { // 不回收小块 buffer,控制帧一般 payload 长度为 0 // Do not recycle small buffers, control frames generally have a payload length of 0 var payload []byte - - // 如果有效载荷长度大于 0,则读取有效载荷数据 - // If the payload length is greater than 0, read the payload data if n > 0 { - // 创建一个长度为 n 的 payload 切片 - // Create a payload slice with length n payload = make([]byte, n) - - // 读取 n 字节的数据到 payload 中 - // Read n bytes of data into the payload if err := internal.ReadN(c.br, payload); err != nil { return err } - - // 如果启用了掩码,则对 payload 进行掩码操作 - // If masking is enabled, apply the mask to the payload if maskEnabled := c.fh.GetMask(); maskEnabled { internal.MaskXOR(payload, c.fh.GetMaskKey()) } } - // 获取操作码 - // Get the opcode var opcode = c.fh.GetOpcode() - - // 根据操作码调用相应的方法 - // Call the corresponding method based on the opcode switch opcode { - case OpcodePing: - // 如果操作码为 Ping,调用 OnPing 方法 - // If the opcode is Ping, call the OnPing method c.handler.OnPing(c, payload) return nil - case OpcodePong: - // 如果操作码为 Pong,调用 OnPong 方法 - // If the opcode is Pong, call the OnPong method c.handler.OnPong(c, payload) return nil - case OpcodeCloseConnection: - // 如果操作码为 CloseConnection,调用 emitClose 方法 - // If the opcode is CloseConnection, call the emitClose method return c.emitClose(bytes.NewBuffer(payload)) - default: - // 如果操作码为其他值,返回一个错误 - // If the opcode is other values, return an error var err = fmt.Errorf("gws: unexpected opcode %d", opcode) return internal.NewError(internal.CloseProtocolError, err) } } -// readMessage 读取消息 -// readMessage reads a message +// 读取消息 +// Reads a message func (c *Conn) readMessage() error { // 解析帧头并获取内容长度 // Parse the frame header and get the content length @@ -109,9 +73,6 @@ func (c *Conn) readMessage() error { if err != nil { return err } - - // 检查内容长度是否超过配置的最大有效载荷大小 - // Check if the content length exceeds the configured maximum payload size if contentLength > c.config.ReadMaxPayloadSize { return internal.CloseMessageTooLarge } @@ -127,204 +88,90 @@ func (c *Conn) readMessage() error { return internal.CloseProtocolError } - // 获取掩码标志 - // Get the mask flag maskEnabled := c.fh.GetMask() - - // 检查掩码设置是否符合协议 - // Check if the mask setting complies with the protocol if err := c.checkMask(maskEnabled); err != nil { return err } - // 读取控制帧 - // Read control frame var opcode = c.fh.GetOpcode() - - // 检查是否启用了压缩并且 RSV1 标志已设置 - // Check if compression is enabled and the RSV1 flag is set var compressed = c.pd.Enabled && c.fh.GetRSV1() - - // 如果操作码不是数据帧,则读取控制帧 - // If the opcode is not a data frame, read the control frame if !opcode.isDataFrame() { return c.readControl() } - // 获取 FIN 标志 - // Get the FIN flag var fin = c.fh.GetFIN() - - // 从内存池中获取一个缓冲区 - // Get a buffer from the memory pool var buf = binaryPool.Get(contentLength + len(flateTail)) - - // 将缓冲区切片到内容长度 - // Slice the buffer to the content length var p = buf.Bytes()[:contentLength] - - // 创建一个 Message 实例,并在函数退出时关闭缓冲区 - // Create a Message instance and close the buffer when the function exits var closer = Message{Data: buf} defer closer.Close() - // 读取指定长度的数据到缓冲区 - // Read the specified length of data into the buffer if err := internal.ReadN(c.br, p); err != nil { return err } - - // 如果启用了掩码,对数据进行掩码操作 - // If masking is enabled, apply the mask to the data if maskEnabled { internal.MaskXOR(p, c.fh.GetMaskKey()) } - // 检查操作码是否不是继续帧并且 continuationFrame 已初始化 - // Check if the opcode is not a continuation frame and the continuationFrame is initialized + if opcode != OpcodeContinuation && c.continuationFrame.initialized { - // 如果是,则返回协议错误 - // If so, return a protocol error return internal.CloseProtocolError } - // 如果是最后一帧并且操作码不是继续帧 - // If it is the final frame and the opcode is not a continuation frame if fin && opcode != OpcodeContinuation { - // 将缓冲区转换为字节切片 - // Convert the buffer to a byte slice *(*[]byte)(unsafe.Pointer(buf)) = p - - // 如果未启用压缩,则将 closer.Data 置为 nil - // If compression is not enabled, set closer.Data to nil if !compressed { closer.Data = nil } - - // 发出消息并返回 - // Emit the message and return return c.emitMessage(&Message{Opcode: opcode, Data: buf, compressed: compressed}) } - // 如果不是最后一帧并且操作码不是继续帧 - // If it is not the final frame and the opcode is not a continuation frame + // 处理分片消息 + // processing segmented messages if !fin && opcode != OpcodeContinuation { - // 初始化 continuationFrame - // Initialize the continuationFrame c.continuationFrame.initialized = true - - // 设置 continuationFrame 的压缩标志 - // Set the compressed flag of the continuationFrame c.continuationFrame.compressed = compressed - - // 设置 continuationFrame 的操作码 - // Set the opcode of the continuationFrame c.continuationFrame.opcode = opcode - - // 初始化 continuationFrame 的缓冲区,容量为 contentLength - // Initialize the buffer of the continuationFrame with a capacity of contentLength c.continuationFrame.buffer = bytes.NewBuffer(make([]byte, 0, contentLength)) } - - // 如果 continuationFrame 未初始化 - // If the continuationFrame is not initialized if !c.continuationFrame.initialized { - // 返回协议错误 - // Return a protocol error return internal.CloseProtocolError } - // 将数据写入 continuationFrame 的缓冲区 - // Write data to the continuationFrame's buffer c.continuationFrame.buffer.Write(p) - - // 如果缓冲区长度超过最大有效载荷大小 - // If the buffer length exceeds the maximum payload size if c.continuationFrame.buffer.Len() > c.config.ReadMaxPayloadSize { - // 返回消息过大错误 - // Return a message too large error return internal.CloseMessageTooLarge } - - // 如果不是最后一帧,返回 nil - // If it is not the final frame, return nil if !fin { return nil } - // 创建一个新的 Message 实例 - // Create a new Message instance - msg := &Message{ - // 设置操作码为 continuationFrame 的操作码 - // Set the opcode to the opcode of the continuationFrame - Opcode: c.continuationFrame.opcode, - - // 设置数据为 continuationFrame 的缓冲区 - // Set the data to the buffer of the continuationFrame - Data: c.continuationFrame.buffer, - - // 设置压缩标志为 continuationFrame 的压缩标志 - // Set the compressed flag to the compressed flag of the continuationFrame - compressed: c.continuationFrame.compressed, - } - - // 重置 continuationFrame - // Reset the continuationFrame + msg := &Message{Opcode: c.continuationFrame.opcode, Data: c.continuationFrame.buffer, compressed: c.continuationFrame.compressed} c.continuationFrame.reset() - - // 发出消息并返回 - // Emit the message and return return c.emitMessage(msg) } -// dispatch 分发消息给消息处理器 -// dispatch dispatches the message to the message handler +// 分发消息和异常恢复 +// Dispatch message & Recovery func (c *Conn) dispatch(msg *Message) error { - // 使用 defer 确保在函数退出时调用 Recovery 方法进行错误恢复 - // Use defer to ensure the Recovery method is called for error recovery when the function exits defer c.config.Recovery(c.config.Logger) - - // 调用消息处理器的 OnMessage 方法处理消息 - // Call the OnMessage method of the message handler to process the message c.handler.OnMessage(c, msg) - - // 返回 nil 表示没有错误 - // Return nil indicating no error return nil } -// emitMessage 处理并发出消息 -// emitMessage processes and emits the message +// 发射消息事件 +// Emit onmessage event func (c *Conn) emitMessage(msg *Message) (err error) { - // 如果消息是压缩的,先解压缩消息数据 - // If the message is compressed, decompress the message data first if msg.compressed { msg.Data, err = c.deflater.Decompress(msg.Data, c.getDpsDict()) if err != nil { - // 如果解压缩失败,返回内部服务器错误 - // If decompression fails, return an internal server error return internal.NewError(internal.CloseInternalServerErr, err) } - - // 将解压缩后的数据写入 dpsWindow - // Write the decompressed data to dpsWindow c.dpsWindow.Write(msg.Bytes()) } - - // 检查文本消息的编码是否有效 - // Check if the text message encoding is valid if !c.isTextValid(msg.Opcode, msg.Bytes()) { - // 如果编码无效,返回不支持的数据错误 - // If the encoding is invalid, return an unsupported data error return internal.NewError(internal.CloseUnsupportedData, ErrTextEncoding) } - - // 如果启用了并行处理,则将消息放入读取队列并发处理 - // If parallel processing is enabled, put the message into the read queue for concurrent processing if c.config.ParallelEnabled { return c.readQueue.Go(msg, c.dispatch) } - - // 否则,直接分发消息 - // Otherwise, directly dispatch the message return c.dispatch(msg) } diff --git a/task.go b/task.go index 677813c3..f4159269 100644 --- a/task.go +++ b/task.go @@ -7,156 +7,91 @@ import ( ) type ( - // workerQueue 代表一个任务队列 - // workerQueue represents a task queue + // workerQueue 任务队列 + // task queue workerQueue struct { - // mu 是一个互斥锁,用于保护对队列的并发访问 - // mu is a mutex to protect concurrent access to the queue + // mu 互斥锁 + // mutex mu sync.Mutex - // q 是一个双端队列,用于存储异步任务 - // q is a double-ended queue to store asynchronous jobs + // q 双端队列,用于存储异步任务 + // double-ended queue to store asynchronous jobs q internal.Deque[asyncJob] - // maxConcurrency 是最大并发数 - // maxConcurrency is the maximum concurrency + // maxConcurrency 最大并发数 + // maximum concurrency maxConcurrency int32 - // curConcurrency 是当前并发数 - // curConcurrency is the current concurrency + // curConcurrency 当前并发数 + // current concurrency curConcurrency int32 } - // asyncJob 代表一个异步任务 - // asyncJob represents an asynchronous job + // asyncJob 异步任务 + // asynchronous job asyncJob func() ) // newWorkerQueue 创建一个任务队列 -// newWorkerQueue creates a task queue +// creates a task queue func newWorkerQueue(maxConcurrency int32) *workerQueue { c := &workerQueue{ - // 初始化互斥锁 - // Initialize the mutex - mu: sync.Mutex{}, - - // 设置最大并发数 - // Set the maximum concurrency + mu: sync.Mutex{}, maxConcurrency: maxConcurrency, - - // 初始化当前并发数为 0 - // Initialize the current concurrency to 0 curConcurrency: 0, } - - // 返回初始化的任务队列 - // Return the initialized task queue return c } // 获取一个任务 // getJob retrieves a job from the worker queue func (c *workerQueue) getJob(newJob asyncJob, delta int32) asyncJob { - // 加锁以确保线程安全 - // Lock to ensure thread safety c.mu.Lock() - // 在函数结束时解锁 - // Unlock at the end of the function defer c.mu.Unlock() - // 如果有新任务,将其添加到队列中 - // If there is a new job, add it to the queue if newJob != nil { c.q.PushBack(newJob) } - - // 更新当前并发数 - // Update the current concurrency count c.curConcurrency += delta - - // 如果当前并发数达到或超过最大并发数,返回 nil - // If the current concurrency count reaches or exceeds the maximum concurrency, return nil if c.curConcurrency >= c.maxConcurrency { return nil } - - // 从队列中取出一个任务 - // Retrieve a job from the queue var job = c.q.PopFront() - - // 如果队列为空,返回 nil - // If the queue is empty, return nil if job == nil { return nil } - - // 增加当前并发数 - // Increment the current concurrency count c.curConcurrency++ - - // 返回取出的任务 - // Return the retrieved job return job } // 循环执行任务 // do continuously executes jobs in the worker queue func (c *workerQueue) do(job asyncJob) { - // 当任务不为空时,循环执行任务 - // Loop to execute jobs as long as the job is not nil for job != nil { - // 执行当前任务 - // Execute the current job job() - // 获取下一个任务并减少当前并发数 - // Get the next job and decrement the current concurrency count job = c.getJob(nil, -1) } } // Push 追加任务, 有资源空闲的话会立即执行 -// Push adds a job to the queue and executes it immediately if resources are available +// adds a job to the queue and executes it immediately if resources are available func (c *workerQueue) Push(job asyncJob) { - // 获取下一个任务,如果有资源空闲的话 - // Get the next job if resources are available if nextJob := c.getJob(job, 0); nextJob != nil { - // 启动一个新的 goroutine 来执行任务 - // Start a new goroutine to execute the job go c.do(nextJob) } } -// 定义一个名为 channel 的类型,底层类型为 struct{} 的通道 -// Define a type named channel, which is a channel of struct{} type channel chan struct{} -// add 方法向通道发送一个空的 struct{},表示增加一个任务 -// The add method sends an empty struct{} to the channel, indicating the addition of a task func (c channel) add() { c <- struct{}{} } -// done 方法从通道接收一个空的 struct{},表示完成一个任务 -// The done method receives an empty struct{} from the channel, indicating the completion of a task func (c channel) done() { <-c } -// Go 方法接收一个消息和一个函数,启动一个新的 goroutine 来执行该函数 -// The Go method receives a message and a function, and starts a new goroutine to execute the function func (c channel) Go(m *Message, f func(*Message) error) error { - // 增加一个任务 - // Add a task c.add() - - // 启动一个新的 goroutine 来执行函数 f - // Start a new goroutine to execute the function f go func() { - // 执行函数 f,并忽略其返回值 - // Execute the function f and ignore its return value _ = f(m) - // 完成一个任务 - // Complete a task c.done() }() - - // 返回 nil 表示成功 - // Return nil to indicate success return nil } diff --git a/types.go b/types.go index d67e36e9..0036299d 100644 --- a/types.go +++ b/types.go @@ -14,52 +14,26 @@ import ( "github.com/lxzan/gws/internal" ) -// 定义帧头的大小常量 -// Define a constant for the frame header size const frameHeaderSize = 14 -// 定义 Opcode 类型,底层类型为 uint8 -// Define the Opcode type, which is an alias for uint8 +// Opcode 操作码 type Opcode uint8 -// 定义各种操作码常量 -// Define constants for various opcodes const ( - // 继续帧操作码 - // Continuation frame opcode - OpcodeContinuation Opcode = 0x0 - - // 文本帧操作码 - // Text frame opcode - OpcodeText Opcode = 0x1 - - // 二进制帧操作码 - // Binary frame opcode - OpcodeBinary Opcode = 0x2 - - // 关闭连接操作码 - // Close connection opcode - OpcodeCloseConnection Opcode = 0x8 - - // Ping 操作码 - // Ping opcode - OpcodePing Opcode = 0x9 - - // Pong 操作码 - // Pong opcode - OpcodePong Opcode = 0xA + OpcodeContinuation Opcode = 0x0 // 继续 + OpcodeText Opcode = 0x1 // 文本 + OpcodeBinary Opcode = 0x2 // 二级制 + OpcodeCloseConnection Opcode = 0x8 // 关闭 + OpcodePing Opcode = 0x9 // 心跳探测 + OpcodePong Opcode = 0xA // 心跳回应 ) -// isDataFrame 方法判断操作码是否为数据帧 -// The isDataFrame method checks if the opcode is a data frame +// isDataFrame 判断操作码是否为数据帧 +// checks if the opcode is a data frame func (c Opcode) isDataFrame() bool { - // 如果操作码小于等于二进制帧操作码,则返回 true - // Return true if the opcode is less than or equal to the binary frame opcode return c <= OpcodeBinary } -// 定义 CloseError 类型,包含关闭代码和原因 -// Define the CloseError type, which includes a close code and a reason type CloseError struct { // 关闭代码,表示关闭连接的原因 // Close code, indicating the reason for closing the connection @@ -70,17 +44,13 @@ type CloseError struct { Reason []byte } -// Error 方法返回关闭错误的描述 -// The Error method returns a description of the close error +// Error 关闭错误的描述 +// returns a description of the close error func (c *CloseError) Error() string { - // 返回格式化的错误信息,包含关闭代码和原因 - // Return a formatted error message that includes the close code and reason return fmt.Sprintf("gws: connection closed, code=%d, reason=%s", c.Code, string(c.Reason)) } var ( - // ErrEmpty 空错误 - // Empty error errEmpty = errors.New("") // ErrUnauthorized 未通过鉴权认证 @@ -138,339 +108,192 @@ type Event interface { OnMessage(socket *Conn, message *Message) } -// BuiltinEventHandler 是一个内置的事件处理器结构体 -// BuiltinEventHandler is a built-in event handler struct type BuiltinEventHandler struct{} -// OnOpen 在连接打开时调用 -// OnOpen is called when the connection is opened func (b BuiltinEventHandler) OnOpen(socket *Conn) {} -// OnClose 在连接关闭时调用 -// OnClose is called when the connection is closed func (b BuiltinEventHandler) OnClose(socket *Conn, err error) {} -// OnPing 在接收到 Ping 帧时调用 -// OnPing is called when a Ping frame is received -func (b BuiltinEventHandler) OnPing(socket *Conn, payload []byte) { - // 发送 Pong 帧作为响应 - // Send a Pong frame in response - _ = socket.WritePong(nil) -} +func (b BuiltinEventHandler) OnPing(socket *Conn, payload []byte) { _ = socket.WritePong(nil) } -// OnPong 在接收到 Pong 帧时调用 -// OnPong is called when a Pong frame is received func (b BuiltinEventHandler) OnPong(socket *Conn, payload []byte) {} -// OnMessage 在接收到消息时调用 -// OnMessage is called when a message is received func (b BuiltinEventHandler) OnMessage(socket *Conn, message *Message) {} -// 定义帧头类型,大小为 frameHeaderSize 的字节数组 -// Define the frameHeader type, which is an array of bytes with size frameHeaderSize type frameHeader [frameHeaderSize]byte -// GetFIN 方法返回 FIN 位的值 -// The GetFIN method returns the value of the FIN bit +// GetFIN 返回 FIN 位的值 +// returns the value of the FIN bit func (c *frameHeader) GetFIN() bool { - // 通过右移 7 位获取第一个字节的最高位 - // Get the highest bit of the first byte by shifting right 7 bits return ((*c)[0] >> 7) == 1 } -// GetRSV1 方法返回 RSV1 位的值 -// The GetRSV1 method returns the value of the RSV1 bit +// GetRSV1 返回 RSV1 位的值 +// returns the value of the RSV1 bit func (c *frameHeader) GetRSV1() bool { - // 通过左移 1 位再右移 7 位获取第一个字节的第二高位 - // Get the second highest bit of the first byte by shifting left 1 bit and then right 7 bits return ((*c)[0] << 1 >> 7) == 1 } -// GetRSV2 方法返回 RSV2 位的值 -// The GetRSV2 method returns the value of the RSV2 bit +// GetRSV2 返回 RSV2 位的值 +// returns the value of the RSV2 bit func (c *frameHeader) GetRSV2() bool { - // 通过左移 2 位再右移 7 位获取第一个字节的第三高位 - // Get the third highest bit of the first byte by shifting left 2 bits and then right 7 bits return ((*c)[0] << 2 >> 7) == 1 } -// GetRSV3 方法返回 RSV3 位的值 -// The GetRSV3 method returns the value of the RSV3 bit +// GetRSV3 返回 RSV3 位的值 +// returns the value of the RSV3 bit func (c *frameHeader) GetRSV3() bool { - // 通过左移 3 位再右移 7 位获取第一个字节的第四高位 - // Get the fourth highest bit of the first byte by shifting left 3 bits and then right 7 bits return ((*c)[0] << 3 >> 7) == 1 } -// GetOpcode 方法返回操作码 -// The GetOpcode method returns the opcode +// GetOpcode 返回操作码 +// returns the opcode func (c *frameHeader) GetOpcode() Opcode { - // 通过左移 4 位再右移 4 位获取第一个字节的低 4 位 - // Get the lowest 4 bits of the first byte by shifting left 4 bits and then right 4 bits return Opcode((*c)[0] << 4 >> 4) } -// GetMask 方法返回 Mask 位的值 -// The GetMask method returns the value of the Mask bit +// GetMask 返回掩码 +// returns the value of the mask bytes func (c *frameHeader) GetMask() bool { - // 通过右移 7 位获取第二个字节的最高位 - // Get the highest bit of the second byte by shifting right 7 bits return ((*c)[1] >> 7) == 1 } -// GetLengthCode 方法返回长度代码 -// The GetLengthCode method returns the length code +// GetLengthCode 返回长度代码 +// returns the length code func (c *frameHeader) GetLengthCode() uint8 { - // 通过左移 1 位再右移 1 位获取第二个字节的低 7 位 - // Get the lowest 7 bits of the second byte by shifting left 1 bit and then right 1 bit return (*c)[1] << 1 >> 1 } -// SetMask 方法设置 Mask 位为 1 -// The SetMask method sets the Mask bit to 1 +// SetMask 设置 Mask 位为 1 +// sets the Mask bit to 1 func (c *frameHeader) SetMask() { - // 将第二个字节的最高位置为 1 - // Set the highest bit of the second byte to 1 (*c)[1] |= uint8(128) } -// SetLength 方法设置帧的长度,并返回偏移量 -// The SetLength method sets the frame length and returns the offset +// SetLength 设置帧的长度,并返回偏移量 +// sets the frame length and returns the offset func (c *frameHeader) SetLength(n uint64) (offset int) { - // 如果长度小于等于 ThresholdV1 - // If the length is less than or equal to ThresholdV1 if n <= internal.ThresholdV1 { - // 将长度直接设置到帧头的第二个字节 - // Set the length directly in the second byte of the frame header (*c)[1] += uint8(n) - - // 返回 0 偏移量 - // Return 0 offset return 0 - } else if n <= internal.ThresholdV2 { - // 如果长度小于等于 ThresholdV2 - // If the length is less than or equal to ThresholdV2 - // 将长度代码设置为 126 - // Set the length code to 126 (*c)[1] += 126 - - // 将长度的值存储在帧头的第 3 到第 4 字节 - // Store the length value in the 3rd to 4th bytes of the frame header binary.BigEndian.PutUint16((*c)[2:4], uint16(n)) - - // 返回 2 偏移量 - // Return 2 offset return 2 - } else { - // 如果长度大于 ThresholdV2 - // If the length is greater than ThresholdV2 - // 将长度代码设置为 127 - // Set the length code to 127 (*c)[1] += 127 - - // 将长度的值存储在帧头的第 3 到第 10 字节 - // Store the length value in the 3rd to 10th bytes of the frame header binary.BigEndian.PutUint64((*c)[2:10], n) - - // 返回 8 偏移量 - // Return 8 offset return 8 } } -// SetMaskKey 方法设置掩码键 -// The SetMaskKey method sets the mask key +// SetMaskKey 设置掩码 +// sets the mask func (c *frameHeader) SetMaskKey(offset int, key [4]byte) { - // 将掩码键复制到帧头的指定偏移量位置 - // Copy the mask key to the specified offset in the frame header copy((*c)[offset:offset+4], key[0:]) } -// GenerateHeader 生成用于写入的帧头 -// GenerateHeader generates a frame header for writing +// GenerateHeader 生成帧头 +// generates a frame header // 可以考虑每个客户端连接带一个随机数发生器 // Consider having a random number generator for each client connection func (c *frameHeader) GenerateHeader(isServer bool, fin bool, compress bool, opcode Opcode, length int) (headerLength int, maskBytes []byte) { - // 初始化帧头长度为 2 - // Initialize the header length to 2 headerLength = 2 - - // 初始化第一个字节为操作码 - // Initialize the first byte with the opcode var b0 = uint8(opcode) - - // 如果是最后一帧,设置 FIN 位 - // If this is the final frame, set the FIN bit if fin { b0 += 128 } - - // 如果需要压缩,设置压缩位 - // If compression is needed, set the compression bit if compress { b0 += 64 } - - // 设置帧头的第一个字节 - // Set the first byte of the frame header (*c)[0] = b0 - - // 设置帧的长度,并增加帧头长度 - // Set the frame length and increase the header length headerLength += c.SetLength(uint64(length)) - // 如果不是服务器,设置掩码位并生成掩码键 - // If not a server, set the mask bit and generate a mask key if !isServer { - // 设置掩码位 - // Set the mask bit (*c)[1] |= 128 - - // 生成一个随机掩码键 - // Generate a random mask key maskNum := internal.AlphabetNumeric.Uint32() - - // 将掩码键写入帧头 - // Write the mask key into the frame header binary.LittleEndian.PutUint32((*c)[headerLength:headerLength+4], maskNum) - - // 设置掩码字节 - // Set the mask bytes maskBytes = (*c)[headerLength : headerLength+4] - - // 增加帧头长度 - // Increase the header length headerLength += 4 } - - // 无效代码 - // Invalid code return } // Parse 解析完整协议头, 最多14字节, 返回payload长度 -// Parse parses the complete protocol header, up to 14 bytes, and returns the payload length +// parses the complete protocol header, up to 14 bytes, and returns the payload length func (c *frameHeader) Parse(reader io.Reader) (int, error) { - // 读取前两个字节到帧头 - // Read the first two bytes into the frame header if err := internal.ReadN(reader, (*c)[0:2]); err != nil { return 0, err } - // 初始化 payload 长度为 0 - // Initialize payload length to 0 var payloadLength = 0 - // 获取长度代码 - // Get the length code var lengthCode = c.GetLengthCode() - - // 根据长度代码解析 payload 长度 - // Parse the payload length based on the length code switch lengthCode { case 126: - // 如果长度代码是 126,读取接下来的两个字节 - // If the length code is 126, read the next two bytes if err := internal.ReadN(reader, (*c)[2:4]); err != nil { return 0, err } - - // 将这两个字节转换为 payload 长度 - // Convert these two bytes to the payload length payloadLength = int(binary.BigEndian.Uint16((*c)[2:4])) case 127: - // 如果长度代码是 127,读取接下来的八个字节 - // If the length code is 127, read the next eight bytes if err := internal.ReadN(reader, (*c)[2:10]); err != nil { return 0, err } - - // 将这八个字节转换为 payload 长度 - // Convert these eight bytes to the payload length payloadLength = int(binary.BigEndian.Uint64((*c)[2:10])) - default: - // 否则,长度代码就是 payload 长度 - // Otherwise, the length code is the payload length payloadLength = int(lengthCode) } - // 检查是否有掩码 - // Check if there is a mask var maskOn = c.GetMask() if maskOn { - // 如果有掩码,读取接下来的四个字节 - // If there is a mask, read the next four bytes if err := internal.ReadN(reader, (*c)[10:14]); err != nil { return 0, err } } - // 返回 payload 长度 - // Return the payload length return payloadLength, nil } -// GetMaskKey 方法返回掩码键 -// The GetMaskKey method returns the mask key -// parser把maskKey放到了末尾 -// The parser places the mask key at the end +// GetMaskKey 返回掩码 +// returns the mask func (c *frameHeader) GetMaskKey() []byte { - // 返回帧头中第 10 到第 14 字节作为掩码键 - // Return the 10th to 14th bytes of the frame header as the mask key return (*c)[10:14] } -// Message 结构体表示一个消息 -// The Message struct represents a message type Message struct { // 是否压缩 - // Indicates if the message is compressed + // if the message is compressed compressed bool // 操作码 - // The opcode of the message + // opcode of the message Opcode Opcode // 消息内容 - // The content of the message + // content of the message Data *bytes.Buffer } // Read 从消息中读取数据到给定的字节切片 p 中 -// Read reads data from the message into the given byte slice p +// reads data from the message into the given byte slice p func (c *Message) Read(p []byte) (n int, err error) { - // 从消息的数据缓冲区中读取数据 - // Read data from the message's data buffer return c.Data.Read(p) } // Bytes 返回消息的数据缓冲区的字节切片 -// Bytes returns the byte slice of the message's data buffer +// returns the byte slice of the message's data buffer func (c *Message) Bytes() []byte { return c.Data.Bytes() } -// Close 回收缓冲区 -// Close recycles the buffer +// Close 关闭消息, 回收资源 +// close message, recycling resources func (c *Message) Close() error { - // 将数据缓冲区放回缓冲池 - // Put the data buffer back into the buffer pool binaryPool.Put(c.Data) - - // 将数据缓冲区设置为 nil - // Set the data buffer to nil c.Data = nil - - // 返回 nil 表示没有错误 - // Return nil to indicate no error return nil } -// continuationFrame 结构体表示一个延续帧 -// The continuationFrame struct represents a continuation frame type continuationFrame struct { // 是否已初始化 // Indicates if the frame is initialized @@ -489,68 +312,41 @@ type continuationFrame struct { buffer *bytes.Buffer } -// reset 方法重置延续帧的状态 -// The reset method resets the state of the continuation frame +// reset 重置延续帧的状态 +// resets the state of the continuation frame func (c *continuationFrame) reset() { - // 将 initialized 设置为 false - // Set initialized to false c.initialized = false - - // 将 compressed 设置为 false - // Set compressed to false c.compressed = false - - // 将 opcode 设置为 0 - // Set opcode to 0 c.opcode = 0 - - // 将 buffer 设置为 nil - // Set buffer to nil c.buffer = nil } -// Logger 接口定义了一个日志记录器 -// The Logger interface defines a logger +// Logger 日志接口 +// logger interface type Logger interface { - // Error 方法记录错误信息 - // The Error method logs error messages + // Error 打印错误日志 + // Printing the error log Error(v ...any) } -// stdLogger 结构体实现了 Logger 接口 -// The stdLogger struct implements the Logger interface +// stdLogger 标准日志库 +// standard Log Library type stdLogger struct{} -// Error 方法实现了 Logger 接口的 Error 方法,使用标准日志记录错误信息 -// The Error method implements the Logger interface's Error method, using the standard log to record error messages +// Error 打印错误日志 +// Printing the error log func (c *stdLogger) Error(v ...any) { log.Println(v...) } -// Recovery 函数用于从 panic 中恢复,并记录错误信息 -// The Recovery function is used to recover from a panic and log error messages +// Recovery 异常恢复,并记录错误信息 +// Exception recovery with logging of error messages func Recovery(logger Logger) { - // 如果有 panic 发生 - // If a panic occurs if e := recover(); e != nil { - // 定义缓冲区大小为 64KB - // Define the buffer size as 64KB const size = 64 << 10 - - // 创建缓冲区 - // Create a buffer buf := make([]byte, size) - - // 获取当前 goroutine 的堆栈信息 - // Get the stack trace of the current goroutine buf = buf[:runtime.Stack(buf, false)] - - // 将缓冲区转换为字符串 - // Convert the buffer to a string msg := *(*string)(unsafe.Pointer(&buf)) - - // 记录错误信息,包括 panic 的值和堆栈信息 - // Log the error message, including the panic value and stack trace logger.Error("fatal error:", e, msg) } } diff --git a/upgrader.go b/upgrader.go index 6eeb7bbb..df515f0c 100644 --- a/upgrader.go +++ b/upgrader.go @@ -14,8 +14,6 @@ import ( "github.com/lxzan/gws/internal" ) -// responseWriter 结构体定义 -// responseWriter struct definition type responseWriter struct { // 错误信息 // Error information @@ -30,529 +28,236 @@ type responseWriter struct { subprotocol string } -// Init 方法初始化 responseWriter 结构体 -// Init method initializes the responseWriter struct +// Init 初始化 +// initializes the responseWriter struct func (c *responseWriter) Init() *responseWriter { - // 从 binaryPool 获取一个大小为 512 的缓冲区 - // Get a buffer of size 512 from binaryPool c.b = binaryPool.Get(512) - - // 写入 HTTP 101 切换协议的响应头 - // Write the HTTP 101 Switching Protocols response header c.b.WriteString("HTTP/1.1 101 Switching Protocols\r\n") - - // 写入 Upgrade: websocket 头 - // Write the Upgrade: websocket header c.b.WriteString("Upgrade: websocket\r\n") - - // 写入 Connection: Upgrade 头 - // Write the Connection: Upgrade header c.b.WriteString("Connection: Upgrade\r\n") - - // 返回初始化后的 responseWriter 结构体 - // Return the initialized responseWriter struct return c } -// Close 关闭 responseWriter 并将缓冲区放回池中 -// Close closes the responseWriter and puts the buffer back into the pool +// Close 回收资源 +// recycling resources func (c *responseWriter) Close() { - // 将缓冲区放回池中 - // Put the buffer back into the pool binaryPool.Put(c.b) - - // 将缓冲区指针置为 nil - // Set the buffer pointer to nil c.b = nil } -// WithHeader 向缓冲区中添加一个 HTTP 头部 -// WithHeader adds an HTTP header to the buffer +// WithHeader 添加 HTTP Header +// adds an http header func (c *responseWriter) WithHeader(k, v string) { - // 写入头部键 - // Write the header key c.b.WriteString(k) - - // 写入冒号和空格 - // Write the colon and space c.b.WriteString(": ") - - // 写入头部值 - // Write the header value c.b.WriteString(v) - - // 写入回车换行符 - // Write the carriage return and newline characters c.b.WriteString("\r\n") } -// WithExtraHeader 向缓冲区中添加多个 HTTP 头部 -// WithExtraHeader adds multiple HTTP headers to the buffer +// WithExtraHeader 添加额外的 HTTP Header +// adds extra http header func (c *responseWriter) WithExtraHeader(h http.Header) { - // 遍历所有头部 - // Iterate over all headers for k, _ := range h { - // 添加每个头部键值对 - // Add each header key-value pair c.WithHeader(k, h.Get(k)) } } // WithSubProtocol 根据请求头和预期的子协议列表设置子协议 -// WithSubProtocol sets the subprotocol based on the request header and the expected subprotocols list +// sets the subprotocol based on the request header and the expected subprotocols list func (c *responseWriter) WithSubProtocol(requestHeader http.Header, expectedSubProtocols []string) { - // 如果预期的子协议列表不为空 - // If the expected subprotocols list is not empty if len(expectedSubProtocols) > 0 { - // 获取请求头中与预期子协议列表的交集元素 - // Get the intersection element from the request header and the expected subprotocols list c.subprotocol = internal.GetIntersectionElem(expectedSubProtocols, internal.Split(requestHeader.Get(internal.SecWebSocketProtocol.Key), ",")) - - // 如果没有匹配的子协议 - // If there is no matching subprotocol if c.subprotocol == "" { - // 设置错误为子协议协商失败 - // Set the error to subprotocol negotiation failure c.err = ErrSubprotocolNegotiation return } - - // 添加子协议头部 - // Add the subprotocol header c.WithHeader(internal.SecWebSocketProtocol.Key, c.subprotocol) } } // Write 将缓冲区内容写入连接,并设置超时 -// Write writes the buffer content to the connection and sets the timeout +// writes the buffer content to the connection and sets the timeout func (c *responseWriter) Write(conn net.Conn, timeout time.Duration) error { - // 如果存在错误 - // If there is an error if c.err != nil { return c.err } - - // 在缓冲区末尾添加回车换行符 - // Add carriage return and newline characters at the end of the buffer c.b.WriteString("\r\n") - - // 设置连接的写入超时 - // Set the write timeout for the connection if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { return err } - - // 将缓冲区内容写入连接 - // Write the buffer content to the connection if _, err := c.b.WriteTo(conn); err != nil { return err } - - // 重置连接的超时设置 - // Reset the timeout setting for the connection return conn.SetDeadline(time.Time{}) } -// Upgrader 结构体定义,用于处理 WebSocket 升级 -// Upgrader struct definition, used for handling WebSocket upgrades type Upgrader struct { - // 服务器选项 - // Server options - option *ServerOption - - // deflater 池 - // Deflater pool + option *ServerOption deflaterPool *deflaterPool - - // 事件处理器 - // Event handler eventHandler Event } // NewUpgrader 创建一个新的 Upgrader 实例 -// NewUpgrader creates a new instance of Upgrader +// creates a new instance of Upgrader func NewUpgrader(eventHandler Event, option *ServerOption) *Upgrader { - // 初始化 Upgrader 实例 - // Initialize the Upgrader instance u := &Upgrader{ - // 初始化服务器选项 - // Initialize server options - option: initServerOption(option), - - // 设置事件处理器 - // Set the event handler + option: initServerOption(option), eventHandler: eventHandler, - - // 创建新的 deflater 池 - // Create a new deflater pool deflaterPool: new(deflaterPool), } - - // 如果启用了 PermessageDeflate - // If PermessageDeflate is enabled if u.option.PermessageDeflate.Enabled { - // 初始化 deflater 池 - // Initialize the deflater pool u.deflaterPool.initialize(u.option.PermessageDeflate, option.ReadMaxPayloadSize) } - - // 返回 Upgrader 实例 - // Return the Upgrader instance return u } // hijack 劫持 HTTP 连接并返回底层的网络连接和缓冲读取器 // hijack hijacks the HTTP connection and returns the underlying network connection and buffered reader func (c *Upgrader) hijack(w http.ResponseWriter) (net.Conn, *bufio.Reader, error) { - // 尝试将响应写入器转换为 Hijacker 接口 - // Attempt to cast the response writer to the Hijacker interface hj, ok := w.(http.Hijacker) - - // 如果转换失败,返回错误 - // If the cast fails, return an error if !ok { return nil, nil, internal.CloseInternalServerErr } - - // 劫持连接,获取底层网络连接 - // Hijack the connection to get the underlying network connection netConn, _, err := hj.Hijack() - - // 如果劫持失败,返回错误 - // If hijacking fails, return an error if err != nil { return nil, nil, err } - - // 从连接池中获取一个缓冲读取器 - // Get a buffered reader from the connection pool br := c.option.config.brPool.Get() - - // 重置缓冲读取器以使用新的网络连接 - // Reset the buffered reader to use the new network connection br.Reset(netConn) - - // 返回网络连接和缓冲读取器 - // Return the network connection and buffered reader return netConn, br, nil } // getPermessageDeflate 根据客户端和服务器的扩展协商结果获取 PermessageDeflate 配置 -// getPermessageDeflate gets the PermessageDeflate configuration based on the negotiation results between the client and server extensions +// gets the PermessageDeflate configuration based on the negotiation results between the client and server extensions func (c *Upgrader) getPermessageDeflate(extensions string) PermessageDeflate { - // 从客户端扩展字符串中解析出客户端的 PermessageDeflate 配置 - // Parse the client's PermessageDeflate configuration from the extensions string clientPD := permessageNegotiation(extensions) - - // 获取服务器的 PermessageDeflate 配置 - // Get the server's PermessageDeflate configuration serverPD := c.option.PermessageDeflate - - // 初始化 PermessageDeflate 配置 - // Initialize the PermessageDeflate configuration pd := PermessageDeflate{ - // 启用状态取决于服务器是否启用并且扩展字符串中包含 PermessageDeflate - // Enabled status depends on whether the server is enabled and the extensions string contains PermessageDeflate - Enabled: serverPD.Enabled && strings.Contains(extensions, internal.PermessageDeflate), - - // 设置压缩阈值 - // Set the compression threshold - Threshold: serverPD.Threshold, - - // 设置压缩级别 - // Set the compression level - Level: serverPD.Level, - - // 设置池大小 - // Set the pool size - PoolSize: serverPD.PoolSize, - - // 设置服务器上下文接管 - // Set the server context takeover + Enabled: serverPD.Enabled && strings.Contains(extensions, internal.PermessageDeflate), + Threshold: serverPD.Threshold, + Level: serverPD.Level, + PoolSize: serverPD.PoolSize, ServerContextTakeover: clientPD.ServerContextTakeover && serverPD.ServerContextTakeover, - - // 设置客户端上下文接管 - // Set the client context takeover ClientContextTakeover: clientPD.ClientContextTakeover && serverPD.ClientContextTakeover, - - // 设置服务器最大窗口位 - // Set the server max window bits - ServerMaxWindowBits: serverPD.ServerMaxWindowBits, - - // 设置客户端最大窗口位 - // Set the client max window bits - ClientMaxWindowBits: serverPD.ClientMaxWindowBits, + ServerMaxWindowBits: serverPD.ServerMaxWindowBits, + ClientMaxWindowBits: serverPD.ClientMaxWindowBits, } - - // 设置压缩阈值 - // Set the compression threshold pd.setThreshold(true) - - // 返回 PermessageDeflate 配置 - // Return the PermessageDeflate configuration return pd } // Upgrade 升级 HTTP 连接到 WebSocket 连接 -// Upgrade upgrades the HTTP connection to a WebSocket connection +// upgrades the HTTP connection to a WebSocket connection func (c *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request) (*Conn, error) { - // 劫持 HTTP 连接,获取底层网络连接和缓冲读取器 - // Hijack the HTTP connection to get the underlying network connection and buffered reader netConn, br, err := c.hijack(w) - - // 如果劫持失败,返回错误 - // If hijacking fails, return an error if err != nil { return nil, err } - - // 从网络连接升级到 WebSocket 连接 - // Upgrade from the network connection to a WebSocket connection return c.UpgradeFromConn(netConn, br, r) } // UpgradeFromConn 从现有的网络连接升级到 WebSocket 连接 -// UpgradeFromConn upgrades from an existing network connection to a WebSocket connection +// upgrades from an existing network connection to a WebSocket connection func (c *Upgrader) UpgradeFromConn(conn net.Conn, br *bufio.Reader, r *http.Request) (*Conn, error) { - // 执行连接升级操作 - // Perform the connection upgrade operation socket, err := c.doUpgradeFromConn(conn, br, r) - - // 如果升级失败,写入错误信息并关闭连接 - // If the upgrade fails, write the error message and close the connection if err != nil { _ = c.writeErr(conn, err) _ = conn.Close() } - - // 返回 WebSocket 连接和错误信息 - // Return the WebSocket connection and error information return socket, err } // writeErr 向客户端写入 HTTP 错误响应 -// writeErr writes an HTTP error response to the client +// writes an HTTP error response to the client func (c *Upgrader) writeErr(conn net.Conn, err error) error { - // 获取错误信息字符串 - // Get the error message string var str = err.Error() - - // 从缓冲池中获取一个缓冲区 - // Get a buffer from the buffer pool var buf = binaryPool.Get(256) - - // 写入 HTTP 状态行 - // Write the HTTP status line buf.WriteString("HTTP/1.1 400 Bad Request\r\n") - - // 写入当前日期 - // Write the current date buf.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") - - // 写入内容长度 - // Write the content length buf.WriteString("Content-Length: " + strconv.Itoa(len(str)) + "\r\n") - - // 写入内容类型 - // Write the content type buf.WriteString("Content-Type: text/plain; charset=utf-8\r\n") - - // 写入空行,表示头部结束 - // Write an empty line to indicate the end of the headers buf.WriteString("\r\n") - - // 写入错误信息 - // Write the error message buf.WriteString(str) - - // 将缓冲区内容写入连接 - // Write the buffer content to the connection _, result := buf.WriteTo(conn) - - // 将缓冲区放回缓冲池 - // Put the buffer back into the buffer pool binaryPool.Put(buf) - - // 返回写入结果 - // Return the write result return result } // doUpgradeFromConn 从现有的网络连接升级到 WebSocket 连接 -// doUpgradeFromConn upgrades from an existing network connection to a WebSocket connection +// upgrades from an existing network connection to a WebSocket connection func (c *Upgrader) doUpgradeFromConn(netConn net.Conn, br *bufio.Reader, r *http.Request) (*Conn, error) { - // 创建一个新的会话 - // Create a new session - var session = c.option.NewSession() - // 授权请求,如果授权失败,返回未授权错误 // Authorize the request, if authorization fails, return an unauthorized error + var session = c.option.NewSession() if !c.option.Authorize(r, session) { return nil, ErrUnauthorized } - // 检查请求方法是否为 GET,如果不是,返回握手错误 - // Check if the request method is GET, if not, return a handshake error + // 检查请求头 + // check request headers if r.Method != http.MethodGet { return nil, ErrHandshake } - - // 检查 WebSocket 版本是否支持,如果不支持,返回错误 - // Check if the WebSocket version is supported, if not, return an error if !strings.EqualFold(r.Header.Get(internal.SecWebSocketVersion.Key), internal.SecWebSocketVersion.Val) { return nil, errors.New("gws: websocket version not supported") } - - // 检查 Connection 头是否包含正确的值,如果不包含,返回握手错误 - // Check if the Connection header contains the correct value, if not, return a handshake error if !internal.HttpHeaderContains(r.Header.Get(internal.Connection.Key), internal.Connection.Val) { return nil, ErrHandshake } - - // 检查 Upgrade 头是否包含正确的值,如果不包含,返回握手错误 - // Check if the Upgrade header contains the correct value, if not, return a handshake error if !strings.EqualFold(r.Header.Get(internal.Upgrade.Key), internal.Upgrade.Val) { return nil, ErrHandshake } - // 初始化响应写入器 - // Initialize the response writer var rw = new(responseWriter).Init() defer rw.Close() - // 获取扩展头 - // Get the extensions header var extensions = r.Header.Get(internal.SecWebSocketExtensions.Key) - - // 获取 PermessageDeflate 配置 - // Get the PermessageDeflate configuration var pd = c.getPermessageDeflate(extensions) - - // 如果启用了 PermessageDeflate,添加相应的响应头 - // If PermessageDeflate is enabled, add the corresponding response header if pd.Enabled { rw.WithHeader(internal.SecWebSocketExtensions.Key, pd.genResponseHeader()) } - // 获取 WebSocket 密钥 - // Get the WebSocket key var websocketKey = r.Header.Get(internal.SecWebSocketKey.Key) - - // 如果 WebSocket 密钥为空,返回握手错误 - // If the WebSocket key is empty, return a handshake error if websocketKey == "" { return nil, ErrHandshake } - - // 添加 Sec-WebSocket-Accept 头 - // Add the Sec-WebSocket-Accept header rw.WithHeader(internal.SecWebSocketAccept.Key, internal.ComputeAcceptKey(websocketKey)) - - // 添加子协议头 - // Add the subprotocol header rw.WithSubProtocol(r.Header, c.option.SubProtocols) - - // 添加额外的响应头 - // Add extra response headers rw.WithExtraHeader(c.option.ResponseHeader) - - // 写入响应,如果失败,返回错误 - // Write the response, if it fails, return an error if err := rw.Write(netConn, c.option.HandshakeTimeout); err != nil { return nil, err } - // 获取配置选项 - // Get configuration options config := c.option.getConfig() - - // 创建 WebSocket 连接实例 - // Create a WebSocket connection instance socket := &Conn{ - // 会话 - // Session - ss: session, - - // 是否为服务器端 - // Is server side - isServer: true, - - // 子协议 - // Subprotocol - subprotocol: rw.subprotocol, - - // PermessageDeflate 配置 - // PermessageDeflate configuration - pd: pd, - - // 网络连接 - // Network connection - conn: netConn, - - // 配置 - // Configuration - config: config, - - // 缓冲读取器 - // Buffered reader - br: br, - - // 连续帧 - // Continuation frame + ss: session, + isServer: true, + subprotocol: rw.subprotocol, + pd: pd, + conn: netConn, + config: config, + br: br, continuationFrame: continuationFrame{}, - - // 帧头 - // Frame header - fh: frameHeader{}, - - // 事件处理器 - // Event handler - handler: c.eventHandler, - - // 关闭状态 - // Closed status - closed: 0, - - // 写队列,最大并发数为 1 - // Write queue,maximum concurrency is 1 - writeQueue: workerQueue{maxConcurrency: 1}, - - // 读队列,channel 缓冲区大小为 option 的 ParallelGolimit - // Read queue,channel buffer size is option's ParallelGolimit - readQueue: make(channel, c.option.ParallelGolimit), + fh: frameHeader{}, + handler: c.eventHandler, + closed: 0, + writeQueue: workerQueue{maxConcurrency: 1}, + readQueue: make(channel, c.option.ParallelGolimit), } - - // 如果启用了 PermessageDeflate - // If PermessageDeflate is enabled if pd.Enabled { - // 选择 deflater - // Select the deflater socket.deflater = c.deflaterPool.Select() - - // 如果服务器上下文接管启用 - // If server context takeover is enabled if c.option.PermessageDeflate.ServerContextTakeover { - // 初始化服务器上下文窗口 - // Initialize the server context window socket.cpsWindow.initialize(config.cswPool, c.option.PermessageDeflate.ServerMaxWindowBits) } - - // 如果客户端上下文接管启用 - // If client context takeover is enabled if c.option.PermessageDeflate.ClientContextTakeover { - // 初始化客户端上下文窗口 - // Initialize the client context window socket.dpsWindow.initialize(config.dswPool, c.option.PermessageDeflate.ClientMaxWindowBits) } } - - // 返回 WebSocket 连接 - // Return the WebSocket connection return socket, nil } -// Server 结构体定义,用于处理 WebSocket 服务器的相关操作 -// Server struct definition, used for handling WebSocket server-related operations +// Server WebSocket服务器 +// websocket server type Server struct { // 升级器,用于将 HTTP 连接升级到 WebSocket 连接 // Upgrader, used to upgrade HTTP connections to WebSocket connections @@ -572,158 +277,78 @@ type Server struct { } // NewServer 创建一个新的 WebSocket 服务器实例 -// NewServer creates a new WebSocket server instance +// creates a new WebSocket server instance func NewServer(eventHandler Event, option *ServerOption) *Server { - // 初始化服务器实例,并设置升级器 - // Initialize the server instance and set the upgrader var c = &Server{upgrader: NewUpgrader(eventHandler, option)} - - // 设置服务器选项配置 - // Set the server option configuration c.option = c.upgrader.option - - // 设置默认的错误处理回调函数 - // Set the default error handling callback function - c.OnError = func(conn net.Conn, err error) { - // 记录错误日志 - // Log the error - c.option.Logger.Error("gws: " + err.Error()) - } - - // 设置默认的请求处理回调函数 - // Set the default request handling callback function + c.OnError = func(conn net.Conn, err error) { c.option.Logger.Error("gws: " + err.Error()) } c.OnRequest = func(conn net.Conn, br *bufio.Reader, r *http.Request) { - // 尝试将 HTTP 连接升级到 WebSocket 连接 - // Attempt to upgrade the HTTP connection to a WebSocket connection socket, err := c.GetUpgrader().UpgradeFromConn(conn, br, r) - - // 如果升级失败,调用错误处理回调函数 - // If the upgrade fails, call the error handling callback function if err != nil { c.OnError(conn, err) } else { - // 否则,启动 WebSocket 连接的读取循环 - // Otherwise, start the read loop for the WebSocket connection socket.ReadLoop() } } - - // 返回服务器实例 - // Return the server instance return c } // GetUpgrader 获取服务器的升级器实例 -// GetUpgrader retrieves the upgrader instance of the server +// retrieves the upgrader instance of the server func (c *Server) GetUpgrader() *Upgrader { return c.upgrader } // Run 启动 WebSocket 服务器,监听指定地址 -// Run starts the WebSocket server and listens on the specified address +// starts the WebSocket server and listens on the specified address func (c *Server) Run(addr string) error { - // 创建 TCP 监听器 - // Create a TCP listener listener, err := net.Listen("tcp", addr) - - // 如果监听失败,返回错误 - // If listening fails, return an error if err != nil { return err } - - // 使用监听器运行服务器 - // Run the server using the listener return c.RunListener(listener) } // RunTLS 启动支持 TLS 的 WebSocket 服务器,监听指定地址 -// RunTLS starts the WebSocket server with TLS support and listens on the specified address +// starts the WebSocket server with TLS support and listens on the specified address func (c *Server) RunTLS(addr string, certFile, keyFile string) error { - // 加载 TLS 证书和私钥 - // Load the TLS certificate and private key cert, err := tls.LoadX509KeyPair(certFile, keyFile) - - // 如果加载失败,返回错误 - // If loading fails, return an error if err != nil { return err } - // 如果服务器的 TLS 配置为空,初始化一个新的配置 - // If the server's TLS configuration is nil, initialize a new configuration if c.option.TlsConfig == nil { c.option.TlsConfig = &tls.Config{} } - - // 克隆服务器的 TLS 配置 - // Clone the server's TLS configuration config := c.option.TlsConfig.Clone() - - // 设置证书 - // Set the certificate config.Certificates = []tls.Certificate{cert} - - // 设置下一个协议为 HTTP/1.1 - // Set the next protocol to HTTP/1.1 config.NextProtos = []string{"http/1.1"} - // 创建 TCP 监听器 - // Create a TCP listener listener, err := net.Listen("tcp", addr) - - // 如果监听失败,返回错误 - // If listening fails, return an error if err != nil { return err } - - // 使用 TLS 监听器运行服务器 - // Run the server using the TLS listener return c.RunListener(tls.NewListener(listener, config)) } // RunListener 使用指定的监听器运行 WebSocket 服务器 -// RunListener runs the WebSocket server using the specified listener +// runs the WebSocket server using the specified listener func (c *Server) RunListener(listener net.Listener) error { - // 确保在函数返回时关闭监听器 - // Ensure the listener is closed when the function returns defer listener.Close() - // 无限循环,接受新的连接 - // Infinite loop to accept new connections for { - // 接受新的网络连接 - // Accept a new network connection netConn, err := listener.Accept() - - // 如果接受连接时发生错误,调用错误处理回调函数并继续 - // If an error occurs while accepting the connection, call the error handling callback and continue if err != nil { c.OnError(netConn, err) continue } - // 启动一个新的 goroutine 处理连接 - // Start a new goroutine to handle the connection go func(conn net.Conn) { - // 从缓冲池中获取一个缓冲读取器 - // Get a buffered reader from the buffer pool br := c.option.config.brPool.Get() - - // 重置缓冲读取器以使用新的连接 - // Reset the buffered reader to use the new connection br.Reset(conn) - - // 尝试读取 HTTP 请求 - // Attempt to read the HTTP request if r, err := http.ReadRequest(br); err != nil { - // 如果读取请求失败,调用错误处理回调函数 - // If reading the request fails, call the error handling callback c.OnError(conn, err) } else { - // 如果读取请求成功,调用请求处理回调函数 - // If reading the request succeeds, call the request handling callback c.OnRequest(conn, br, r) } }(netConn) diff --git a/writer.go b/writer.go index 4939a3e8..311685cc 100644 --- a/writer.go +++ b/writer.go @@ -10,7 +10,7 @@ import ( "github.com/lxzan/gws/internal" ) -// WriteClose 发送关闭帧, 主动断开连接 +// WriteClose 发送关闭帧并断开连接 // 没有特殊需求的话, 推荐code=1000, reason=nil // Send shutdown frame, active disconnection // If you don't have any special needs, we recommend code=1000, reason=nil @@ -23,85 +23,63 @@ func (c *Conn) WriteClose(code uint16, reason []byte) { c.emitError(err) } -// WritePing 写入Ping消息, 携带的信息不要超过125字节 +// WritePing +// 写入Ping消息, 携带的信息不要超过125字节 // Control frame length cannot exceed 125 bytes func (c *Conn) WritePing(payload []byte) error { return c.WriteMessage(OpcodePing, payload) } -// WritePong 写入Pong消息, 携带的信息不要超过125字节 +// WritePong +// 写入Pong消息, 携带的信息不要超过125字节 // Control frame length cannot exceed 125 bytes func (c *Conn) WritePong(payload []byte) error { return c.WriteMessage(OpcodePong, payload) } -// WriteString 写入文本消息, 使用UTF8编码. +// WriteString +// 写入文本消息, 使用UTF8编码. // Write text messages, should be encoded in UTF8. func (c *Conn) WriteString(s string) error { return c.WriteMessage(OpcodeText, internal.StringToBytes(s)) } -// WriteMessage 写入文本/二进制消息, 文本消息应该使用UTF8编码 -// WriteMessage writes text/binary messages, text messages should be encoded in UTF8. +// WriteMessage +// 写入文本/二进制消息, 文本消息应该使用UTF8编码 +// Writes text/binary messages, text messages should be encoded in UTF8. func (c *Conn) WriteMessage(opcode Opcode, payload []byte) error { - // 调用 doWrite 方法写入消息 - // Call the doWrite method to write the message err := c.doWrite(opcode, internal.Bytes(payload)) - - // 触发错误处理 - // Emit error handling c.emitError(err) - - // 返回错误信息 - // Return the error return err } // WriteAsync 异步写 -// WriteAsync writes messages asynchronously +// Writes messages asynchronously // 异步非阻塞地将消息写入到任务队列, 收到回调后才允许回收payload内存 -// Write messages to the task queue asynchronously and non-blockingly, allowing payload memory to be recycled only after receiving the callback +// Write messages to the task queue asynchronously and non-blockingly, +// allowing payload memory to be recycled only after receiving the callback func (c *Conn) WriteAsync(opcode Opcode, payload []byte, callback func(error)) { - // 将写操作推送到写队列中 - // Push the write operation to the write queue c.writeQueue.Push(func() { - // 调用 WriteMessage 方法写入消息 - // Call the WriteMessage method to write the message if err := c.WriteMessage(opcode, payload); callback != nil { - // 如果有回调函数,调用回调函数并传递错误信息 - // If there is a callback function, call it and pass the error callback(err) } }) } -// Writev 类似 WriteMessage, 区别是可以一次写入多个切片 +// Writev +// 类似 WriteMessage, 区别是可以一次写入多个切片 // Writev is similar to WriteMessage, except that you can write multiple slices at once. func (c *Conn) Writev(opcode Opcode, payloads ...[]byte) error { - // 调用 doWrite 方法写入多个切片 - // Call the doWrite method to write multiple slices var err = c.doWrite(opcode, internal.Buffers(payloads)) - - // 触发错误处理 - // Emit error handling c.emitError(err) - - // 返回错误信息 - // Return the error return err } // WritevAsync 类似 WriteAsync, 区别是可以一次写入多个切片 -// WritevAsync is similar to WriteAsync, except that you can write multiple slices at once. +// It's similar to WriteAsync, except that you can write multiple slices at once. func (c *Conn) WritevAsync(opcode Opcode, payloads [][]byte, callback func(error)) { - // 将写操作推送到写队列中 - // Push the write operation to the write queue c.writeQueue.Push(func() { - // 调用 Writev 方法写入多个切片 - // Call the Writev method to write multiple slices if err := c.Writev(opcode, payloads...); callback != nil { - // 如果有回调函数,调用回调函数并传递错误信息 - // If there is a callback function, call it and pass the error callback(err) } }) @@ -119,313 +97,148 @@ func (c *Conn) Async(f func()) { // 执行写入逻辑, 注意妥善维护压缩字典 // doWrite executes the write logic, ensuring proper maintenance of the compression dictionary func (c *Conn) doWrite(opcode Opcode, payload internal.Payload) error { - // 加锁以确保线程安全 - // Lock to ensure thread safety c.mu.Lock() - // 在函数结束时解锁 - // Unlock at the end of the function defer c.mu.Unlock() - // 如果操作码不是关闭连接且连接已关闭,返回连接关闭错误 - // If the opcode is not CloseConnection and the connection is closed, return a connection closed error if opcode != OpcodeCloseConnection && c.isClosed() { return ErrConnClosed } - // 生成帧数据 - // Generate the frame data frame, err := c.genFrame(opcode, payload, false) if err != nil { return err } - // 将帧数据写入连接 - // Write the frame data to the connection err = internal.WriteN(c.conn, frame.Bytes()) - - // 将 payload 写入压缩窗口 - // Write the payload to the compression window _, _ = payload.WriteTo(&c.cpsWindow) - - // 将帧放回缓冲池 - // Put the frame back into the buffer pool binaryPool.Put(frame) - - // 返回写入操作的错误信息 - // Return the error from the write operation return err } // genFrame 生成帧数据 -// genFrame generates the frame data +// generates the frame data func (c *Conn) genFrame(opcode Opcode, payload internal.Payload, isBroadcast bool) (*bytes.Buffer, error) { - // 如果操作码是文本且编码检查未通过,返回不支持的数据错误 - // If the opcode is text and the encoding check fails, return an unsupported data error if opcode == OpcodeText && !payload.CheckEncoding(c.config.CheckUtf8Enabled, uint8(opcode)) { return nil, internal.NewError(internal.CloseUnsupportedData, ErrTextEncoding) } - // 获取负载的长度 - // Get the length of the payload var n = payload.Len() - // 如果负载长度超过配置的最大负载大小,返回消息过大错误 - // If the payload length exceeds the configured maximum payload size, return a message too large error if n > c.config.WriteMaxPayloadSize { return nil, internal.CloseMessageTooLarge } - // 从缓冲池获取一个缓冲区,大小为负载长度加上帧头大小 - // Get a buffer from the buffer pool, with size equal to payload length plus frame header size var buf = binaryPool.Get(n + frameHeaderSize) - - // 写入帧填充数据 - // Write frame padding data buf.Write(framePadding[0:]) - // 如果启用了压缩且操作码是数据帧且负载长度大于等于压缩阈值,进行数据压缩 - // If compression is enabled, the opcode is a data frame, and the payload length is greater than or equal to the compression threshold, compress the data if c.pd.Enabled && opcode.isDataFrame() && n >= c.pd.Threshold { return c.compressData(buf, opcode, payload, isBroadcast) } - // 生成帧头 - // Generate the frame header var header = frameHeader{} headerLength, maskBytes := header.GenerateHeader(c.isServer, true, false, opcode, n) - - // 将负载写入缓冲区 - // Write the payload to the buffer _, _ = payload.WriteTo(buf) - - // 获取缓冲区的字节切片 - // Get the byte slice of the buffer var contents = buf.Bytes() - - // 如果不是服务器端,进行掩码异或操作 - // If not server-side, perform mask XOR operation if !c.isServer { internal.MaskXOR(contents[frameHeaderSize:], maskBytes) } - - // 计算帧头的偏移量 - // Calculate the offset of the frame header var m = frameHeaderSize - headerLength - - // 将帧头复制到缓冲区 - // Copy the frame header to the buffer copy(contents[m:], header[:headerLength]) - - // 调整缓冲区的读取位置 - // Adjust the read position of the buffer buf.Next(m) - - // 返回缓冲区和 nil 错误 - // Return the buffer and nil error return buf, nil } // compressData 压缩数据并生成帧 -// compressData compresses the data and generates the frame +// compresses the data and generates the frame func (c *Conn) compressData(buf *bytes.Buffer, opcode Opcode, payload internal.Payload, isBroadcast bool) (*bytes.Buffer, error) { - // 使用 deflater 压缩数据并写入缓冲区 - // Use deflater to compress the data and write it to the buffer err := c.deflater.Compress(payload, buf, c.getCpsDict(isBroadcast)) if err != nil { return nil, err } - - // 获取缓冲区的字节切片 - // Get the byte slice of the buffer var contents = buf.Bytes() - - // 计算压缩后的负载大小 - // Calculate the size of the compressed payload var payloadSize = buf.Len() - frameHeaderSize - - // 生成帧头 - // Generate the frame header var header = frameHeader{} headerLength, maskBytes := header.GenerateHeader(c.isServer, true, true, opcode, payloadSize) - - // 如果不是服务器端,进行掩码异或操作 - // If not server-side, perform mask XOR operation if !c.isServer { internal.MaskXOR(contents[frameHeaderSize:], maskBytes) } - - // 计算帧头的偏移量 - // Calculate the offset of the frame header var m = frameHeaderSize - headerLength - - // 将帧头复制到缓冲区 - // Copy the frame header to the buffer copy(contents[m:], header[:headerLength]) - - // 调整缓冲区的读取位置 - // Adjust the read position of the buffer buf.Next(m) - - // 返回缓冲区和 nil 错误 - // Return the buffer and nil error return buf, nil } type ( - // Broadcaster 结构体用于广播消息 - // Broadcaster struct is used for broadcasting messages Broadcaster struct { - // opcode 表示操作码 - // opcode represents the operation code - opcode Opcode - - // payload 表示消息的负载 - // payload represents the message payload + opcode Opcode payload []byte - - // msgs 是一个包含两个广播消息包装器的数组 - // msgs is an array containing two broadcast message wrappers - msgs [2]*broadcastMessageWrapper - - // state 表示广播器的状态 - // state represents the state of the broadcaster - state int64 + msgs [2]*broadcastMessageWrapper + state int64 } - // broadcastMessageWrapper 结构体用于包装广播消息 - // broadcastMessageWrapper struct is used to wrap broadcast messages broadcastMessageWrapper struct { - // once 用于确保某些操作只执行一次 - // once is used to ensure certain operations are executed only once - once sync.Once - - // err 表示广播消息的错误状态 - // err represents the error state of the broadcast message - err error - - // frame 表示广播消息的帧数据 - // frame represents the frame data of the broadcast message + once sync.Once + err error frame *bytes.Buffer } ) // NewBroadcaster 创建广播器 -// NewBroadcaster creates a broadcaster +// creates a broadcaster // 相比循环调用 WriteAsync, Broadcaster 只会压缩一次消息, 可以节省大量 CPU 开销. // Instead of calling WriteAsync in a loop, Broadcaster compresses the message only once, saving a lot of CPU overhead. func NewBroadcaster(opcode Opcode, payload []byte) *Broadcaster { - // 初始化一个 Broadcaster 实例 - // Initialize a Broadcaster instance c := &Broadcaster{ - // 设置操作码 - // Set the operation code - opcode: opcode, - - // 设置消息负载 - // Set the message payload + opcode: opcode, payload: payload, - - // 初始化广播消息包装器数组 - // Initialize the broadcast message wrapper array - msgs: [2]*broadcastMessageWrapper{{}, {}}, - - // 设置初始状态 - // Set the initial state - state: int64(math.MaxInt32), + msgs: [2]*broadcastMessageWrapper{{}, {}}, + state: int64(math.MaxInt32), } - - // 返回 Broadcaster 实例 - // Return the Broadcaster instance return c } // writeFrame 将帧数据写入连接 -// writeFrame writes the frame data to the connection +// writes the frame data to the connection func (c *Broadcaster) writeFrame(socket *Conn, frame *bytes.Buffer) error { - // 如果连接已关闭,返回连接关闭错误 - // If the connection is closed, return a connection closed error if socket.isClosed() { return ErrConnClosed } - - // 加锁以确保线程安全 - // Lock to ensure thread safety socket.mu.Lock() - - // 写入帧数据到连接 - // Write the frame data to the connection var err = internal.WriteN(socket.conn, frame.Bytes()) - - // 将负载写入压缩窗口 - // Write the payload to the compression window socket.cpsWindow.Write(c.payload) - - // 解锁 - // Unlock socket.mu.Unlock() - - // 返回写入操作的错误信息 - // Return the error from the write operation return err } +// Broadcast 广播 +// 向客户端发送广播消息 +// Send a broadcast message to a client. func (c *Broadcaster) Broadcast(socket *Conn) error { - // 根据是否启用压缩选择索引值 - // Select index value based on whether compression is enabled var idx = internal.SelectValue(socket.pd.Enabled, 1, 0) - // 获取对应索引的广播消息包装器 - // Get the broadcast message wrapper for the corresponding index var msg = c.msgs[idx] - // 使用 sync.Once 确保帧数据只生成一次 - // Use sync.Once to ensure the frame data is generated only once msg.once.Do(func() { - // 生成帧数据 - // Generate the frame data msg.frame, msg.err = socket.genFrame(c.opcode, internal.Bytes(c.payload), true) }) - - // 如果生成帧数据时发生错误,返回错误 - // If there is an error generating the frame data, return the error if msg.err != nil { return msg.err } - // 原子性地增加广播器的状态值 - // Atomically increment the state value of the broadcaster atomic.AddInt64(&c.state, 1) - - // 将写入操作推入连接的写队列 - // Push the write operation into the connection's write queue socket.writeQueue.Push(func() { - // 将帧数据写入连接 - // Write the frame data to the connection var err = c.writeFrame(socket, msg.frame) - - // 触发错误事件 - // Emit the error event socket.emitError(err) - - // 原子性地减少广播器的状态值,如果状态值为 0,关闭广播器 - // Atomically decrement the state value of the broadcaster, if the state value is 0, close the broadcaster if atomic.AddInt64(&c.state, -1) == 0 { c.doClose() } }) - - // 返回 nil 表示成功 - // Return nil to indicate success return nil } // doClose 关闭广播器并释放资源 -// doClose closes the broadcaster and releases resources +// closes the broadcaster and releases resources func (c *Broadcaster) doClose() { - // 遍历广播消息包装器数组 - // Iterate over the broadcast message wrapper array for _, item := range c.msgs { - // 如果包装器不为空,释放其帧数据 - // If the wrapper is not nil, release its frame data if item != nil { binaryPool.Put(item.frame) } @@ -433,19 +246,12 @@ func (c *Broadcaster) doClose() { } // Close 释放资源 -// Close releases resources +// releases resources // 在完成所有 Broadcast 调用之后执行 Close 方法释放资源。 // Call the Close method after all the Broadcasts have been completed to release the resources. func (c *Broadcaster) Close() error { - // 原子性地减少广播器的状态值 - // Atomically decrement the state value of the broadcaster if atomic.AddInt64(&c.state, -1*math.MaxInt32) == 0 { - // 如果状态值为 0,关闭广播器并释放资源 - // If the state value is 0, close the broadcaster and release resources c.doClose() } - - // 返回 nil 表示成功 - // Return nil to indicate success return nil }