Skip to content

Commit

Permalink
Implement custom mux for Websocket servers
Browse files Browse the repository at this point in the history
  • Loading branch information
anderspitman committed Jan 22, 2024
1 parent 473989b commit 0fe2cbc
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions core/transport/websocket_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type wsServerTransport struct {
l net.Listener
m map[*Transport]struct{}
done chan struct{}
mux *http.ServeMux
}

func (ws *wsServerTransport) Close() (err error) {
Expand Down Expand Up @@ -80,8 +81,11 @@ func (ws *wsServerTransport) Listen(ctx context.Context, notifier chan<- bool) (
}
}()

mux := http.NewServeMux()
mux.HandleFunc(ws.path, func(w http.ResponseWriter, r *http.Request) {
if ws.mux == nil {
ws.mux = http.NewServeMux()
}

ws.mux.HandleFunc(ws.path, func(w http.ResponseWriter, r *http.Request) {
// upgrade websocket
c, err := ws.upgrader.Upgrade(w, r, nil)
if err != nil {
Expand All @@ -103,7 +107,7 @@ func (ws *wsServerTransport) Listen(ctx context.Context, notifier chan<- bool) (
}
})

err = http.Serve(ws.l, mux)
err = http.Serve(ws.l, ws.mux)
if err == io.EOF || isClosedErr(err) {
err = nil
} else {
Expand Down Expand Up @@ -138,8 +142,12 @@ func (ws *wsServerTransport) putTransport(tp *Transport) bool {
}
}

func NewWebsocketServerTransportWithMux(f ListenerFactory, path string, upgrader *websocket.Upgrader, mux *http.ServeMux) ServerTransport {
return NewWebsocketServerTransport(f, path, upgrader, mux)
}

// NewWebsocketServerTransport creates a new server-side transport.
func NewWebsocketServerTransport(f ListenerFactory, path string, upgrader *websocket.Upgrader) ServerTransport {
func NewWebsocketServerTransport(f ListenerFactory, path string, upgrader *websocket.Upgrader, mux *http.ServeMux) ServerTransport {
if path == "" {
path = defaultWebsocketPath
}
Expand All @@ -158,6 +166,7 @@ func NewWebsocketServerTransport(f ListenerFactory, path string, upgrader *webso
f: f,
m: make(map[*Transport]struct{}),
done: make(chan struct{}),
mux: mux,
}
}

Expand All @@ -174,7 +183,7 @@ func NewWebsocketServerTransportWithAddr(addr string, path string, upgrader *web
}
return tls.NewListener(l, config), nil
}
return NewWebsocketServerTransport(f, path, upgrader)
return NewWebsocketServerTransport(f, path, upgrader, nil)
}

// NewWebsocketClientTransport creates a new client-side transport.
Expand Down

0 comments on commit 0fe2cbc

Please sign in to comment.