From 7650f95b3b38aa775118d64c747a3837573e32e7 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 7 Sep 2024 21:40:07 +0200 Subject: [PATCH] Follow redirects in the new asyncio implementation. Fix #631. --- docs/howto/upgrade.rst | 27 +-- docs/project/changelog.rst | 3 + docs/reference/features.rst | 2 +- docs/reference/variables.rst | 9 + src/websockets/asyncio/client.py | 169 ++++++++++++---- src/websockets/legacy/client.py | 2 +- tests/asyncio/test_client.py | 319 ++++++++++++++++++++++++------- 7 files changed, 401 insertions(+), 130 deletions(-) diff --git a/docs/howto/upgrade.rst b/docs/howto/upgrade.rst index 42edb978..120509c9 100644 --- a/docs/howto/upgrade.rst +++ b/docs/howto/upgrade.rst @@ -10,15 +10,13 @@ It provides a very similar API. However, there are a few differences. The recommended upgrade process is: -1. Make sure that your application doesn't use any `deprecated APIs`_. If it +#. Make sure that your application doesn't use any `deprecated APIs`_. If it doesn't raise any warnings, you can skip this step. -2. Check if your application depends on `missing features`_. If it does, you - should stick to the original implementation until they're added. -3. `Update import paths`_. For straightforward usage of websockets, this could +#. `Update import paths`_. For straightforward usage of websockets, this could be the only step you need to take. Upgrading could be transparent. -4. Check out `new features and improvements`_ and consider taking advantage of +#. Check out `new features and improvements`_ and consider taking advantage of them to improve your application. -5. Review `API changes`_ and adapt your application to preserve its current +#. Review `API changes`_ and adapt your application to preserve its current functionality. In the interest of brevity, only :func:`~asyncio.client.connect` and @@ -62,23 +60,6 @@ the release notes of the version in which the feature was deprecated. * The ``host``, ``port``, and ``secure`` attributes of connections — deprecated in :ref:`8.0`. -.. _missing features: - -Missing features ----------------- - -.. admonition:: All features listed below will be provided in a future release. - :class: tip - - If your application relies on one of them, you should stick to the original - implementation until the new implementation supports it in a future release. - -Following redirects -................... - -The new implementation of :func:`~asyncio.client.connect` doesn't follow HTTP -redirects yet. - .. _Update import paths: Import paths diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 61113fb8..c77876a7 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -50,6 +50,9 @@ New features :func:`~asyncio.client.connect` as an asynchronous iterator to the new :mod:`asyncio` implementation. +* :func:`~asyncio.client.connect` now follows redirects in the new + :mod:`asyncio` implementation. + * Added HTTP Basic Auth to the new :mod:`asyncio` and :mod:`threading` implementations of servers. diff --git a/docs/reference/features.rst b/docs/reference/features.rst index d9941e40..32fc05ba 100644 --- a/docs/reference/features.rst +++ b/docs/reference/features.rst @@ -154,7 +154,7 @@ Client +------------------------------------+--------+--------+--------+--------+ | Connect to non-ASCII IRIs | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ - | Follow HTTP redirects | ❌ | ❌ | — | ✅ | + | Follow HTTP redirects | ✅ | ❌ | — | ✅ | +------------------------------------+--------+--------+--------+--------+ | Perform HTTP Basic Authentication | ✅ | ✅ | ✅ | ✅ | +------------------------------------+--------+--------+--------+--------+ diff --git a/docs/reference/variables.rst b/docs/reference/variables.rst index 49813225..26213bd9 100644 --- a/docs/reference/variables.rst +++ b/docs/reference/variables.rst @@ -77,3 +77,12 @@ Reconnection attempts are spaced out with truncated exponential backoff. The delay between attempts is capped at ``BACKOFF_MAX_DELAY`` seconds. The default value is ``90.0`` seconds. + +Redirections +------------ + +.. envvar:: WEBSOCKETS_MAX_REDIRECTS + + Maximum number of redirects that :func:`~asyncio.client.connect` follows. + + The default value is ``10``. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 5f7a3719..50f67b95 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -1,27 +1,30 @@ from __future__ import annotations import asyncio -import functools import logging +import os +import urllib.parse from types import TracebackType from typing import Any, AsyncIterator, Callable, Generator, Sequence from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike -from ..exceptions import InvalidStatus +from ..exceptions import InvalidStatus, SecurityError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..typing import LoggerLike, Origin, Subprotocol -from ..uri import parse_uri +from ..uri import WebSocketURI, parse_uri from .compatibility import TimeoutError, asyncio_timeout from .connection import Connection __all__ = ["connect", "unix_connect", "ClientConnection"] +MAX_REDIRECTS = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) + class ClientConnection(Connection): """ @@ -126,7 +129,7 @@ def connection_lost(self, exc: Exception | None) -> None: def process_exception(exc: Exception) -> Exception | None: """ - Determine whether an error is retryable or fatal. + Determine whether a connection error is retryable or fatal. When reconnecting automatically with ``async for ... in connect(...)``, if a connection attempt fails, :func:`process_exception` is called to determine @@ -297,16 +300,7 @@ def __init__( # Other keyword arguments are passed to loop.create_connection **kwargs: Any, ) -> None: - wsuri = parse_uri(uri) - - if wsuri.secure: - kwargs.setdefault("ssl", True) - kwargs.setdefault("server_hostname", wsuri.host) - if kwargs.get("ssl") is None: - raise TypeError("ssl=None is incompatible with a wss:// URI") - else: - if kwargs.get("ssl") is not None: - raise TypeError("ssl argument is incompatible with a ws:// URI") + self.uri = uri if subprotocols is not None: validate_subprotocols(subprotocols) @@ -316,10 +310,13 @@ def __init__( elif compression is not None: raise ValueError(f"unsupported compression: {compression}") + if logger is None: + logger = logging.getLogger("websockets.client") + if create_connection is None: create_connection = ClientConnection - def factory() -> ClientConnection: + def protocol_factory(wsuri: WebSocketURI) -> ClientConnection: # This is a protocol in the Sans-I/O implementation of websockets. protocol = ClientProtocol( wsuri, @@ -340,28 +337,104 @@ def factory() -> ClientConnection: ) return connection + self.protocol_factory = protocol_factory + self.handshake_args = ( + additional_headers, + user_agent_header, + ) + self.process_exception = process_exception + self.open_timeout = open_timeout + self.logger = logger + self.connection_kwargs = kwargs + + async def create_connection(self) -> ClientConnection: + """Create TCP or Unix connection.""" loop = asyncio.get_running_loop() + + wsuri = parse_uri(self.uri) + kwargs = self.connection_kwargs.copy() + + def factory() -> ClientConnection: + return self.protocol_factory(wsuri) + + if wsuri.secure: + kwargs.setdefault("ssl", True) + kwargs.setdefault("server_hostname", wsuri.host) + if kwargs.get("ssl") is None: + raise TypeError("ssl=None is incompatible with a wss:// URI") + else: + if kwargs.get("ssl") is not None: + raise TypeError("ssl argument is incompatible with a ws:// URI") + if kwargs.pop("unix", False): - self.create_connection = functools.partial( - loop.create_unix_connection, factory, **kwargs - ) + _, connection = await loop.create_unix_connection(factory, **kwargs) else: if kwargs.get("sock") is None: kwargs.setdefault("host", wsuri.host) kwargs.setdefault("port", wsuri.port) - self.create_connection = functools.partial( - loop.create_connection, factory, **kwargs + _, connection = await loop.create_connection(factory, **kwargs) + return connection + + def process_redirect(self, exc: Exception) -> Exception | str: + """ + Determine whether a connection error is a redirect that can be followed. + + Return the new URI if it's a valid redirect. Else, return an exception. + + """ + if not ( + isinstance(exc, InvalidStatus) + and exc.response.status_code + in [ + 300, # Multiple Choices + 301, # Moved Permanently + 302, # Found + 303, # See Other + 307, # Temporary Redirect + 308, # Permanent Redirect + ] + and "Location" in exc.response.headers + ): + return exc + + old_wsuri = parse_uri(self.uri) + new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) + new_wsuri = parse_uri(new_uri) + + # If connect() received a socket, it is closed and cannot be reused. + if self.connection_kwargs.get("sock") is not None: + return ValueError( + f"cannot follow redirect to {new_uri} with a preexisting socket" ) - self.handshake_args = ( - additional_headers, - user_agent_header, - ) - self.process_exception = process_exception - self.open_timeout = open_timeout - if logger is None: - logger = logging.getLogger("websockets.client") - self.logger = logger + # TLS downgrade is forbidden. + if old_wsuri.secure and not new_wsuri.secure: + return SecurityError(f"cannot follow redirect to non-secure URI {new_uri}") + + # Apply restrictions to cross-origin redirects. + if ( + old_wsuri.secure != new_wsuri.secure + or old_wsuri.host != new_wsuri.host + or old_wsuri.port != new_wsuri.port + ): + # Cross-origin redirects on Unix sockets don't quite make sense. + if self.connection_kwargs.get("unix", False): + return ValueError( + f"cannot follow cross-origin redirect to {new_uri} " + f"with a Unix socket" + ) + + # Cross-origin redirects when host and port are overridden are ill-defined. + if ( + self.connection_kwargs.get("host") is not None + or self.connection_kwargs.get("port") is not None + ): + return ValueError( + f"cannot follow cross-origin redirect to {new_uri} " + f"with an explicit host or port" + ) + + return new_uri # ... = await connect(...) @@ -372,14 +445,38 @@ def __await__(self) -> Generator[Any, None, ClientConnection]: async def __await_impl__(self) -> ClientConnection: try: async with asyncio_timeout(self.open_timeout): - _transport, self.connection = await self.create_connection() - try: - await self.connection.handshake(*self.handshake_args) - except (Exception, asyncio.CancelledError): - self.connection.transport.close() - raise + for _ in range(MAX_REDIRECTS): + self.connection = await self.create_connection() + try: + await self.connection.handshake(*self.handshake_args) + except asyncio.CancelledError: + self.connection.transport.close() + raise + except Exception as exc: + # Always close the connection even though keep-alive is + # the default in HTTP/1.1 because create_connection ties + # opening the network connection with initializing the + # protocol. In the current design of connect(), there is + # no easy way to reuse the network connection that works + # in every case nor to reinitialize the protocol. + self.connection.transport.close() + + uri_or_exc = self.process_redirect(exc) + # Response is a valid redirect; follow it. + if isinstance(uri_or_exc, str): + self.uri = uri_or_exc + continue + # Response isn't a valid redirect; raise the exception. + if uri_or_exc is exc: + raise + else: + raise uri_or_exc from exc + + else: + return self.connection else: - return self.connection + raise SecurityError(f"more than {MAX_REDIRECTS} redirects") + except TimeoutError: # Re-raise exception with an informative error message. raise TimeoutError("timed out during handshake") from None diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index a1bc5cba..ec4c2ff6 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -418,7 +418,7 @@ class Connect: """ - MAX_REDIRECTS_ALLOWED = 10 + MAX_REDIRECTS_ALLOWED = int(os.environ.get("WEBSOCKETS_MAX_REDIRECTS", "10")) def __init__( self, diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 7467d215..b0487552 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -10,7 +10,12 @@ from websockets.asyncio.compatibility import TimeoutError from websockets.asyncio.server import serve, unix_serve from websockets.client import backoff -from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI +from websockets.exceptions import ( + InvalidHandshake, + InvalidStatus, + InvalidURI, + SecurityError, +) from websockets.extensions.permessage_deflate import PerMessageDeflate from ..utils import CLIENT_CONTEXT, MS, SERVER_CONTEXT, temp_unix_socket_path @@ -34,6 +39,20 @@ async def short_backoff_delay(): backoff.__defaults__ = defaults +# Decorate tests that need it with @few_redirects() instead of using it as a +# context manager when dropping support for Python < 3.10. +@contextlib.asynccontextmanager +async def few_redirects(): + from websockets.asyncio import client + + max_redirects = client.MAX_REDIRECTS + client.MAX_REDIRECTS = 2 + try: + yield + finally: + client.MAX_REDIRECTS = max_redirects + + class ClientTests(unittest.IsolatedAsyncioTestCase): async def test_connection(self): """Client connects to server.""" @@ -41,7 +60,93 @@ async def test_connection(self): async with connect(get_uri(server)) as client: self.assertEqual(client.protocol.state.name, "OPEN") - async def test_reconnection(self): + async def test_explicit_host_port(self): + """Client connects using an explicit host / port.""" + async with serve(*args) as server: + host, port = get_host_port(server) + async with connect("ws://overridden/", host=host, port=port) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_existing_socket(self): + """Client connects using a pre-existing socket.""" + async with serve(*args) as server: + with socket.create_connection(get_host_port(server)) as sock: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/", sock=sock) as client: + self.assertEqual(client.protocol.state.name, "OPEN") + + async def test_additional_headers(self): + """Client can set additional headers with additional_headers.""" + async with serve(*args) as server: + async with connect( + get_uri(server), additional_headers={"Authorization": "Bearer ..."} + ) as client: + self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + + async def test_override_user_agent(self): + """Client can override User-Agent header with user_agent_header.""" + async with serve(*args) as server: + async with connect(get_uri(server), user_agent_header="Smith") as client: + self.assertEqual(client.request.headers["User-Agent"], "Smith") + + async def test_remove_user_agent(self): + """Client can remove User-Agent header with user_agent_header.""" + async with serve(*args) as server: + async with connect(get_uri(server), user_agent_header=None) as client: + self.assertNotIn("User-Agent", client.request.headers) + + async def test_compression_is_enabled(self): + """Client enables compression by default.""" + async with serve(*args) as server: + async with connect(get_uri(server)) as client: + self.assertEqual( + [type(ext) for ext in client.protocol.extensions], + [PerMessageDeflate], + ) + + async def test_disable_compression(self): + """Client disables compression.""" + async with serve(*args) as server: + async with connect(get_uri(server), compression=None) as client: + self.assertEqual(client.protocol.extensions, []) + + async def test_keepalive_is_enabled(self): + """Client enables keepalive and measures latency by default.""" + async with serve(*args) as server: + async with connect(get_uri(server), ping_interval=MS) as client: + self.assertEqual(client.latency, 0) + await asyncio.sleep(2 * MS) + self.assertGreater(client.latency, 0) + + async def test_disable_keepalive(self): + """Client disables keepalive.""" + async with serve(*args) as server: + async with connect(get_uri(server), ping_interval=None) as client: + await asyncio.sleep(2 * MS) + self.assertEqual(client.latency, 0) + + async def test_logger(self): + """Client accepts a logger argument.""" + logger = logging.getLogger("test") + async with serve(*args) as server: + async with connect(get_uri(server), logger=logger) as client: + self.assertEqual(client.logger.name, logger.name) + + async def test_custom_connection_factory(self): + """Client runs ClientConnection factory provided in create_connection.""" + + def create_connection(*args, **kwargs): + client = ClientConnection(*args, **kwargs) + client.create_connection_ran = True + return client + + async with serve(*args) as server: + async with connect( + get_uri(server), create_connection=create_connection + ) as client: + self.assertTrue(client.create_connection_ran) + + async def test_reconnect(self): """Client reconnects to server.""" iterations = 0 successful = 0 @@ -76,7 +181,7 @@ def process_request(connection, request): hasattr(http.HTTPStatus, "IM_A_TEAPOT"), "test requires Python 3.9", ) - async def test_reconnection_with_custom_process_exception(self): + async def test_reconnect_with_custom_process_exception(self): """Client runs process_exception to tell if errors are retryable or fatal.""" iteration = 0 @@ -113,7 +218,7 @@ def process_exception(exc): hasattr(http.HTTPStatus, "IM_A_TEAPOT"), "test requires Python 3.9", ) - async def test_reconnection_with_custom_process_exception_raising_exception(self): + async def test_reconnect_with_custom_process_exception_raising_exception(self): """Client supports raising an exception in process_exception.""" def process_request(connection, request): @@ -137,84 +242,107 @@ def process_exception(exc): "🫖 💔 ☕️", ) - async def test_existing_socket(self): - """Client connects using a pre-existing socket.""" - async with serve(*args) as server: - with socket.create_connection(get_host_port(server)) as sock: - # Use a non-existing domain to ensure we connect to the right socket. - async with connect("ws://invalid/", sock=sock) as client: - self.assertEqual(client.protocol.state.name, "OPEN") + async def test_redirect(self): + """Client follows redirect.""" - async def test_additional_headers(self): - """Client can set additional headers with additional_headers.""" - async with serve(*args) as server: - async with connect( - get_uri(server), additional_headers={"Authorization": "Bearer ..."} - ) as client: - self.assertEqual(client.request.headers["Authorization"], "Bearer ...") + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response - async def test_override_user_agent(self): - """Client can override User-Agent header with user_agent_header.""" - async with serve(*args) as server: - async with connect(get_uri(server), user_agent_header="Smith") as client: - self.assertEqual(client.request.headers["User-Agent"], "Smith") + async with serve(*args, process_request=redirect) as server: + async with connect(get_uri(server) + "/redirect") as client: + self.assertEqual(client.protocol.wsuri.path, "/") - async def test_remove_user_agent(self): - """Client can remove User-Agent header with user_agent_header.""" - async with serve(*args) as server: - async with connect(get_uri(server), user_agent_header=None) as client: - self.assertNotIn("User-Agent", client.request.headers) + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" - async def test_compression_is_enabled(self): - """Client enables compression by default.""" - async with serve(*args) as server: - async with connect(get_uri(server)) as client: - self.assertEqual( - [type(ext) for ext in client.protocol.extensions], - [PerMessageDeflate], - ) + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server) + return response - async def test_disable_compression(self): - """Client disables compression.""" - async with serve(*args) as server: - async with connect(get_uri(server), compression=None) as client: - self.assertEqual(client.protocol.extensions, []) + async with serve(*args, process_request=redirect) as server: + async with serve(*args) as other_server: + async with connect(get_uri(server)): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) - async def test_keepalive_is_enabled(self): - """Client enables keepalive and measures latency by default.""" - async with serve(*args) as server: - async with connect(get_uri(server), ping_interval=MS) as client: - self.assertEqual(client.latency, 0) - await asyncio.sleep(2 * MS) - self.assertGreater(client.latency, 0) + async def test_redirect_limit(self): + """Client stops following redirects after limit is reached.""" - async def test_disable_keepalive(self): - """Client disables keepalive.""" - async with serve(*args) as server: - async with connect(get_uri(server), ping_interval=None) as client: - await asyncio.sleep(2 * MS) - self.assertEqual(client.latency, 0) + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = request.path + return response - async def test_logger(self): - """Client accepts a logger argument.""" - logger = logging.getLogger("test") - async with serve(*args) as server: - async with connect(get_uri(server), logger=logger) as client: - self.assertEqual(client.logger.name, logger.name) + async with serve(*args, process_request=redirect) as server: + async with few_redirects(): + with self.assertRaises(SecurityError) as raised: + async with connect(get_uri(server)): + self.fail("did not raise") - async def test_custom_connection_factory(self): - """Client runs ClientConnection factory provided in create_connection.""" + self.assertEqual( + str(raised.exception), + "more than 2 redirects", + ) - def create_connection(*args, **kwargs): - client = ClientConnection(*args, **kwargs) - client.create_connection_ran = True - return client + async def test_redirect_with_explicit_host_port(self): + """Client follows redirect with an explicit host / port.""" - async with serve(*args) as server: + def redirect(connection, request): + if request.path == "/redirect": + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with serve(*args, process_request=redirect) as server: + host, port = get_host_port(server) async with connect( - get_uri(server), create_connection=create_connection + "ws://overridden/redirect", host=host, port=port ) as client: - self.assertTrue(client.create_connection_ran) + self.assertEqual(client.protocol.wsuri.path, "/") + + async def test_cross_origin_redirect_with_explicit_host_port(self): + """Client doesn't follow cross-origin redirect with an explicit host / port.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "ws://other/" + return response + + async with serve(*args, process_request=redirect) as server: + host, port = get_host_port(server) + with self.assertRaises(ValueError) as raised: + async with connect("ws://overridden/", host=host, port=port): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow cross-origin redirect to ws://other/ " + "with an explicit host or port", + ) + + async def test_redirect_with_existing_socket(self): + """Client doesn't follow redirect when using a pre-existing socket.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "/" + return response + + async with serve(*args, process_request=redirect) as server: + with socket.create_connection(get_host_port(server)) as sock: + with self.assertRaises(ValueError) as raised: + # Use a non-existing domain to ensure we connect to sock. + async with connect("ws://invalid/redirect", sock=sock): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow redirect to ws://invalid/ with a preexisting socket", + ) async def test_invalid_uri(self): """Client receives an invalid URI.""" @@ -336,6 +464,40 @@ async def test_reject_invalid_server_hostname(self): str(raised.exception), ) + async def test_cross_origin_redirect(self): + """Client follows redirect to a secure URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = get_uri(other_server) + return response + + async with serve(*args, ssl=SERVER_CONTEXT, process_request=redirect) as server: + async with serve(*args, ssl=SERVER_CONTEXT) as other_server: + async with connect(get_uri(server), ssl=CLIENT_CONTEXT): + self.assertFalse(server.connections) + self.assertTrue(other_server.connections) + + async def test_redirect_to_insecure_uri(self): + """Client doesn't follow redirect from secure URI to non-secure URI.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = insecure_uri + return response + + async with serve(*args, ssl=SERVER_CONTEXT, process_request=redirect) as server: + with self.assertRaises(SecurityError) as raised: + secure_uri = get_uri(server) + insecure_uri = secure_uri.replace("wss://", "ws://") + async with connect(secure_uri, ssl=CLIENT_CONTEXT): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + f"cannot follow redirect to non-secure URI {insecure_uri}", + ) + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class UnixClientTests(unittest.IsolatedAsyncioTestCase): @@ -354,6 +516,25 @@ async def test_set_host_header(self): async with unix_connect(path, uri="ws://overridden/") as client: self.assertEqual(client.request.headers["Host"], "overridden") + async def test_cross_origin_redirect(self): + """Client doesn't follows redirect to a URI on a different origin.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "ws://other/" + return response + + with temp_unix_socket_path() as path: + async with unix_serve(handler, path, process_request=redirect): + with self.assertRaises(ValueError) as raised: + async with unix_connect(path): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "cannot follow cross-origin redirect to ws://other/ with a Unix socket", + ) + @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets") class SecureUnixClientTests(unittest.IsolatedAsyncioTestCase):