From cff007e2565528ff3221a85b548d8f8bdbe011c4 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Tue, 4 Jul 2023 15:00:41 +0100 Subject: [PATCH] Fix typing of ssl parameter (#7335) `True` is not an allowed value. --- CHANGES/7335.misc | 1 + aiohttp/client.py | 7 ++++--- aiohttp/client_reqrep.py | 7 ++++--- aiohttp/connector.py | 27 ++++++++++++++------------- 4 files changed, 23 insertions(+), 19 deletions(-) create mode 100644 CHANGES/7335.misc diff --git a/CHANGES/7335.misc b/CHANGES/7335.misc new file mode 100644 index 00000000000..9ccad2ed9d5 --- /dev/null +++ b/CHANGES/7335.misc @@ -0,0 +1 @@ +Fixed annotation of ``ssl`` parameter to disallow ``True``. -- by :user:`Dreamsorcerer` diff --git a/aiohttp/client.py b/aiohttp/client.py index f7cccdc24ca..9050cc4120c 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -22,6 +22,7 @@ Generic, Iterable, List, + Literal, Mapping, Optional, Set, @@ -344,7 +345,7 @@ async def _request( proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, timeout: Union[ClientTimeout, _SENTINEL, None] = sentinel, - ssl: Optional[Union[SSLContext, bool, Fingerprint]] = None, + ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, proxy_headers: Optional[LooseHeaders] = None, trace_request_ctx: Optional[SimpleNamespace] = None, read_bufsize: Optional[int] = None, @@ -677,7 +678,7 @@ def ws_connect( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, bool, None, Fingerprint] = None, + ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None, proxy_headers: Optional[LooseHeaders] = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, @@ -723,7 +724,7 @@ async def _ws_connect( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, bool, None, Fingerprint] = None, + ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None, proxy_headers: Optional[LooseHeaders] = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index edaf7087fc5..3d7a90d6d2d 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -17,6 +17,7 @@ Dict, Iterable, List, + Literal, Mapping, Optional, Tuple, @@ -156,7 +157,7 @@ class ConnectionKey: host: str port: Optional[int] is_ssl: bool - ssl: Union[SSLContext, None, bool, Fingerprint] + ssl: Union[SSLContext, None, Literal[False], Fingerprint] proxy: Optional[URL] proxy_auth: Optional[BasicAuth] proxy_headers_hash: Optional[int] # hash(CIMultiDict) @@ -210,7 +211,7 @@ def __init__( proxy_auth: Optional[BasicAuth] = None, timer: Optional[BaseTimerContext] = None, session: Optional["ClientSession"] = None, - ssl: Union[SSLContext, bool, Fingerprint, None] = None, + ssl: Union[SSLContext, Literal[False], Fingerprint, None] = None, proxy_headers: Optional[LooseHeaders] = None, traces: Optional[List["Trace"]] = None, trust_env: bool = False, @@ -270,7 +271,7 @@ def is_ssl(self) -> bool: return self.url.scheme in ("https", "wss") @property - def ssl(self) -> Union["SSLContext", None, bool, Fingerprint]: + def ssl(self) -> Union["SSLContext", None, Literal[False], Fingerprint]: return self._ssl @property diff --git a/aiohttp/connector.py b/aiohttp/connector.py index f523787f16c..ff54a2beadc 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -22,6 +22,7 @@ Dict, Iterator, List, + Literal, Optional, Set, Tuple, @@ -464,7 +465,7 @@ def _available_connections(self, key: "ConnectionKey") -> int: return available async def connect( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> Connection: """Get from pool or create new connection.""" key = req.connection_key @@ -659,7 +660,7 @@ def _release( ) async def _create_connection( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: raise NotImplementedError() @@ -734,7 +735,7 @@ def __init__( use_dns_cache: bool = True, ttl_dns_cache: Optional[int] = 10, family: int = 0, - ssl: Union[None, bool, Fingerprint, SSLContext] = None, + ssl: Union[None, Literal[False], Fingerprint, SSLContext] = None, local_addr: Optional[Tuple[str, int]] = None, resolver: Optional[AbstractResolver] = None, keepalive_timeout: Union[None, float, _SENTINEL] = sentinel, @@ -870,7 +871,7 @@ async def _resolve_host( return self._cached_hosts.next_addrs(key) async def _create_connection( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: """Create connection. @@ -906,7 +907,7 @@ def _make_ssl_context(verified: bool) -> SSLContext: sslcontext.set_default_verify_paths() return sslcontext - def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]: + def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: """Logic to get the correct SSL context 0. if req.ssl is false, return None @@ -939,7 +940,7 @@ def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]: else: return None - def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]: + def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: ret = req.ssl if isinstance(ret, Fingerprint): return ret @@ -951,7 +952,7 @@ def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]: async def _wrap_create_connection( self, *args: Any, - req: "ClientRequest", + req: ClientRequest, timeout: "ClientTimeout", client_error: Type[Exception] = ClientConnectorError, **kwargs: Any, @@ -973,7 +974,7 @@ async def _wrap_create_connection( def _warn_about_tls_in_tls( self, underlying_transport: asyncio.Transport, - req: "ClientRequest", + req: ClientRequest, ) -> None: """Issue a warning if the requested URL has HTTPS scheme.""" if req.request_info.url.scheme != "https": @@ -1010,7 +1011,7 @@ def _warn_about_tls_in_tls( async def _start_tls_connection( self, underlying_transport: asyncio.Transport, - req: "ClientRequest", + req: ClientRequest, timeout: "ClientTimeout", client_error: Type[Exception] = ClientConnectorError, ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: @@ -1070,7 +1071,7 @@ async def _start_tls_connection( async def _create_direct_connection( self, - req: "ClientRequest", + req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout", *, @@ -1146,7 +1147,7 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: raise last_exc async def _create_proxy_connection( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: headers: Dict[str, str] = {} if req.proxy_headers is not None: @@ -1285,7 +1286,7 @@ def path(self) -> str: return self._path async def _create_connection( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout( @@ -1345,7 +1346,7 @@ def path(self) -> str: return self._path async def _create_connection( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout(