diff --git a/core/transport/websocket_transport.go b/core/transport/websocket_transport.go index ec93c58..cbde2c3 100644 --- a/core/transport/websocket_transport.go +++ b/core/transport/websocket_transport.go @@ -26,6 +26,7 @@ type wsServerTransport struct { l net.Listener m map[*Transport]struct{} done chan struct{} + mux *http.ServeMux } func (ws *wsServerTransport) Close() (err error) { @@ -80,8 +81,11 @@ func (ws *wsServerTransport) Listen(ctx context.Context, notifier chan<- bool) ( } }() - mux := http.NewServeMux() - mux.HandleFunc(ws.path, func(w http.ResponseWriter, r *http.Request) { + if ws.mux == nil { + ws.mux = http.NewServeMux() + } + + ws.mux.HandleFunc(ws.path, func(w http.ResponseWriter, r *http.Request) { // upgrade websocket c, err := ws.upgrader.Upgrade(w, r, nil) if err != nil { @@ -103,7 +107,7 @@ func (ws *wsServerTransport) Listen(ctx context.Context, notifier chan<- bool) ( } }) - err = http.Serve(ws.l, mux) + err = http.Serve(ws.l, ws.mux) if err == io.EOF || isClosedErr(err) { err = nil } else { @@ -138,8 +142,12 @@ func (ws *wsServerTransport) putTransport(tp *Transport) bool { } } +func NewWebsocketServerTransportWithMux(f ListenerFactory, path string, upgrader *websocket.Upgrader, mux *http.ServeMux) ServerTransport { + return NewWebsocketServerTransport(f, path, upgrader, mux) +} + // NewWebsocketServerTransport creates a new server-side transport. -func NewWebsocketServerTransport(f ListenerFactory, path string, upgrader *websocket.Upgrader) ServerTransport { +func NewWebsocketServerTransport(f ListenerFactory, path string, upgrader *websocket.Upgrader, mux *http.ServeMux) ServerTransport { if path == "" { path = defaultWebsocketPath } @@ -158,6 +166,7 @@ func NewWebsocketServerTransport(f ListenerFactory, path string, upgrader *webso f: f, m: make(map[*Transport]struct{}), done: make(chan struct{}), + mux: mux, } } @@ -174,7 +183,7 @@ func NewWebsocketServerTransportWithAddr(addr string, path string, upgrader *web } return tls.NewListener(l, config), nil } - return NewWebsocketServerTransport(f, path, upgrader) + return NewWebsocketServerTransport(f, path, upgrader, nil) } // NewWebsocketClientTransport creates a new client-side transport.