Skip to content

Commit

Permalink
Fix connecting to npipe://, tcp://, and unix:// urls (#8632)
Browse files Browse the repository at this point in the history
Co-authored-by: Sam Bull <[email protected]>
  • Loading branch information
bdraco and Dreamsorcerer authored Aug 7, 2024
1 parent e0ff524 commit b2691f2
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGES/8632.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed connecting to ``npipe://``, ``tcp://``, and ``unix://`` urls -- by :user:`bdraco`.
17 changes: 10 additions & 7 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@
ClientWebSocketResponse,
ClientWSTimeout,
)
from .connector import BaseConnector, NamedPipeConnector, TCPConnector, UnixConnector
from .connector import (
HTTP_AND_EMPTY_SCHEMA_SET,
BaseConnector,
NamedPipeConnector,
TCPConnector,
UnixConnector,
)
from .cookiejar import CookieJar
from .helpers import (
_SENTINEL,
Expand Down Expand Up @@ -210,9 +216,6 @@ class ClientTimeout:

# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})
HTTP_SCHEMA_SET = frozenset({"http", "https", ""})
WS_SCHEMA_SET = frozenset({"ws", "wss"})
ALLOWED_PROTOCOL_SCHEMA_SET = HTTP_SCHEMA_SET | WS_SCHEMA_SET

_RetType = TypeVar("_RetType")
_CharsetResolver = Callable[[ClientResponse, bytes], str]
Expand Down Expand Up @@ -466,7 +469,8 @@ async def _request(
except ValueError as e:
raise InvalidUrlClientError(str_or_url) from e

if url.scheme not in ALLOWED_PROTOCOL_SCHEMA_SET:
assert self._connector is not None
if url.scheme not in self._connector.allowed_protocol_schema_set:
raise NonHttpUrlClientError(url)

skip_headers = set(self._skip_auto_headers)
Expand Down Expand Up @@ -597,7 +601,6 @@ async def _request(
real_timeout.connect,
ceil_threshold=real_timeout.ceil_threshold,
):
assert self._connector is not None
conn = await self._connector.connect(
req, traces=traces, timeout=real_timeout
)
Expand Down Expand Up @@ -693,7 +696,7 @@ async def _request(
) from e

scheme = parsed_redirect_url.scheme
if scheme not in HTTP_SCHEMA_SET:
if scheme not in HTTP_AND_EMPTY_SCHEMA_SET:
resp.close()
raise NonHttpUrlRedirectClientError(r_url)
elif not scheme:
Expand Down
16 changes: 16 additions & 0 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@
SSLContext = object # type: ignore[misc,assignment]


EMPTY_SCHEMA_SET = frozenset({""})
HTTP_SCHEMA_SET = frozenset({"http", "https"})
WS_SCHEMA_SET = frozenset({"ws", "wss"})

HTTP_AND_EMPTY_SCHEMA_SET = HTTP_SCHEMA_SET | EMPTY_SCHEMA_SET
HIGH_LEVEL_SCHEMA_SET = HTTP_AND_EMPTY_SCHEMA_SET | WS_SCHEMA_SET


__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")


Expand Down Expand Up @@ -190,6 +198,8 @@ class BaseConnector:
# abort transport after 2 seconds (cleanup broken connections)
_cleanup_closed_period = 2.0

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET

def __init__(
self,
*,
Expand Down Expand Up @@ -741,6 +751,8 @@ class TCPConnector(BaseConnector):
loop - Optional event loop.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})

def __init__(
self,
*,
Expand Down Expand Up @@ -1342,6 +1354,8 @@ class UnixConnector(BaseConnector):
loop - Optional event loop.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"unix"})

def __init__(
self,
path: str,
Expand Down Expand Up @@ -1396,6 +1410,8 @@ class NamedPipeConnector(BaseConnector):
loop - Optional event loop.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"npipe"})

