Skip to content

Commit

Permalink
Merge pull request python-trio#3098 from CoolCat467/enable-flake8-ann…
Browse files Browse the repository at this point in the history
…otations

Enable flake8 annotations
  • Loading branch information
CoolCat467 authored Nov 5, 2024
2 parents c7801ae + cf7df04 commit 0ee5a69
Show file tree
Hide file tree
Showing 23 changed files with 128 additions and 75 deletions.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ allowed-confusables = ["–"]

select = [
"A", # flake8-builtins
"ANN", # flake8-annotations
"ASYNC", # flake8-async
"B", # flake8-bugbear
"C4", # flake8-comprehensions
Expand All @@ -131,6 +132,9 @@ select = [
]
extend-ignore = [
'A002', # builtin-argument-shadowing
'ANN101', # missing-type-self
'ANN102', # missing-type-cls
'ANN401', # any-type (mypy's `disallow_any_explicit` is better)
'E402', # module-import-not-at-top-of-file (usually OS-specific)
'E501', # line-too-long
'F403', # undefined-local-with-import-star
Expand Down Expand Up @@ -160,6 +164,8 @@ extend-ignore = [
'src/trio/_abc.py' = ['A005']
'src/trio/_socket.py' = ['A005']
'src/trio/_ssl.py' = ['A005']
# Don't check annotations in notes-to-self
'notes-to-self/*.py' = ['ANN001', 'ANN002', 'ANN003', 'ANN201', 'ANN202', 'ANN204']

[tool.ruff.lint.isort]
combine-as-imports = true
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __new__( # type: ignore[misc] # "must return a subtype"
) -> tuple[MemorySendChannel[T], MemoryReceiveChannel[T]]:
return _open_memory_channel(max_buffer_size)

def __init__(self, max_buffer_size: int | float): # noqa: PYI041
def __init__(self, max_buffer_size: int | float) -> None: # noqa: PYI041
...

else:
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_core/_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Instruments(dict[str, dict[Instrument, None]]):

__slots__ = ()

def __init__(self, incoming: Sequence[Instrument]):
def __init__(self, incoming: Sequence[Instrument]) -> None:
self["_all"] = {}
for instrument in incoming:
self.add_instrument(instrument)
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_core/_mock_clock.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class MockClock(Clock):
"""

def __init__(self, rate: float = 0.0, autojump_threshold: float = inf):
def __init__(self, rate: float = 0.0, autojump_threshold: float = inf) -> None:
# when the real clock said 'real_base', the virtual time was
# 'virtual_base', and since then it's advanced at 'rate' virtual
# seconds per real second.
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,7 @@ def __init__(
parent_task: Task,
cancel_scope: CancelScope,
strict_exception_groups: bool,
):
) -> None:
self._parent_task = parent_task
self._strict_exception_groups = strict_exception_groups
parent_task._child_nurseries.append(self)
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_dtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def challenge_for(


class _Queue(Generic[_T]):
def __init__(self, incoming_packets_buffer: int | float): # noqa: PYI041
def __init__(self, incoming_packets_buffer: int | float) -> None: # noqa: PYI041
self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer)


Expand Down
7 changes: 6 additions & 1 deletion src/trio/_file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)
from typing_extensions import Literal

from ._sync import CapacityLimiter

# This list is also in the docs, make sure to keep them in sync
_FILE_SYNC_ATTRS: set[str] = {
"closed",
Expand Down Expand Up @@ -241,7 +243,10 @@ def __getattr__(self, name: str) -> object:
meth = getattr(self._wrapped, name)

@async_wraps(self.__class__, self._wrapped.__class__, name)
async def wrapper(*args, **kwargs):
async def wrapper(
*args: Callable[..., T],
**kwargs: object | str | bool | CapacityLimiter | None,
) -> T:
func = partial(meth, *args, **kwargs)
return await trio.to_thread.run_sync(func)

Expand Down
4 changes: 2 additions & 2 deletions src/trio/_highlevel_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SocketStream(HalfCloseableStream):
"""

def __init__(self, socket: SocketType):
def __init__(self, socket: SocketType) -> None:
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketStream requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
Expand Down Expand Up @@ -364,7 +364,7 @@ class SocketListener(Listener[SocketStream]):
"""

def __init__(self, socket: SocketType):
def __init__(self, socket: SocketType) -> None:
if not isinstance(socket, tsocket.SocketType):
raise TypeError("SocketListener requires a Trio socket object")
if socket.type != tsocket.SOCK_STREAM:
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TrioInteractiveConsole(InteractiveConsole):
# we make the type more specific on our subclass
locals: dict[str, object]

def __init__(self, repl_locals: dict[str, object] | None = None):
def __init__(self, repl_locals: dict[str, object] | None = None) -> None:
super().__init__(locals=repl_locals)
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT

Expand Down
4 changes: 2 additions & 2 deletions src/trio/_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class _try_sync:
def __init__(
self,
blocking_exc_override: Callable[[BaseException], bool] | None = None,
):
) -> None:
self._blocking_exc_override = blocking_exc_override

def _is_blocking_io_error(self, exc: BaseException) -> bool:
Expand Down Expand Up @@ -782,7 +782,7 @@ async def sendmsg(


class _SocketType(SocketType):
def __init__(self, sock: _stdlib_socket.socket):
def __init__(self, sock: _stdlib_socket.socket) -> None:
if type(sock) is not _stdlib_socket.socket:
# For example, ssl.SSLSocket subclasses socket.socket, but we
# certainly don't want to blindly wrap one of those.
Expand Down
8 changes: 4 additions & 4 deletions src/trio/_subprocess_platform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def create_pipe_from_child_output() -> tuple[ClosableReceiveStream, int]:

elif os.name == "posix":

def create_pipe_to_child_stdin():
def create_pipe_to_child_stdin() -> tuple[trio.lowlevel.FdStream, int]:
rfd, wfd = os.pipe()
return trio.lowlevel.FdStream(wfd), rfd

def create_pipe_from_child_output():
def create_pipe_from_child_output() -> tuple[trio.lowlevel.FdStream, int]:
rfd, wfd = os.pipe()
return trio.lowlevel.FdStream(rfd), wfd

Expand All @@ -106,12 +106,12 @@ def create_pipe_from_child_output():

from .._windows_pipes import PipeReceiveStream, PipeSendStream

def create_pipe_to_child_stdin():
def create_pipe_to_child_stdin() -> tuple[PipeSendStream, int]:
# for stdin, we want the write end (our end) to use overlapped I/O
rh, wh = windows_pipe(overlapped=(False, True))
return PipeSendStream(wh), msvcrt.open_osfhandle(rh, os.O_RDONLY)

def create_pipe_from_child_output():
def create_pipe_from_child_output() -> tuple[PipeReceiveStream, int]:
# for stdout/err, it's the read end that's overlapped
rh, wh = windows_pipe(overlapped=(True, False))
return PipeReceiveStream(rh), msvcrt.open_osfhandle(wh, 0)
Expand Down
6 changes: 3 additions & 3 deletions src/trio/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class CapacityLimiter(AsyncContextManagerMixin):
"""

# total_tokens would ideally be int|Literal[math.inf] - but that's not valid typing
def __init__(self, total_tokens: int | float): # noqa: PYI041
def __init__(self, total_tokens: int | float) -> None: # noqa: PYI041
self._lot = ParkingLot()
self._borrowers: set[Task | object] = set()
# Maps tasks attempting to acquire -> borrower, to handle on-behalf-of
Expand Down Expand Up @@ -433,7 +433,7 @@ class Semaphore(AsyncContextManagerMixin):
"""

def __init__(self, initial_value: int, *, max_value: int | None = None):
def __init__(self, initial_value: int, *, max_value: int | None = None) -> None:
if not isinstance(initial_value, int):
raise TypeError("initial_value must be an int")
if initial_value < 0:
Expand Down Expand Up @@ -759,7 +759,7 @@ class Condition(AsyncContextManagerMixin):
"""

def __init__(self, lock: Lock | None = None):
def __init__(self, lock: Lock | None = None) -> None:
if lock is None:
lock = Lock()
if type(lock) is not Lock:
Expand Down
5 changes: 1 addition & 4 deletions src/trio/_tests/test_highlevel_open_tcp_listeners.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,7 @@ def setsockopt(
) -> None:
pass

async def bind(
self,
address: AddressFormat,
) -> None:
async def bind(self, address: AddressFormat) -> None:
pass

def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/trio/_tests/test_highlevel_ssl_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def getnameinfo(
self,
sockaddr: tuple[str, int] | tuple[str, int, int, int],
flags: int,
) -> NoReturn:
) -> NoReturn: # pragma: no cover
raise NotImplementedError


Expand Down
15 changes: 11 additions & 4 deletions src/trio/_tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _frozenbind(
flags: int = 0,
) -> GetAddrInfoArgs:
sig = inspect.signature(self._orig_getaddrinfo)
bound = sig.bind(host, port, family, type, proto, flags)
bound = sig.bind(host, port, family=family, type=type, proto=proto, flags=flags)
bound.apply_defaults()
frozenbound = bound.args
assert not bound.kwargs
Expand All @@ -95,9 +95,16 @@ def set(
proto: int = 0,
flags: int = 0,
) -> None:
self._responses[self._frozenbind(host, port, family, type, proto, flags)] = (
response
)
self._responses[
self._frozenbind(
host,
port,
family=family,
type=type,
proto=proto,
flags=flags,
)
] = response

def getaddrinfo(
self,
Expand Down
63 changes: 41 additions & 22 deletions src/trio/_tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def ssl_echo_serve_sync(
# Fixture that gives a raw socket connected to a trio-test-1 echo server
# (running in a thread). Useful for testing making connections with different
# SSLContexts.
@asynccontextmanager # type: ignore[misc] # decorated contains Any
async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]:
@asynccontextmanager
async def ssl_echo_server_raw(expect_fail: bool = False) -> AsyncIterator[SocketStream]:
a, b = stdlib_socket.socketpair()
async with trio.open_nursery() as nursery:
# Exiting the 'with a, b' context manager closes the sockets, which
Expand All @@ -178,20 +178,20 @@ async def ssl_echo_server_raw(**kwargs: Any) -> AsyncIterator[SocketStream]:
with a, b:
nursery.start_soon(
trio.to_thread.run_sync,
partial(ssl_echo_serve_sync, b, **kwargs),
partial(ssl_echo_serve_sync, b, expect_fail=expect_fail),
)

yield SocketStream(tsocket.from_stdlib_socket(a))


# Fixture that gives a properly set up SSLStream connected to a trio-test-1
# echo server (running in a thread)
@asynccontextmanager # type: ignore[misc] # decorated contains Any
@asynccontextmanager
async def ssl_echo_server(
client_ctx: SSLContext,
**kwargs: Any,
expect_fail: bool = False,
) -> AsyncIterator[SSLStream[Stream]]:
async with ssl_echo_server_raw(**kwargs) as sock:
async with ssl_echo_server_raw(expect_fail=expect_fail) as sock:
yield SSLStream(sock, client_ctx, server_hostname="trio-test-1.example.org")


Expand All @@ -201,7 +201,10 @@ async def ssl_echo_server(
# jakkdl: it seems to implement all the abstract methods (now), so I made it inherit
# from Stream for the sake of typechecking.
class PyOpenSSLEchoStream(Stream):
def __init__(self, sleeper: None = None) -> None:
def __init__(
self,
sleeper: Callable[[str], Awaitable[None]] | None = None,
) -> None:
ctx = SSL.Context(SSL.SSLv23_METHOD)
# TLS 1.3 removes renegotiation support. Which is great for them, but
# we still have to support versions before that, and that means we
Expand Down Expand Up @@ -249,6 +252,7 @@ def __init__(self, sleeper: None = None) -> None:
"simultaneous calls to PyOpenSSLEchoStream.receive_some",
)

self.sleeper: Callable[[str], Awaitable[None]]
if sleeper is None:

async def no_op_sleeper(_: object) -> None:
Expand Down Expand Up @@ -384,12 +388,12 @@ async def do_test(
await do_test("receive_some", (1,), "receive_some", (1,))


@contextmanager # type: ignore[misc] # decorated contains Any
@contextmanager
def virtual_ssl_echo_server(
client_ctx: SSLContext,
**kwargs: Any,
sleeper: Callable[[str], Awaitable[None]] | None = None,
) -> Iterator[SSLStream[PyOpenSSLEchoStream]]:
fakesock = PyOpenSSLEchoStream(**kwargs)
fakesock = PyOpenSSLEchoStream(sleeper=sleeper)
yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org")


Expand Down Expand Up @@ -424,31 +428,43 @@ def ssl_wrap_pair( # type: ignore[misc]
MemoryStapledStream: TypeAlias = StapledStream[MemorySendStream, MemoryReceiveStream]


# Explicit "Any" is not allowed
def ssl_memory_stream_pair( # type: ignore[misc]
def ssl_memory_stream_pair(
client_ctx: SSLContext,
**kwargs: Any,
client_kwargs: dict[str, str | bytes | bool | None] | None = None,
server_kwargs: dict[str, str | bytes | bool | None] | None = None,
) -> tuple[
SSLStream[MemoryStapledStream],
SSLStream[MemoryStapledStream],
]:
client_transport, server_transport = memory_stream_pair()
return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
return ssl_wrap_pair(
client_ctx,
client_transport,
server_transport,
client_kwargs=client_kwargs,
server_kwargs=server_kwargs,
)


MyStapledStream: TypeAlias = StapledStream[SendStream, ReceiveStream]


# Explicit "Any" is not allowed
def ssl_lockstep_stream_pair( # type: ignore[misc]
def ssl_lockstep_stream_pair(
client_ctx: SSLContext,
**kwargs: Any,
client_kwargs: dict[str, str | bytes | bool | None] | None = None,
server_kwargs: dict[str, str | bytes | bool | None] | None = None,
) -> tuple[
SSLStream[MyStapledStream],
SSLStream[MyStapledStream],
]:
client_transport, server_transport = lockstep_stream_pair()
return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs)
return ssl_wrap_pair(
client_ctx,
client_transport,
server_transport,
client_kwargs=client_kwargs,
server_kwargs=server_kwargs,
)


# Simple smoke test for handshake/send/receive/shutdown talking to a
Expand Down Expand Up @@ -1327,15 +1343,18 @@ async def test_getpeercert(client_ctx: SSLContext) -> None:


async def test_SSLListener(client_ctx: SSLContext) -> None:
# Explicit "Any" is not allowed
async def setup( # type: ignore[misc]
**kwargs: Any,
async def setup(
https_compatible: bool = False,
) -> tuple[tsocket.SocketType, SSLListener[SocketStream], SSLStream[SocketStream]]:
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(1)
socket_listener = SocketListener(listen_sock)
ssl_listener = SSLListener(socket_listener, SERVER_CTX, **kwargs)
ssl_listener = SSLListener(
socket_listener,
SERVER_CTX,
https_compatible=https_compatible,
)

transport_client = await open_tcp_stream(*listen_sock.getsockname())
ssl_client = SSLStream(
Expand Down
Loading

0 comments on commit 0ee5a69

Please sign in to comment.