diff --git a/client.go b/client.go index 1779ea98..5fe35a97 100644 --- a/client.go +++ b/client.go @@ -118,7 +118,7 @@ func (c *connector) handshake() (*Conn, *http.Response, error) { return nil, c.resp, err } var compressEnabled = c.option.CompressEnabled && strings.Contains(c.resp.Header.Get(internal.SecWebSocketExtensions.Key), "permessage-deflate") - return serveWebSocket(false, c.option.getConfig(), new(sliceMap), c.conn, br, c.eventHandler, compressEnabled), c.resp, nil + return serveWebSocket(false, c.option.getConfig(), c.option.NewSessionStorage(), c.conn, br, c.eventHandler, compressEnabled), c.resp, nil } func (c *connector) checkHeaders() error { diff --git a/option.go b/option.go index 6fc49692..47c391b6 100644 --- a/option.go +++ b/option.go @@ -107,6 +107,11 @@ type ( // 鉴权 // Authentication of requests for connection establishment Authorize func(r *http.Request, session SessionStorage) bool + + // 创建session存储空间 + // 用于自定义SessionStorage实现 + // For custom SessionStorage implementations + NewSessionStorage func() SessionStorage } ) @@ -139,9 +144,10 @@ func initServerOption(c *ServerOption) *ServerOption { c.CompressorNum = defaultCompressorNum } if c.Authorize == nil { - c.Authorize = func(r *http.Request, session SessionStorage) bool { - return true - } + c.Authorize = func(r *http.Request, session SessionStorage) bool { return true } + } + if c.NewSessionStorage == nil { + c.NewSessionStorage = func() SessionStorage { return new(sliceMap) } } if c.ResponseHeader == nil { c.ResponseHeader = http.Header{} @@ -212,6 +218,11 @@ type ClientOption struct { // return proxy.SOCKS5("tcp", "127.0.0.1:1080", nil, nil) // }, NewDialer func() (Dialer, error) + + // 创建session存储空间 + // 用于自定义SessionStorage实现 + // For custom SessionStorage implementations + NewSessionStorage func() SessionStorage } func initClientOption(c *ClientOption) *ClientOption { @@ -248,6 +259,9 @@ func initClientOption(c *ClientOption) *ClientOption { if c.NewDialer == nil { c.NewDialer = func() (Dialer, error) { return &net.Dialer{Timeout: defaultDialTimeout}, nil } } + if c.NewSessionStorage == nil { + c.NewSessionStorage = func() SessionStorage { return new(sliceMap) } + } return c } diff --git a/option_test.go b/option_test.go index 4411d7b4..5c3a6a00 100644 --- a/option_test.go +++ b/option_test.go @@ -22,6 +22,9 @@ func validateServerOption(as *assert.Assertions, u *Upgrader) { as.Equal(config.ReadBufferSize, option.ReadBufferSize) as.Equal(config.WriteBufferSize, option.WriteBufferSize) as.Equal(config.CompressorNum, option.CompressorNum) + + _, ok := u.option.NewSessionStorage().(*sliceMap) + as.True(ok) } func validateClientOption(as *assert.Assertions, option *ClientOption) { @@ -36,6 +39,9 @@ func validateClientOption(as *assert.Assertions, option *ClientOption) { as.Equal(config.CheckUtf8Enabled, option.CheckUtf8Enabled) as.Equal(config.ReadBufferSize, option.ReadBufferSize) as.Equal(config.WriteBufferSize, option.WriteBufferSize) + + _, ok := option.NewSessionStorage().(*sliceMap) + as.True(ok) } // 检查默认配置 @@ -56,6 +62,7 @@ func TestDefaultUpgrader(t *testing.T) { as.NotNil(updrader.option) as.NotNil(updrader.option.ResponseHeader) as.NotNil(updrader.option.Authorize) + as.NotNil(updrader.option.NewSessionStorage) as.Nil(updrader.option.Subprotocols) validateServerOption(as, updrader) } @@ -122,6 +129,7 @@ func TestDefaultClientOption(t *testing.T) { as.Equal(1, config.CompressorNum) as.NotNil(config) as.Equal(0, len(option.RequestHeader)) + as.NotNil(option.NewSessionStorage) validateClientOption(as, option) } @@ -144,6 +152,7 @@ func TestCompressClientOption(t *testing.T) { CompressLevel: flate.BestCompression, CompressThreshold: 1024, } + initClientOption(option) var config = option.getConfig() as.Equal(true, config.CompressEnabled) as.Equal(flate.BestCompression, config.CompressLevel) @@ -151,3 +160,23 @@ func TestCompressClientOption(t *testing.T) { validateClientOption(as, option) }) } + +func TestNewSessionStorage(t *testing.T) { + { + var option = &ServerOption{ + NewSessionStorage: func() SessionStorage { return NewConcurrentMap[string, any](16) }, + } + initServerOption(option) + _, ok := option.NewSessionStorage().(*ConcurrentMap[string, any]) + assert.True(t, ok) + } + + { + var option = &ClientOption{ + NewSessionStorage: func() SessionStorage { return NewConcurrentMap[string, any](16) }, + } + initClientOption(option) + _, ok := option.NewSessionStorage().(*ConcurrentMap[string, any]) + assert.True(t, ok) + } +} diff --git a/session_storage.go b/session_storage.go index 95cba255..ed5ca0e1 100644 --- a/session_storage.go +++ b/session_storage.go @@ -5,8 +5,6 @@ import ( "sync" ) -// SessionStorage because sync.Map is not easy to debug, so I implemented my own map. -// if you don't like it, use sync.Map instead. type SessionStorage interface { Load(key string) (value interface{}, exist bool) Delete(key string) @@ -94,11 +92,6 @@ func (c *sliceMap) Range(f func(key string, value interface{}) bool) { } } -/* -ConcurrentMap -used to store websocket connections in the IM server -用来存储IM等服务的连接 -*/ type ( Comparable interface { string | int | int64 | int32 | uint | uint64 | uint32 diff --git a/updrader.go b/updrader.go index 3551ff5c..81b5f9e3 100644 --- a/updrader.go +++ b/updrader.go @@ -95,7 +95,7 @@ func (c *Upgrader) doUpgrade(r *http.Request, netConn net.Conn, br *bufio.Reader return nil, err } - var session = new(sliceMap) + var session = c.option.NewSessionStorage() var header = c.option.ResponseHeader.Clone() if !c.option.Authorize(r, session) { return nil, internal.ErrUnauthorized