def __init__(
self,
path: str,
Expand Down
71 changes: 67 additions & 4 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from aiohttp.client import ClientSession
from aiohttp.client_proto import ResponseHandler
from aiohttp.client_reqrep import ClientRequest, ConnectionKey
from aiohttp.connector import BaseConnector, Connection, TCPConnector
from aiohttp.connector import BaseConnector, Connection, TCPConnector, UnixConnector
from aiohttp.pytest_plugin import AiohttpClient
from aiohttp.test_utils import make_mocked_coro
from aiohttp.tracing import Trace
Expand Down Expand Up @@ -536,15 +536,78 @@ async def test_ws_connect_allowed_protocols(
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.url = URL(f"{protocol}://example.com")
resp.url = URL(f"{protocol}://example")
resp.cookies = SimpleCookie()
resp.start = mock.AsyncMock()

req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req_factory = mock.Mock(return_value=req)
req.send = mock.AsyncMock(return_value=resp)
# BaseConnector allows all high level protocols by default
connector = BaseConnector()

session = await create_session(request_class=req_factory)
session = await create_session(connector=connector, request_class=req_factory)

connections = []
assert session._connector is not None
original_connect = session._connector.connect

async def connect(
req: ClientRequest, traces: List[Trace], timeout: aiohttp.ClientTimeout
) -> Connection:
conn = await original_connect(req, traces, timeout)
connections.append(conn)
return conn

async def create_connection(
req: object, traces: object, timeout: object
) -> ResponseHandler:
return create_mocked_conn()

connector = session._connector
with mock.patch.object(connector, "connect", connect), mock.patch.object(
connector, "_create_connection", create_connection
), mock.patch.object(connector, "_release"), mock.patch(
"aiohttp.client.os"
) as m_os:
m_os.urandom.return_value = key_data
await session.ws_connect(f"{protocol}://example")

# normally called during garbage collection. triggers an exception
# if the connection wasn't already closed
for c in connections:
c.close()
c.__del__()

await session.close()


@pytest.mark.parametrize("protocol", ["http", "https", "ws", "wss", "unix"])
async def test_ws_connect_unix_socket_allowed_protocols(
create_session: Callable[..., Awaitable[ClientSession]],
create_mocked_conn: Callable[[], ResponseHandler],
protocol: str,
ws_key: bytes,
key_data: bytes,
) -> None:
resp = mock.create_autospec(aiohttp.ClientResponse)
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.url = URL(f"{protocol}://example")
resp.cookies = SimpleCookie()
resp.start = mock.AsyncMock()

req = mock.create_autospec(aiohttp.ClientRequest, spec_set=True)
req_factory = mock.Mock(return_value=req)
req.send = mock.AsyncMock(return_value=resp)
# UnixConnector allows all high level protocols by default and unix sockets
session = await create_session(
connector=UnixConnector(path=""), request_class=req_factory
)

connections = []
assert session._connector is not None
Expand All @@ -569,7 +632,7 @@ async def create_connection(
"aiohttp.client.os"
) as m_os:
m_os.urandom.return_value = key_data
await session.ws_connect(f"{protocol}://example.com")
await session.ws_connect(f"{protocol}://example")

# normally called during garbage collection. triggers an exception
# if the connection wasn't already closed
Expand Down
34 changes: 34 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,11 @@ async def test_tcp_connector_ctor(loop: asyncio.AbstractEventLoop) -> None:
assert conn.family == 0


async def test_tcp_connector_allowed_protocols(loop: asyncio.AbstractEventLoop) -> None:
conn = aiohttp.TCPConnector()
assert conn.allowed_protocol_schema_set == {"", "tcp", "http", "https", "ws", "wss"}


async def test_invalid_ssl_param() -> None:
with pytest.raises(TypeError):
aiohttp.TCPConnector(ssl=object()) # type: ignore[arg-type]
Expand Down Expand Up @@ -1819,6 +1824,19 @@ async def test_ctor_with_default_loop(loop: asyncio.AbstractEventLoop) -> None:
assert loop is conn._loop


async def test_base_connector_allows_high_level_protocols(
loop: asyncio.AbstractEventLoop,
) -> None:
conn = aiohttp.BaseConnector()
assert conn.allowed_protocol_schema_set == {
"",
"http",
"https",
"ws",
"wss",
}


async def test_connect_with_limit(
loop: asyncio.AbstractEventLoop, key: ConnectionKey
) -> None:
Expand Down Expand Up @@ -2621,6 +2639,14 @@ async def handler(request: web.Request) -> web.Response:

connector = aiohttp.UnixConnector(unix_sockname)
assert unix_sockname == connector.path
assert connector.allowed_protocol_schema_set == {
"",
"http",
"https",
"ws",
"wss",
"unix",
}

session = ClientSession(connector=connector)
r = await session.get(url)
Expand Down Expand Up @@ -2648,6 +2674,14 @@ async def handler(request: web.Request) -> web.Response:

connector = aiohttp.NamedPipeConnector(pipe_name)
assert pipe_name == connector.path
assert connector.allowed_protocol_schema_set == {
"",
"http",
"https",
"ws",
"wss",
"npipe",
}

session = ClientSession(connector=connector)
r = await session.get(url)
Expand Down

0 comments on commit b2691f2

Please sign in to comment.