From 7c8cd8f275b3a02735273bac68986642c0676566 Mon Sep 17 00:00:00 2001 From: lxzan Date: Sun, 23 Jul 2023 10:44:01 +0800 Subject: [PATCH] adapt to firefox --- client.go | 4 ++-- internal/utils.go | 4 ++++ internal/utils_test.go | 5 +++++ updrader.go | 6 +++--- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index f590a35b..1779ea98 100644 --- a/client.go +++ b/client.go @@ -125,10 +125,10 @@ func (c *connector) checkHeaders() error { if c.resp.StatusCode != http.StatusSwitchingProtocols { return internal.ErrStatusCode } - if !internal.HttpHeaderEqual(c.resp.Header.Get(internal.Connection.Key), internal.Connection.Val) { + if !internal.HttpHeaderContains(c.resp.Header.Get(internal.Connection.Key), internal.Connection.Val) { return internal.ErrHandshake } - if !internal.HttpHeaderEqual(c.resp.Header.Get(internal.Upgrade.Key), internal.Upgrade.Val) { + if !strings.EqualFold(c.resp.Header.Get(internal.Upgrade.Key), internal.Upgrade.Val) { return internal.ErrHandshake } if c.resp.Header.Get(internal.SecWebSocketAccept.Key) != internal.ComputeAcceptKey(c.secWebsocketKey) { diff --git a/internal/utils.go b/internal/utils.go index ed348039..0dd447b8 100644 --- a/internal/utils.go +++ b/internal/utils.go @@ -198,6 +198,10 @@ func HttpHeaderEqual(a, b string) bool { return strings.ToLower(a) == strings.ToLower(b) } +func HttpHeaderContains(a, b string) bool { + return strings.Contains(strings.ToLower(a), strings.ToLower(b)) +} + func SelectValue[T any](ok bool, a, b T) T { if ok { return a diff --git a/internal/utils_test.go b/internal/utils_test.go index f96248af..62b7ae9e 100644 --- a/internal/utils_test.go +++ b/internal/utils_test.go @@ -186,6 +186,11 @@ func TestHttpHeaderEqual(t *testing.T) { assert.Equal(t, false, HttpHeaderEqual("WebSocket@", "websocket")) } +func TestHttpHeaderContains(t *testing.T) { + assert.Equal(t, true, HttpHeaderContains("WebSocket", "websocket")) + assert.Equal(t, true, HttpHeaderContains("WebSocket@", "websocket")) +} + func TestSelectInt(t *testing.T) { assert.Equal(t, 1, SelectValue(true, 1, 2)) assert.Equal(t, 2, SelectValue(false, 1, 2)) diff --git a/updrader.go b/updrader.go index 2b44f88c..3551ff5c 100644 --- a/updrader.go +++ b/updrader.go @@ -105,14 +105,14 @@ func (c *Upgrader) doUpgrade(r *http.Request, netConn net.Conn, br *bufio.Reader if r.Method != http.MethodGet { return nil, internal.ErrGetMethodRequired } - if !internal.HttpHeaderEqual(r.Header.Get(internal.SecWebSocketVersion.Key), internal.SecWebSocketVersion.Val) { + if !strings.EqualFold(r.Header.Get(internal.SecWebSocketVersion.Key), internal.SecWebSocketVersion.Val) { msg := "websocket version not supported" return nil, errors.New(msg) } - if !internal.HttpHeaderEqual(r.Header.Get(internal.Connection.Key), internal.Connection.Val) { + if !internal.HttpHeaderContains(r.Header.Get(internal.Connection.Key), internal.Connection.Val) { return nil, internal.ErrHandshake } - if !internal.HttpHeaderEqual(r.Header.Get(internal.Upgrade.Key), internal.Upgrade.Val) { + if !strings.EqualFold(r.Header.Get(internal.Upgrade.Key), internal.Upgrade.Val) { return nil, internal.ErrHandshake } if val := r.Header.Get(internal.SecWebSocketExtensions.Key); strings.Contains(val, "permessage-deflate") && c.option.CompressEnabled {