Skip to content

Commit

Permalink
add write reader method
Browse files Browse the repository at this point in the history
  • Loading branch information
lxzan committed Aug 18, 2024
1 parent 3dc044f commit 2d79d12
Show file tree
Hide file tree
Showing 10 changed files with 528 additions and 23 deletions.
1 change: 1 addition & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ linters:
# Disable specific linter
# https://golangci-lint.run/usage/linters/#disabled-by-default
disable:
- maintidx
- mnd
- testpackage
- nlreturn
Expand Down
4 changes: 2 additions & 2 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func BenchmarkConn_ReadMessage(b *testing.B) {
conn: &benchConn{},
config: upgrader.option.getConfig(),
}
var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), false)
var buf, _ = conn1.genFrame(OpcodeText, true, false, internal.Bytes(githubData), false)

var reader = bytes.NewBuffer(buf.Bytes())
var conn2 = &Conn{
Expand Down Expand Up @@ -98,7 +98,7 @@ func BenchmarkConn_ReadMessage(b *testing.B) {
deflater: new(deflater),
}
conn1.deflater.initialize(false, conn1.pd, config.ReadMaxPayloadSize)
var buf, _ = conn1.genFrame(OpcodeText, internal.Bytes(githubData), false)
var buf, _ = conn1.genFrame(OpcodeText, true, true, internal.Bytes(githubData), false)

var reader = bytes.NewBuffer(buf.Bytes())
var conn2 = &Conn{
Expand Down
215 changes: 215 additions & 0 deletions bigfile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package gws

import (
"bytes"
"encoding/binary"
"errors"
"github.com/klauspost/compress/flate"
"github.com/lxzan/gws/internal"
"io"
"math"
)

const segmentSize = 128 * 1024

// 获取大文件压缩器
func (c *Conn) getBigDeflater() *bigDeflater {
if c.isServer {
return c.config.bdPool.Get()
}
return c.deflater.ToBigDeflater()
}

// 回收大文件压缩器
func (c *Conn) putBigDeflater(d *bigDeflater) {
if c.isServer {
c.config.bdPool.Put(d)
}
}

// 拆分io.Reader为小切片
func (c *Conn) splitReader(r io.Reader, f func(index int, eof bool, p []byte) error) error {
var buf = binaryPool.Get(segmentSize)
var p = buf.Bytes()[:segmentSize]
var n, index = 0, 0
var err error
for n, err = r.Read(p); err == nil || errors.Is(err, io.EOF); n, err = r.Read(p) {
eof := errors.Is(err, io.EOF)
if err = f(index, eof, p[:n]); err != nil {
return err
}
index++
if eof {
break
}
}
return err
}

// WriteReader 大文件写入
// 采用分段写入技术, 大大减少内存占用
func (c *Conn) WriteReader(opcode Opcode, payload io.Reader) error {
err := c.doWriteReader(opcode, payload)
c.emitError(err)
return err
}

func (c *Conn) doWriteReader(opcode Opcode, payload io.Reader) error {
c.mu.Lock()
defer c.mu.Unlock()

var cb = func(index int, eof bool, p []byte) error {
op := internal.SelectValue(index == 0, opcode, OpcodeContinuation)
frame, err := c.genFrame(op, eof, false, internal.Bytes(p), false)
if err != nil {
return err
}
if c.pd.Enabled && index == 0 {
frame.Bytes()[0] |= uint8(64)
}
if c.isClosed() {
return ErrConnClosed
}
err = internal.WriteN(c.conn, frame.Bytes())
binaryPool.Put(frame)
return err
}

if c.pd.Enabled {
var deflater = c.getBigDeflater()
var fw = &flateWriter{cb: cb}
err := deflater.Compress(payload, fw, c.getCpsDict(false), &c.cpsWindow)
c.putBigDeflater(deflater)
return err
} else {
return c.splitReader(payload, cb)
}
}

// 大文件压缩器
type bigDeflater struct {
cpsWriter *flate.Writer
}

// 初始化大文件压缩器
// Initialize the bigDeflater
func (c *bigDeflater) initialize(isServer bool, options PermessageDeflate) *bigDeflater {
windowBits := internal.SelectValue(isServer, options.ServerMaxWindowBits, options.ClientMaxWindowBits)
if windowBits == 15 {
c.cpsWriter, _ = flate.NewWriter(nil, options.Level)
} else {
c.cpsWriter, _ = flate.NewWriterWindow(nil, internal.BinaryPow(windowBits))
}
return c
}

// Compress 压缩
func (c *bigDeflater) Compress(src io.Reader, dst *flateWriter, dict []byte, sw *slideWindow) error {
if err := compressTo(c.cpsWriter, &readerWrapper{r: src, sw: sw}, dst, dict); err != nil {
return err
}
return dst.Flush()
}

// 写入代理
// 将切片透传给回调函数, 以实现分段写入功能
type flateWriter struct {
index int
buffers []*bytes.Buffer
cb func(index int, eof bool, p []byte) error
}

// 是否可以执行回调函数
func (c *flateWriter) shouldCall() bool {
var n = len(c.buffers)
if n < 2 {
return false
}
var sum = 0
for i := 1; i < n; i++ {
sum += c.buffers[i].Len()
}
return sum >= 4
}

// 聚合写入, 减少syscall.write次数
func (c *flateWriter) write(p []byte) {
if len(c.buffers) == 0 {
var buf = binaryPool.Get(segmentSize)
c.buffers = append(c.buffers, buf)
}
var n = len(c.buffers)
var tail = c.buffers[n-1]
if tail.Len()+len(p) >= segmentSize {
var buf = binaryPool.Get(segmentSize)
c.buffers = append(c.buffers, buf)
tail = buf
}
tail.Write(p)
}

func (c *flateWriter) Write(p []byte) (n int, err error) {
c.write(p)
if c.shouldCall() {
err = c.cb(c.index, false, c.buffers[0].Bytes())
binaryPool.Put(c.buffers[0])
c.buffers = c.buffers[1:]
c.index++
}
return n, err
}

func (c *flateWriter) Flush() error {
var buf = c.buffers[0]
for i := 1; i < len(c.buffers); i++ {
buf.Write(c.buffers[i].Bytes())
binaryPool.Put(c.buffers[i])
}
if n := buf.Len(); n >= 4 {
compressedContent := buf.Bytes()
if tail := compressedContent[n-4:]; binary.BigEndian.Uint32(tail) == math.MaxUint16 {
buf.Truncate(n - 4)
}
}
var err = c.cb(c.index, true, buf.Bytes())
c.index++
binaryPool.Put(buf)
return err
}

// 将io.Reader包装为io.WriterTo
type readerWrapper struct {
r io.Reader
sw *slideWindow
}

// WriteTo 写入内容, 并更新字典
func (c *readerWrapper) WriteTo(w io.Writer) (int64, error) {
var buf = binaryPool.Get(segmentSize)
defer binaryPool.Put(buf)

var p = buf.Bytes()[:segmentSize]
var sum, n = 0, 0
var err error
for n, err = c.r.Read(p); err == nil || errors.Is(err, io.EOF); n, err = c.r.Read(p) {
eof := errors.Is(err, io.EOF)
if _, err = w.Write(p[:n]); err != nil {
return int64(sum), err
}
sum += n
_, _ = c.sw.Write(p[:n])
if eof {
break
}
}
return int64(sum), err
}

// 压缩公共函数
func compressTo(cpsWriter *flate.Writer, r io.WriterTo, w io.Writer, dict []byte) error {
cpsWriter.ResetDict(w, dict)
if _, err := r.WriteTo(cpsWriter); err != nil {
return err
}
return cpsWriter.Flush()
}
9 changes: 3 additions & 6 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,7 @@ func (c *deflater) Decompress(src *bytes.Buffer, dict []byte) (*bytes.Buffer, er
func (c *deflater) Compress(src internal.Payload, dst *bytes.Buffer, dict []byte) error {
c.cpsLocker.Lock()
defer c.cpsLocker.Unlock()

c.cpsWriter.ResetDict(dst, dict)
if _, err := src.WriteTo(c.cpsWriter); err != nil {
return err
}
if err := c.cpsWriter.Flush(); err != nil {
if err := compressTo(c.cpsWriter, src, dst, dict); err != nil {
return err
}
if n := dst.Len(); n >= 4 {
Expand All @@ -117,6 +112,8 @@ func (c *deflater) Compress(src internal.Payload, dst *bytes.Buffer, dict []byte
return nil
}

func (c *deflater) ToBigDeflater() *bigDeflater { return &bigDeflater{cpsWriter: c.cpsWriter} }

// 滑动窗口
// Sliding window
type slideWindow struct {
Expand Down
4 changes: 4 additions & 0 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,7 @@ func (c *writerTo) Len() int {
func (c *writerTo) WriteTo(w io.Writer) (n int64, err error) {
return 0, errors.New("1")
}

func (c *writerTo) Read(p []byte) (n int, err error) {
return 0, errors.New("1")
}
10 changes: 7 additions & 3 deletions examples/echo/main.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package main

import (
"github.com/lxzan/gws"
"log"
"net/http"

"github.com/lxzan/gws"
"os"
)

func main() {
Expand Down Expand Up @@ -41,5 +41,9 @@ func (c *Handler) OnPing(socket *gws.Conn, payload []byte) {

func (c *Handler) OnMessage(socket *gws.Conn, message *gws.Message) {
defer message.Close()
_ = socket.WriteMessage(message.Opcode, message.Bytes())
//file, _ := os.OpenFile("C:\\msys64\\home\\lxzan\\Open\\gws\\assets\\github.json", os.O_RDONLY, 0644)
file, _ := os.OpenFile("C:\\Users\\lxzan\\Pictures\\mg.png", os.O_RDONLY, 0644)
defer file.Close()
_ = socket.WriteReader(gws.OpcodeBinary, file)
//_ = socket.WriteReader(message.Opcode, message)
}
6 changes: 6 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ type (
// Memory pool for bufio.Reader
brPool *internal.Pool[*bufio.Reader]

// 大文件压缩器
bdPool *internal.Pool[*bigDeflater]

// 压缩器滑动窗口内存池
// Memory pool for compressor sliding window
cswPool *internal.Pool[[]byte]
Expand Down Expand Up @@ -320,6 +323,9 @@ func initServerOption(c *ServerOption) *ServerOption {
}

if c.PermessageDeflate.Enabled {
c.config.bdPool = internal.NewPool[*bigDeflater](func() *bigDeflater {
return new(bigDeflater).initialize(true, c.PermessageDeflate)
})
if c.PermessageDeflate.ServerContextTakeover {
windowSize := internal.BinaryPow(c.PermessageDeflate.ServerMaxWindowBits)
c.config.cswPool = internal.NewPool[[]byte](func() []byte {
Expand Down
6 changes: 3 additions & 3 deletions reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func TestSegments(t *testing.T) {
go client.ReadLoop()

go func() {
frame, _ := client.genFrame(OpcodeText, internal.Bytes(testdata), false)
frame, _ := client.genFrame(OpcodeText, true, true, internal.Bytes(testdata), false)
data := frame.Bytes()
data[20] = 'x'
client.conn.Write(data)
Expand Down Expand Up @@ -366,7 +366,7 @@ func TestConn_ReadMessage(t *testing.T) {
var serverHandler = &webSocketMocker{}
serverHandler.onOpen = func(socket *Conn) {
var p = []byte("123")
frame, _ := socket.genFrame(OpcodePing, internal.Bytes(p), false)
frame, _ := socket.genFrame(OpcodePing, true, socket.pd.Enabled, internal.Bytes(p), false)
socket.conn.Write(frame.Bytes()[:2])
socket.conn.Close()
}
Expand All @@ -391,7 +391,7 @@ func TestConn_ReadMessage(t *testing.T) {
var serverHandler = &webSocketMocker{}
serverHandler.onOpen = func(socket *Conn) {
var p = []byte("123")
frame, _ := socket.genFrame(OpcodeText, internal.Bytes(p), false)
frame, _ := socket.genFrame(OpcodeText, true, socket.pd.Enabled, internal.Bytes(p), false)
socket.conn.Write(frame.Bytes()[:2])
socket.conn.Close()
}
Expand Down
Loading

0 comments on commit 2d79d12

Please sign in to comment.