Skip to content

Commit

Permalink
Fix typing of ssl parameter (#7335)
Browse files Browse the repository at this point in the history
`True` is not an allowed value.
  • Loading branch information
Dreamsorcerer authored Jul 4, 2023
1 parent 726fe21 commit cff007e
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGES/7335.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed annotation of ``ssl`` parameter to disallow ``True``. -- by :user:`Dreamsorcerer`
7 changes: 4 additions & 3 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Generic,
Iterable,
List,
Literal,
Mapping,
Optional,
Set,
Expand Down Expand Up @@ -344,7 +345,7 @@ async def _request(
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
timeout: Union[ClientTimeout, _SENTINEL, None] = sentinel,
ssl: Optional[Union[SSLContext, bool, Fingerprint]] = None,
ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None,
proxy_headers: Optional[LooseHeaders] = None,
trace_request_ctx: Optional[SimpleNamespace] = None,
read_bufsize: Optional[int] = None,
Expand Down Expand Up @@ -677,7 +678,7 @@ def ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, None, Fingerprint] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
proxy_headers: Optional[LooseHeaders] = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
Expand Down Expand Up @@ -723,7 +724,7 @@ async def _ws_connect(
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, None, Fingerprint] = None,
ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None,
proxy_headers: Optional[LooseHeaders] = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
Expand Down
7 changes: 4 additions & 3 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
Tuple,
Expand Down Expand Up @@ -156,7 +157,7 @@ class ConnectionKey:
host: str
port: Optional[int]
is_ssl: bool
ssl: Union[SSLContext, None, bool, Fingerprint]
ssl: Union[SSLContext, None, Literal[False], Fingerprint]
proxy: Optional[URL]
proxy_auth: Optional[BasicAuth]
proxy_headers_hash: Optional[int] # hash(CIMultiDict)
Expand Down Expand Up @@ -210,7 +211,7 @@ def __init__(
proxy_auth: Optional[BasicAuth] = None,
timer: Optional[BaseTimerContext] = None,
session: Optional["ClientSession"] = None,
ssl: Union[SSLContext, bool, Fingerprint, None] = None,
ssl: Union[SSLContext, Literal[False], Fingerprint, None] = None,
proxy_headers: Optional[LooseHeaders] = None,
traces: Optional[List["Trace"]] = None,
trust_env: bool = False,
Expand Down Expand Up @@ -270,7 +271,7 @@ def is_ssl(self) -> bool:
return self.url.scheme in ("https", "wss")

@property
def ssl(self) -> Union["SSLContext", None, bool, Fingerprint]:
def ssl(self) -> Union["SSLContext", None, Literal[False], Fingerprint]:
return self._ssl

@property
Expand Down
27 changes: 14 additions & 13 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Dict,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
Expand Down Expand Up @@ -464,7 +465,7 @@ def _available_connections(self, key: "ConnectionKey") -> int:
return available

async def connect(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> Connection:
"""Get from pool or create new connection."""
key = req.connection_key
Expand Down Expand Up @@ -659,7 +660,7 @@ def _release(
)

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
raise NotImplementedError()

Expand Down Expand Up @@ -734,7 +735,7 @@ def __init__(
use_dns_cache: bool = True,
ttl_dns_cache: Optional[int] = 10,
family: int = 0,
ssl: Union[None, bool, Fingerprint, SSLContext] = None,
ssl: Union[None, Literal[False], Fingerprint, SSLContext] = None,
local_addr: Optional[Tuple[str, int]] = None,
resolver: Optional[AbstractResolver] = None,
keepalive_timeout: Union[None, float, _SENTINEL] = sentinel,
Expand Down Expand Up @@ -870,7 +871,7 @@ async def _resolve_host(
return self._cached_hosts.next_addrs(key)

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
"""Create connection.
Expand Down Expand Up @@ -906,7 +907,7 @@ def _make_ssl_context(verified: bool) -> SSLContext:
sslcontext.set_default_verify_paths()
return sslcontext

def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]:
def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]:
"""Logic to get the correct SSL context
0. if req.ssl is false, return None
Expand Down Expand Up @@ -939,7 +940,7 @@ def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]:
else:
return None

def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]:
def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
ret = req.ssl
if isinstance(ret, Fingerprint):
return ret
Expand All @@ -951,7 +952,7 @@ def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]:
async def _wrap_create_connection(
self,
*args: Any,
req: "ClientRequest",
req: ClientRequest,
timeout: "ClientTimeout",
client_error: Type[Exception] = ClientConnectorError,
**kwargs: Any,
Expand All @@ -973,7 +974,7 @@ async def _wrap_create_connection(
def _warn_about_tls_in_tls(
self,
underlying_transport: asyncio.Transport,
req: "ClientRequest",
req: ClientRequest,
) -> None:
"""Issue a warning if the requested URL has HTTPS scheme."""
if req.request_info.url.scheme != "https":
Expand Down Expand Up @@ -1010,7 +1011,7 @@ def _warn_about_tls_in_tls(
async def _start_tls_connection(
self,
underlying_transport: asyncio.Transport,
req: "ClientRequest",
req: ClientRequest,
timeout: "ClientTimeout",
client_error: Type[Exception] = ClientConnectorError,
) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
Expand Down Expand Up @@ -1070,7 +1071,7 @@ async def _start_tls_connection(

async def _create_direct_connection(
self,
req: "ClientRequest",
req: ClientRequest,
traces: List["Trace"],
timeout: "ClientTimeout",
*,
Expand Down Expand Up @@ -1146,7 +1147,7 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
raise last_exc

async def _create_proxy_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> Tuple[asyncio.BaseTransport, ResponseHandler]:
headers: Dict[str, str] = {}
if req.proxy_headers is not None:
Expand Down Expand Up @@ -1285,7 +1286,7 @@ def path(self) -> str:
return self._path

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
try:
async with ceil_timeout(
Expand Down Expand Up @@ -1345,7 +1346,7 @@ def path(self) -> str:
return self._path

async def _create_connection(
self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
) -> ResponseHandler:
try:
async with ceil_timeout(
Expand Down

0 comments on commit cff007e

Please sign in to comment.