From b2691f211511dc476ce4fb77f2293e4082a8a357 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 7 Aug 2024 15:38:45 -0500 Subject: [PATCH] Fix connecting to npipe://, tcp://, and unix:// urls (#8632) Co-authored-by: Sam Bull --- CHANGES/8632.bugfix.rst | 1 + aiohttp/client.py | 17 +++++---- aiohttp/connector.py | 16 ++++++++ tests/test_client_session.py | 71 ++++++++++++++++++++++++++++++++++-- tests/test_connector.py | 34 +++++++++++++++++ 5 files changed, 128 insertions(+), 11 deletions(-) create mode 100644 CHANGES/8632.bugfix.rst diff --git a/CHANGES/8632.bugfix.rst b/CHANGES/8632.bugfix.rst new file mode 100644 index 00000000000..c6da81d7ab3 --- /dev/null +++ b/CHANGES/8632.bugfix.rst @@ -0,0 +1 @@ +Fixed connecting to ``npipe://``, ``tcp://``, and ``unix://`` urls -- by :user:`bdraco`. diff --git a/aiohttp/client.py b/aiohttp/client.py index 3256fa82bb2..299201a24c7 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -80,7 +80,13 @@ ClientWebSocketResponse, ClientWSTimeout, ) -from .connector import BaseConnector, NamedPipeConnector, TCPConnector, UnixConnector +from .connector import ( + HTTP_AND_EMPTY_SCHEMA_SET, + BaseConnector, + NamedPipeConnector, + TCPConnector, + UnixConnector, +) from .cookiejar import CookieJar from .helpers import ( _SENTINEL, @@ -210,9 +216,6 @@ class ClientTimeout: # https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2 IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"}) -HTTP_SCHEMA_SET = frozenset({"http", "https", ""}) -WS_SCHEMA_SET = frozenset({"ws", "wss"}) -ALLOWED_PROTOCOL_SCHEMA_SET = HTTP_SCHEMA_SET | WS_SCHEMA_SET _RetType = TypeVar("_RetType") _CharsetResolver = Callable[[ClientResponse, bytes], str] @@ -466,7 +469,8 @@ async def _request( except ValueError as e: raise InvalidUrlClientError(str_or_url) from e - if url.scheme not in ALLOWED_PROTOCOL_SCHEMA_SET: + assert self._connector is not None + if url.scheme not in self._connector.allowed_protocol_schema_set: raise NonHttpUrlClientError(url) skip_headers = set(self._skip_auto_headers) @@ -597,7 +601,6 @@ async def _request( real_timeout.connect, ceil_threshold=real_timeout.ceil_threshold, ): - assert self._connector is not None conn = await self._connector.connect( req, traces=traces, timeout=real_timeout ) @@ -693,7 +696,7 @@ async def _request( ) from e scheme = parsed_redirect_url.scheme - if scheme not in HTTP_SCHEMA_SET: + if scheme not in HTTP_AND_EMPTY_SCHEMA_SET: resp.close() raise NonHttpUrlRedirectClientError(r_url) elif not scheme: diff --git a/aiohttp/connector.py b/aiohttp/connector.py index fdb76c3e931..c86855a361f 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -64,6 +64,14 @@ SSLContext = object # type: ignore[misc,assignment] +EMPTY_SCHEMA_SET = frozenset({""}) +HTTP_SCHEMA_SET = frozenset({"http", "https"}) +WS_SCHEMA_SET = frozenset({"ws", "wss"}) + +HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET +HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET + + __all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector") @@ -190,6 +198,8 @@ class BaseConnector: # abort transport after 2 seconds (cleanup broken connections) _cleanup_closed_period = 2.0 + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET + def __init__( self, *, @@ -741,6 +751,8 @@ class TCPConnector(BaseConnector): loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) + def __init__( self, *, @@ -1342,6 +1354,8 @@ class UnixConnector(BaseConnector): loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"}) + def __init__( self, path: str, @@ -1396,6 +1410,8 @@ class NamedPipeConnector(BaseConnector): loop - Optional event loop. """ + allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"}) + def __init__( self, path: str, diff --git a/tests/test_client_session.py b/tests/test_client_session.py index f7506c880d8..e90f4266503 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -28,7 +28,7 @@ from aiohttp.client import ClientSession from aiohttp.client_proto import ResponseHandler from aiohttp.client_reqrep import ClientRequest, ConnectionKey -from aiohttp.connector import BaseConnector, Connection, TCPConnector +from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector from aiohttp.pytest_plugin import AiohttpClient from aiohttp.test_utils import make_mocked_coro from aiohttp.tracing import Trace @@ -536,15 +536,78 @@ async def test_ws_connect_allowed_protocols( hdrs.CONNECTION: "upgrade", hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, } - resp.url = URL(f"{protocol}://example.com") + resp.url = URL(f"{protocol}://example") resp.cookies = SimpleCookie() resp.start = mock.AsyncMock() req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) req_factory = mock.Mock(return_value=req) req.send = mock.AsyncMock(return_value=resp) + # BaseConnector allows all high level protocols by default + connector = BaseConnector() - session = await create_session(request_class=req_factory) + session = await create_session(connector=connector, request_class=req_factory) + + connections = [] + assert session._connector is not None + original_connect = session._connector.connect + + async def connect( + req: ClientRequest, traces: List[Trace], timeout: aiohttp.ClientTimeout + ) -> Connection: + conn = await original_connect(req, traces, timeout) + connections.append(conn) + return conn + + async def create_connection( + req: object, traces: object, timeout: object + ) -> ResponseHandler: + return create_mocked_conn() + + connector = session._connector + with mock.patch.object(connector, "connect", connect), mock.patch.object( + connector, "_create_connection", create_connection + ), mock.patch.object(connector, "_release"), mock.patch( + "aiohttp.client.os" + ) as m_os: + m_os.urandom.return_value = key_data + await session.ws_connect(f"{protocol}://example") + + # normally called during garbage collection. triggers an exception + # if the connection wasn't already closed + for c in connections: + c.close() + c.__del__() + + await session.close() + + +@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss", "unix"]) +async def test_ws_connect_unix_socket_allowed_protocols( + create_session: Callable[..., Awaitable[ClientSession]], + create_mocked_conn: Callable[[], ResponseHandler], + protocol: str, + ws_key: bytes, + key_data: bytes, +) -> None: + resp = mock.create_autospec(aiohttp.ClientResponse) + resp.status = 101 + resp.headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: ws_key, + } + resp.url = URL(f"{protocol}://example") + resp.cookies = SimpleCookie() + resp.start = mock.AsyncMock() + + req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True) + req_factory = mock.Mock(return_value=req) + req.send = mock.AsyncMock(return_value=resp) + # UnixConnector allows all high level protocols by default and unix sockets + session = await create_session( + connector=UnixConnector(path=""), request_class=req_factory + ) connections = [] assert session._connector is not None @@ -569,7 +632,7 @@ async def create_connection( "aiohttp.client.os" ) as m_os: m_os.urandom.return_value = key_data - await session.ws_connect(f"{protocol}://example.com") + await session.ws_connect(f"{protocol}://example") # normally called during garbage collection. triggers an exception # if the connection wasn't already closed diff --git a/tests/test_connector.py b/tests/test_connector.py index 2087e0d0cbd..335e2a1ebc0 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1636,6 +1636,11 @@ async def test_tcp_connector_ctor(loop: asyncio.AbstractEventLoop) -> None: assert conn.family == 0 +async def test_tcp_connector_allowed_protocols(loop: asyncio.AbstractEventLoop) -> None: + conn = aiohttp.TCPConnector() + assert conn.allowed_protocol_schema_set == {"", "tcp", "http", "https", "ws", "wss"} + + async def test_invalid_ssl_param() -> None: with pytest.raises(TypeError): aiohttp.TCPConnector(ssl=object()) # type: ignore[arg-type] @@ -1819,6 +1824,19 @@ async def test_ctor_with_default_loop(loop: asyncio.AbstractEventLoop) -> None: assert loop is conn._loop +async def test_base_connector_allows_high_level_protocols( + loop: asyncio.AbstractEventLoop, +) -> None: + conn = aiohttp.BaseConnector() + assert conn.allowed_protocol_schema_set == { + "", + "http", + "https", + "ws", + "wss", + } + + async def test_connect_with_limit( loop: asyncio.AbstractEventLoop, key: ConnectionKey ) -> None: @@ -2621,6 +2639,14 @@ async def handler(request: web.Request) -> web.Response: connector = aiohttp.UnixConnector(unix_sockname) assert unix_sockname == connector.path + assert connector.allowed_protocol_schema_set == { + "", + "http", + "https", + "ws", + "wss", + "unix", + } session = ClientSession(connector=connector) r = await session.get(url) @@ -2648,6 +2674,14 @@ async def handler(request: web.Request) -> web.Response: connector = aiohttp.NamedPipeConnector(pipe_name) assert pipe_name == connector.path + assert connector.allowed_protocol_schema_set == { + "", + "http", + "https", + "ws", + "wss", + "npipe", + } session = ClientSession(connector=connector) r = await session.get(url)