Skip to content

Commit

Permalink
Follow redirects in the new asyncio implementation.
Browse files Browse the repository at this point in the history
Fix #631.
  • Loading branch information
aaugustin committed Sep 7, 2024
1 parent 1f89db7 commit 7650f95
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 130 deletions.
27 changes: 4 additions & 23 deletions docs/howto/upgrade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@ It provides a very similar API. However, there are a few differences.

The recommended upgrade process is:

1. Make sure that your application doesn't use any `deprecated APIs`_. If it
#. Make sure that your application doesn't use any `deprecated APIs`_. If it
doesn't raise any warnings, you can skip this step.
2. Check if your application depends on `missing features`_. If it does, you
should stick to the original implementation until they're added.
3. `Update import paths`_. For straightforward usage of websockets, this could
#. `Update import paths`_. For straightforward usage of websockets, this could
be the only step you need to take. Upgrading could be transparent.
4. Check out `new features and improvements`_ and consider taking advantage of
#. Check out `new features and improvements`_ and consider taking advantage of
them to improve your application.
5. Review `API changes`_ and adapt your application to preserve its current
#. Review `API changes`_ and adapt your application to preserve its current
functionality.

In the interest of brevity, only :func:`~asyncio.client.connect` and
Expand Down Expand Up @@ -62,23 +60,6 @@ the release notes of the version in which the feature was deprecated.
* The ``host``, ``port``, and ``secure`` attributes of connections — deprecated
in :ref:`8.0`.

.. _missing features:

Missing features
----------------

.. admonition:: All features listed below will be provided in a future release.
:class: tip

If your application relies on one of them, you should stick to the original
implementation until the new implementation supports it in a future release.

Following redirects
...................

The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP
redirects yet.

.. _Update import paths:

Import paths
Expand Down
3 changes: 3 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ New features
:func:`~asyncio.client.connect` as an asynchronous iterator to the new
:mod:`asyncio` implementation.

* :func:`~asyncio.client.connect` now follows redirects in the new
:mod:`asyncio` implementation.

* Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading`
implementations of servers.

Expand Down
2 changes: 1 addition & 1 deletion docs/reference/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ Client
+------------------------------------+--------+--------+--------+--------+
| Connect to non-ASCII IRIs |||||
+------------------------------------+--------+--------+--------+--------+
| Follow HTTP redirects | ||||
| Follow HTTP redirects | ||||
+------------------------------------+--------+--------+--------+--------+
| Perform HTTP Basic Authentication |||||
+------------------------------------+--------+--------+--------+--------+
Expand Down
9 changes: 9 additions & 0 deletions docs/reference/variables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,12 @@ Reconnection attempts are spaced out with truncated exponential backoff.
The delay between attempts is capped at ``BACKOFF_MAX_DELAY`` seconds.

The default value is ``90.0`` seconds.

Redirections
------------

.. envvar:: WEBSOCKETS_MAX_REDIRECTS

Maximum number of redirects that :func:`~asyncio.client.connect` follows.

The default value is ``10``.
169 changes: 133 additions & 36 deletions src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
from __future__ import annotations

import asyncio
import functools
import logging
import os
import urllib.parse
from types import TracebackType
from typing import Any, AsyncIterator, Callable, Generator, Sequence

from ..client import ClientProtocol, backoff
from ..datastructures import HeadersLike
from ..exceptions import InvalidStatus
from ..exceptions import InvalidStatus, SecurityError
from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import validate_subprotocols
from ..http11 import USER_AGENT, Response
from ..protocol import CONNECTING, Event
from ..typing import LoggerLike, Origin, Subprotocol
from ..uri import parse_uri
from ..uri import WebSocketURI, parse_uri
from .compatibility import TimeoutError, asyncio_timeout
from .connection import Connection


__all__ = ["connect", "unix_connect", "ClientConnection"]

MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))


class ClientConnection(Connection):
"""
Expand Down Expand Up @@ -126,7 +129,7 @@ def connection_lost(self, exc: Exception | None) -> None:

def process_exception(exc: Exception) -> Exception | None:
"""
Determine whether an error is retryable or fatal.
Determine whether a connection error is retryable or fatal.
When reconnecting automatically with ``async for ... in connect(...)``, if a
connection attempt fails, :func:`process_exception` is called to determine
Expand Down Expand Up @@ -297,16 +300,7 @@ def __init__(
# Other keyword arguments are passed to loop.create_connection
**kwargs: Any,
) -> None:
wsuri = parse_uri(uri)

if wsuri.secure:
kwargs.setdefault("ssl", True)
kwargs.setdefault("server_hostname", wsuri.host)
if kwargs.get("ssl") is None:
raise TypeError("ssl=None is incompatible with a wss:// URI")
else:
if kwargs.get("ssl") is not None:
raise TypeError("ssl argument is incompatible with a ws:// URI")
self.uri = uri

if subprotocols is not None:
validate_subprotocols(subprotocols)
Expand All @@ -316,10 +310,13 @@ def __init__(
elif compression is not None:
raise ValueError(f"unsupported compression: {compression}")

if logger is None:
logger = logging.getLogger("websockets.client")

if create_connection is None:
create_connection = ClientConnection

def factory() -> ClientConnection:
def protocol_factory(wsuri: WebSocketURI) -> ClientConnection:
# This is a protocol in the Sans-I/O implementation of websockets.
protocol = ClientProtocol(
wsuri,
Expand All @@ -340,28 +337,104 @@ def factory() -> ClientConnection:
)
return connection

self.protocol_factory = protocol_factory
self.handshake_args = (
additional_headers,
user_agent_header,
)
self.process_exception = process_exception
self.open_timeout = open_timeout
self.logger = logger
self.connection_kwargs = kwargs

async def create_connection(self) -> ClientConnection:
"""Create TCP or Unix connection."""
loop = asyncio.get_running_loop()

wsuri = parse_uri(self.uri)
kwargs = self.connection_kwargs.copy()

def factory() -> ClientConnection:
return self.protocol_factory(wsuri)

if wsuri.secure:
kwargs.setdefault("ssl", True)
kwargs.setdefault("server_hostname", wsuri.host)
if kwargs.get("ssl") is None:
raise TypeError("ssl=None is incompatible with a wss:// URI")
else:
if kwargs.get("ssl") is not None:
raise TypeError("ssl argument is incompatible with a ws:// URI")

if kwargs.pop("unix", False):
self.create_connection = functools.partial(
loop.create_unix_connection, factory, **kwargs
)
_, connection = await loop.create_unix_connection(factory, **kwargs)
else:
if kwargs.get("sock") is None:
kwargs.setdefault("host", wsuri.host)
kwargs.setdefault("port", wsuri.port)
self.create_connection = functools.partial(
loop.create_connection, factory, **kwargs
_, connection = await loop.create_connection(factory, **kwargs)
return connection

def process_redirect(self, exc: Exception) -> Exception | str:
"""
Determine whether a connection error is a redirect that can be followed.
Return the new URI if it's a valid redirect. Else, return an exception.
"""
if not (
isinstance(exc, InvalidStatus)
and exc.response.status_code
in [
300, # Multiple Choices
301, # Moved Permanently
302, # Found
303, # See Other
307, # Temporary Redirect
308, # Permanent Redirect
]
and "Location" in exc.response.headers
):
return exc

old_wsuri = parse_uri(self.uri)
new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"])
new_wsuri = parse_uri(new_uri)

# If connect() received a socket, it is closed and cannot be reused.
if self.connection_kwargs.get("sock") is not None:
return ValueError(
f"cannot follow redirect to {new_uri} with a preexisting socket"
)

self.handshake_args = (
additional_headers,
user_agent_header,
)
self.process_exception = process_exception
self.open_timeout = open_timeout
if logger is None:
logger = logging.getLogger("websockets.client")
self.logger = logger
# TLS downgrade is forbidden.
if old_wsuri.secure and not new_wsuri.secure:
return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}")

# Apply restrictions to cross-origin redirects.
if (
old_wsuri.secure != new_wsuri.secure
or old_wsuri.host != new_wsuri.host
or old_wsuri.port != new_wsuri.port
):
# Cross-origin redirects on Unix sockets don't quite make sense.
if self.connection_kwargs.get("unix", False):
return ValueError(
f"cannot follow cross-origin redirect to {new_uri} "
f"with a Unix socket"
)

# Cross-origin redirects when host and port are overridden are ill-defined.
if (
self.connection_kwargs.get("host") is not None
or self.connection_kwargs.get("port") is not None
):
return ValueError(
f"cannot follow cross-origin redirect to {new_uri} "
f"with an explicit host or port"
)

return new_uri

# ... = await connect(...)

Expand All @@ -372,14 +445,38 @@ def __await__(self) -> Generator[Any, None, ClientConnection]:
async def __await_impl__(self) -> ClientConnection:
try:
async with asyncio_timeout(self.open_timeout):
_transport, self.connection = await self.create_connection()
try:
await self.connection.handshake(*self.handshake_args)
except (Exception, asyncio.CancelledError):
self.connection.transport.close()
raise
for _ in range(MAX_REDIRECTS):
self.connection = await self.create_connection()
try:
await self.connection.handshake(*self.handshake_args)
except asyncio.CancelledError:
self.connection.transport.close()
raise
except Exception as exc:
# Always close the connection even though keep-alive is
# the default in HTTP/1.1 because create_connection ties
# opening the network connection with initializing the
# protocol. In the current design of connect(), there is
# no easy way to reuse the network connection that works
# in every case nor to reinitialize the protocol.
self.connection.transport.close()

uri_or_exc = self.process_redirect(exc)
# Response is a valid redirect; follow it.
if isinstance(uri_or_exc, str):
self.uri = uri_or_exc
continue
# Response isn't a valid redirect; raise the exception.
if uri_or_exc is exc:
raise
else:
raise uri_or_exc from exc

else:
return self.connection
else:
return self.connection
raise SecurityError(f"more than {MAX_REDIRECTS} redirects")

except TimeoutError:
# Re-raise exception with an informative error message.
raise TimeoutError("timed out during handshake") from None
Expand Down
2 changes: 1 addition & 1 deletion src/websockets/legacy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ class Connect:
"""

MAX_REDIRECTS_ALLOWED = 10
MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10"))

def __init__(
self,
Expand Down
Loading

0 comments on commit 7650f95

Please sign in to comment.