Skip to content

Commit

Permalink
For custom SessionStorage implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
lixizan committed Jul 27, 2023
1 parent 7c8cd8f commit 01617f5
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 12 deletions.
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (c *connector) handshake() (*Conn, *http.Response, error) {
return nil, c.resp, err
}
var compressEnabled = c.option.CompressEnabled && strings.Contains(c.resp.Header.Get(internal.SecWebSocketExtensions.Key), "permessage-deflate")
return serveWebSocket(false, c.option.getConfig(), new(sliceMap), c.conn, br, c.eventHandler, compressEnabled), c.resp, nil
return serveWebSocket(false, c.option.getConfig(), c.option.NewSessionStorage(), c.conn, br, c.eventHandler, compressEnabled), c.resp, nil
}

func (c *connector) checkHeaders() error {
Expand Down
20 changes: 17 additions & 3 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ type (
// 鉴权
// Authentication of requests for connection establishment
Authorize func(r *http.Request, session SessionStorage) bool

// 创建session存储空间
// 用于自定义SessionStorage实现
// For custom SessionStorage implementations
NewSessionStorage func() SessionStorage
}
)

Expand Down Expand Up @@ -139,9 +144,10 @@ func initServerOption(c *ServerOption) *ServerOption {
c.CompressorNum = defaultCompressorNum
}
if c.Authorize == nil {
c.Authorize = func(r *http.Request, session SessionStorage) bool {
return true
}
c.Authorize = func(r *http.Request, session SessionStorage) bool { return true }
}
if c.NewSessionStorage == nil {
c.NewSessionStorage = func() SessionStorage { return new(sliceMap) }
}
if c.ResponseHeader == nil {
c.ResponseHeader = http.Header{}
Expand Down Expand Up @@ -212,6 +218,11 @@ type ClientOption struct {
// return proxy.SOCKS5("tcp", "127.0.0.1:1080", nil, nil)
// },
NewDialer func() (Dialer, error)

// 创建session存储空间
// 用于自定义SessionStorage实现
// For custom SessionStorage implementations
NewSessionStorage func() SessionStorage
}

func initClientOption(c *ClientOption) *ClientOption {
Expand Down Expand Up @@ -248,6 +259,9 @@ func initClientOption(c *ClientOption) *ClientOption {
if c.NewDialer == nil {
c.NewDialer = func() (Dialer, error) { return &net.Dialer{Timeout: defaultDialTimeout}, nil }
}
if c.NewSessionStorage == nil {
c.NewSessionStorage = func() SessionStorage { return new(sliceMap) }
}
return c
}

Expand Down
29 changes: 29 additions & 0 deletions option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ func validateServerOption(as *assert.Assertions, u *Upgrader) {
as.Equal(config.ReadBufferSize, option.ReadBufferSize)
as.Equal(config.WriteBufferSize, option.WriteBufferSize)
as.Equal(config.CompressorNum, option.CompressorNum)

_, ok := u.option.NewSessionStorage().(*sliceMap)
as.True(ok)
}

func validateClientOption(as *assert.Assertions, option *ClientOption) {
Expand All @@ -36,6 +39,9 @@ func validateClientOption(as *assert.Assertions, option *ClientOption) {
as.Equal(config.CheckUtf8Enabled, option.CheckUtf8Enabled)
as.Equal(config.ReadBufferSize, option.ReadBufferSize)
as.Equal(config.WriteBufferSize, option.WriteBufferSize)

_, ok := option.NewSessionStorage().(*sliceMap)
as.True(ok)
}

// 检查默认配置
Expand All @@ -56,6 +62,7 @@ func TestDefaultUpgrader(t *testing.T) {
as.NotNil(updrader.option)
as.NotNil(updrader.option.ResponseHeader)
as.NotNil(updrader.option.Authorize)
as.NotNil(updrader.option.NewSessionStorage)
as.Nil(updrader.option.Subprotocols)
validateServerOption(as, updrader)
}
Expand Down Expand Up @@ -122,6 +129,7 @@ func TestDefaultClientOption(t *testing.T) {
as.Equal(1, config.CompressorNum)
as.NotNil(config)
as.Equal(0, len(option.RequestHeader))
as.NotNil(option.NewSessionStorage)
validateClientOption(as, option)
}

Expand All @@ -144,10 +152,31 @@ func TestCompressClientOption(t *testing.T) {
CompressLevel: flate.BestCompression,
CompressThreshold: 1024,
}
initClientOption(option)
var config = option.getConfig()
as.Equal(true, config.CompressEnabled)
as.Equal(flate.BestCompression, config.CompressLevel)
as.Equal(1024, config.CompressThreshold)
validateClientOption(as, option)
})
}

func TestNewSessionStorage(t *testing.T) {
{
var option = &ServerOption{
NewSessionStorage: func() SessionStorage { return NewConcurrentMap[string, any](16) },
}
initServerOption(option)
_, ok := option.NewSessionStorage().(*ConcurrentMap[string, any])
assert.True(t, ok)
}

{
var option = &ClientOption{
NewSessionStorage: func() SessionStorage { return NewConcurrentMap[string, any](16) },
}
initClientOption(option)
_, ok := option.NewSessionStorage().(*ConcurrentMap[string, any])
assert.True(t, ok)
}
}
7 changes: 0 additions & 7 deletions session_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"sync"
)

// SessionStorage because sync.Map is not easy to debug, so I implemented my own map.
// if you don't like it, use sync.Map instead.
type SessionStorage interface {
Load(key string) (value interface{}, exist bool)
Delete(key string)
Expand Down Expand Up @@ -94,11 +92,6 @@ func (c *sliceMap) Range(f func(key string, value interface{}) bool) {
}
}

/*
ConcurrentMap
used to store websocket connections in the IM server
用来存储IM等服务的连接
*/
type (
Comparable interface {
string | int | int64 | int32 | uint | uint64 | uint32
Expand Down
2 changes: 1 addition & 1 deletion updrader.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (c *Upgrader) doUpgrade(r *http.Request, netConn net.Conn, br *bufio.Reader
return nil, err
}

var session = new(sliceMap)
var session = c.option.NewSessionStorage()
var header = c.option.ResponseHeader.Clone()
if !c.option.Authorize(r, session) {
return nil, internal.ErrUnauthorized
Expand Down

0 comments on commit 01617f5

Please sign in to comment.