diff --git a/newsfragments/3121.misc.rst b/newsfragments/3121.misc.rst new file mode 100644 index 000000000..731232877 --- /dev/null +++ b/newsfragments/3121.misc.rst @@ -0,0 +1 @@ +Improve type annotations in several places by removing `Any` usage. diff --git a/pyproject.toml b/pyproject.toml index cd059d312..cde205b81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -189,6 +189,7 @@ warn_return_any = true # Avoid subtle backsliding disallow_any_decorated = true +disallow_any_explicit = true disallow_any_generics = true disallow_any_unimported = true disallow_incomplete_defs = true diff --git a/src/trio/_core/_concat_tb.py b/src/trio/_core/_concat_tb.py index 5d84118cd..82e525137 100644 --- a/src/trio/_core/_concat_tb.py +++ b/src/trio/_core/_concat_tb.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import TracebackType -from typing import Any, ClassVar, cast +from typing import TYPE_CHECKING, ClassVar, cast ################################################################ # concat_tb @@ -88,7 +88,7 @@ def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackT # cpython/pypy in current type checkers. def controller( # type: ignore[no-any-unimported] operation: tputil.ProxyOperation, - ) -> Any | None: + ) -> TracebackType | None: # Rationale for pragma: I looked fairly carefully and tried a few # things, and AFAICT it's not actually possible to get any # 'opname' that isn't __getattr__ or __getattribute__. So there's @@ -101,9 +101,10 @@ def controller( # type: ignore[no-any-unimported] "__getattr__", } and operation.args[0] == "tb_next" - ): # pragma: no cover + ) or TYPE_CHECKING: # pragma: no cover return tb_next - return operation.delegate() # Delegate is reverting to original behaviour + # Delegate is reverting to original behaviour + return operation.delegate() # type: ignore[no-any-return] return cast( TracebackType, diff --git a/src/trio/_core/_entry_queue.py b/src/trio/_core/_entry_queue.py index 332441a3a..0691de351 100644 --- a/src/trio/_core/_entry_queue.py +++ b/src/trio/_core/_entry_queue.py @@ -16,7 +16,8 @@ PosArgsT = TypeVarTuple("PosArgsT") -Function = Callable[..., object] +# Explicit "Any" is not allowed +Function = Callable[..., object] # type: ignore[misc] Job = tuple[Function, tuple[object, ...]] diff --git a/src/trio/_core/_instrumentation.py b/src/trio/_core/_instrumentation.py index bbab2acd7..40bddd1a2 100644 --- a/src/trio/_core/_instrumentation.py +++ b/src/trio/_core/_instrumentation.py @@ -3,7 +3,7 @@ import logging import types from collections.abc import Callable, Sequence -from typing import Any, TypeVar +from typing import TypeVar from .._abc import Instrument @@ -11,12 +11,14 @@ INSTRUMENT_LOGGER = logging.getLogger("trio.abc.Instrument") -F = TypeVar("F", bound=Callable[..., Any]) +# Explicit "Any" is not allowed +F = TypeVar("F", bound=Callable[..., object]) # type: ignore[misc] # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. -def _public(fn: F) -> F: +# Explicit "Any" is not allowed +def _public(fn: F) -> F: # type: ignore[misc] return fn @@ -92,7 +94,7 @@ def remove_instrument(self, instrument: Instrument) -> None: def call( self, hookname: str, - *args: Any, + *args: object, ) -> None: """Call hookname(*args) on each applicable instrument. diff --git a/src/trio/_core/_io_windows.py b/src/trio/_core/_io_windows.py index 80b62d477..c4325a286 100644 --- a/src/trio/_core/_io_windows.py +++ b/src/trio/_core/_io_windows.py @@ -7,8 +7,8 @@ from contextlib import contextmanager from typing import ( TYPE_CHECKING, - Any, Literal, + Protocol, TypeVar, cast, ) @@ -24,6 +24,7 @@ AFDPollFlags, CData, CompletionModes, + CType, ErrorCodes, FileFlags, Handle, @@ -249,13 +250,28 @@ class AFDWaiters: current_op: AFDPollOp | None = None +# Just used for internal type checking. +class _AFDHandle(Protocol): + Handle: Handle + Status: int + Events: int + + +# Just used for internal type checking. +class _AFDPollInfo(Protocol): + Timeout: int + NumberOfHandles: int + Exclusive: int + Handles: list[_AFDHandle] + + # We also need to bundle up all the info for a single op into a standalone # object, because we need to keep all these objects alive until the operation # finishes, even if we're throwing it away. @attrs.frozen(eq=False) class AFDPollOp: lpOverlapped: CData - poll_info: Any + poll_info: _AFDPollInfo waiters: AFDWaiters afd_group: AFDGroup @@ -684,7 +700,7 @@ def _refresh_afd(self, base_handle: Handle) -> None: lpOverlapped = ffi.new("LPOVERLAPPED") - poll_info: Any = ffi.new("AFD_POLL_INFO *") + poll_info = cast(_AFDPollInfo, ffi.new("AFD_POLL_INFO *")) poll_info.Timeout = 2**63 - 1 # INT64_MAX poll_info.NumberOfHandles = 1 poll_info.Exclusive = 0 @@ -697,9 +713,9 @@ def _refresh_afd(self, base_handle: Handle) -> None: kernel32.DeviceIoControl( afd_group.handle, IoControlCodes.IOCTL_AFD_POLL, - poll_info, + cast(CType, poll_info), ffi.sizeof("AFD_POLL_INFO"), - poll_info, + cast(CType, poll_info), ffi.sizeof("AFD_POLL_INFO"), ffi.NULL, lpOverlapped, diff --git a/src/trio/_core/_ki.py b/src/trio/_core/_ki.py index 672501f75..46a7fdf70 100644 --- a/src/trio/_core/_ki.py +++ b/src/trio/_core/_ki.py @@ -4,7 +4,7 @@ import sys import types import weakref -from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar +from typing import TYPE_CHECKING, Generic, Protocol, TypeVar import attrs @@ -85,7 +85,12 @@ class _IdRef(weakref.ref[_T]): __slots__ = ("_hash",) _hash: int - def __new__(cls, ob: _T, callback: Callable[[Self], Any] | None = None, /) -> Self: + def __new__( + cls, + ob: _T, + callback: Callable[[Self], object] | None = None, + /, + ) -> Self: self: Self = weakref.ref.__new__(cls, ob, callback) self._hash = object.__hash__(ob) return self diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index f3ee0eb7e..21bdecd44 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -82,14 +82,13 @@ StatusT = TypeVar("StatusT") StatusT_contra = TypeVar("StatusT_contra", contravariant=True) -FnT = TypeVar("FnT", bound="Callable[..., Any]") RetT = TypeVar("RetT") DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: Final = 1000 # Passed as a sentinel -_NO_SEND: Final[Outcome[Any]] = cast("Outcome[Any]", object()) +_NO_SEND: Final[Outcome[object]] = cast("Outcome[object]", object()) # Used to track if an exceptiongroup can be collapsed NONSTRICT_EXCEPTIONGROUP_NOTE = 'This is a "loose" ExceptionGroup, and may be collapsed by Trio if it only contains one exception - typically after `Cancelled` has been stripped from it. Note this has consequences for exception handling, and strict_exception_groups=True is recommended.' @@ -102,7 +101,7 @@ class _NoStatus(metaclass=NoPublicConstructor): # Decorator to mark methods public. This does nothing by itself, but # trio/_tools/gen_exports.py looks for it. -def _public(fn: FnT) -> FnT: +def _public(fn: RetT) -> RetT: return fn @@ -1172,7 +1171,11 @@ def _check_nursery_closed(self) -> None: self._parent_waiting_in_aexit = False GLOBAL_RUN_CONTEXT.runner.reschedule(self._parent_task) - def _child_finished(self, task: Task, outcome: Outcome[Any]) -> None: + def _child_finished( + self, + task: Task, + outcome: Outcome[object], + ) -> None: self._children.remove(task) if isinstance(outcome, Error): self._add_exc(outcome.error) @@ -1278,12 +1281,14 @@ def start_soon( """ GLOBAL_RUN_CONTEXT.runner.spawn_impl(async_fn, args, self, name) - async def start( + # Typing changes blocked by https://github.com/python/mypy/pull/17512 + # Explicit "Any" is not allowed + async def start( # type: ignore[misc] self, async_fn: Callable[..., Awaitable[object]], *args: object, name: object = None, - ) -> Any: + ) -> Any | None: r"""Creates and initializes a child task. Like :meth:`start_soon`, but blocks until the new task has @@ -1334,7 +1339,10 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): # set strict_exception_groups = True to make sure we always unwrap # *this* nursery's exceptiongroup async with open_nursery(strict_exception_groups=True) as old_nursery: - task_status: _TaskStatus[Any] = _TaskStatus(old_nursery, self) + task_status: _TaskStatus[object | None] = _TaskStatus( + old_nursery, + self, + ) thunk = functools.partial(async_fn, task_status=task_status) task = GLOBAL_RUN_CONTEXT.runner.spawn_impl( thunk, @@ -1375,9 +1383,10 @@ def __del__(self) -> None: @final @attrs.define(eq=False, repr=False) -class Task(metaclass=NoPublicConstructor): +class Task(metaclass=NoPublicConstructor): # type: ignore[misc] _parent_nursery: Nursery | None - coro: Coroutine[Any, Outcome[object], Any] + # Explicit "Any" is not allowed + coro: Coroutine[Any, Outcome[object], Any] # type: ignore[misc] _runner: Runner name: str context: contextvars.Context @@ -1395,10 +1404,11 @@ class Task(metaclass=NoPublicConstructor): # tracebacks with extraneous frames. # - for scheduled tasks, custom_sleep_data is None # Tasks start out unscheduled. - _next_send_fn: Callable[[Any], object] | None = None - _next_send: Outcome[Any] | None | BaseException = None + # Explicit "Any" is not allowed + _next_send_fn: Callable[[Any], object] | None = None # type: ignore[misc] + _next_send: Outcome[Any] | None | BaseException = None # type: ignore[misc] _abort_func: Callable[[_core.RaiseCancelT], Abort] | None = None - custom_sleep_data: Any = None + custom_sleep_data: Any = None # type: ignore[misc] # For introspection and nursery.start() _child_nurseries: list[Nursery] = attrs.Factory(list) @@ -1466,7 +1476,7 @@ def print_stack_for_task(task): """ # Ignore static typing as we're doing lots of dynamic introspection - coro: Any = self.coro + coro: Any = self.coro # type: ignore[misc] while coro is not None: if hasattr(coro, "cr_frame"): # A real coroutine @@ -1611,13 +1621,16 @@ class RunStatistics: @attrs.define(eq=False) -class GuestState: +# Explicit "Any" is not allowed +class GuestState: # type: ignore[misc] runner: Runner run_sync_soon_threadsafe: Callable[[Callable[[], object]], object] run_sync_soon_not_threadsafe: Callable[[Callable[[], object]], object] - done_callback: Callable[[Outcome[Any]], object] + # Explicit "Any" is not allowed + done_callback: Callable[[Outcome[Any]], object] # type: ignore[misc] unrolled_run_gen: Generator[float, EventResult, None] - unrolled_run_next_send: Outcome[Any] = attrs.Factory(lambda: Value(None)) + # Explicit "Any" is not allowed + unrolled_run_next_send: Outcome[Any] = attrs.Factory(lambda: Value(None)) # type: ignore[misc] def guest_tick(self) -> None: prev_library, sniffio_library.name = sniffio_library.name, "trio" @@ -1662,7 +1675,8 @@ def in_main_thread() -> None: @attrs.define(eq=False) -class Runner: +# Explicit "Any" is not allowed +class Runner: # type: ignore[misc] clock: Clock instruments: Instruments io_manager: TheIOManager @@ -1670,7 +1684,8 @@ class Runner: strict_exception_groups: bool # Run-local values, see _local.py - _locals: dict[_core.RunVar[Any], Any] = attrs.Factory(dict) + # Explicit "Any" is not allowed + _locals: dict[_core.RunVar[Any], object] = attrs.Factory(dict) # type: ignore[misc] runq: deque[Task] = attrs.Factory(deque) tasks: set[Task] = attrs.Factory(set) @@ -1681,7 +1696,7 @@ class Runner: system_nursery: Nursery | None = None system_context: contextvars.Context = attrs.field(kw_only=True) main_task: Task | None = None - main_task_outcome: Outcome[Any] | None = None + main_task_outcome: Outcome[object] | None = None entry_queue: EntryQueue = attrs.Factory(EntryQueue) trio_token: TrioToken | None = None @@ -1774,11 +1789,7 @@ def current_root_task(self) -> Task | None: ################ @_public - def reschedule( - self, - task: Task, - next_send: Outcome[object] = _NO_SEND, - ) -> None: + def reschedule(self, task: Task, next_send: Outcome[object] = _NO_SEND) -> None: """Reschedule the given task with the given :class:`outcome.Outcome`. @@ -1889,7 +1900,7 @@ async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: self.reschedule(task, None) # type: ignore[arg-type] return task - def task_exited(self, task: Task, outcome: Outcome[Any]) -> None: + def task_exited(self, task: Task, outcome: Outcome[object]) -> None: # break parking lots associated with the exiting task if task in GLOBAL_PARKING_LOT_BREAKER: for lot in GLOBAL_PARKING_LOT_BREAKER[task]: @@ -2101,7 +2112,8 @@ def _deliver_ki_cb(self) -> None: # sortedcontainers doesn't have types, and is reportedly very hard to type: # https://github.com/grantjenks/python-sortedcontainers/issues/68 - waiting_for_idle: Any = attrs.Factory(SortedDict) + # Explicit "Any" is not allowed + waiting_for_idle: Any = attrs.Factory(SortedDict) # type: ignore[misc] @_public async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None: @@ -2402,7 +2414,8 @@ def run( raise AssertionError(runner.main_task_outcome) -def start_guest_run( +# Explicit .../"Any" not allowed +def start_guest_run( # type: ignore[misc] async_fn: Callable[..., Awaitable[RetT]], *args: object, run_sync_soon_threadsafe: Callable[[Callable[[], object]], object], @@ -2706,7 +2719,7 @@ def unrolled_run( next_send_fn = task._next_send_fn next_send = task._next_send task._next_send_fn = task._next_send = None - final_outcome: Outcome[Any] | None = None + final_outcome: Outcome[object] | None = None try: # We used to unwrap the Outcome object here and send/throw # its contents in directly, but it turns out that .throw() @@ -2815,15 +2828,15 @@ def unrolled_run( ################################################################ -class _TaskStatusIgnored(TaskStatus[Any]): +class _TaskStatusIgnored(TaskStatus[object]): def __repr__(self) -> str: return "TASK_STATUS_IGNORED" - def started(self, value: Any = None) -> None: + def started(self, value: object = None) -> None: pass -TASK_STATUS_IGNORED: Final[TaskStatus[Any]] = _TaskStatusIgnored() +TASK_STATUS_IGNORED: Final[TaskStatus[object]] = _TaskStatusIgnored() def current_task() -> Task: diff --git a/src/trio/_core/_tests/test_guest_mode.py b/src/trio/_core/_tests/test_guest_mode.py index 71fcf3da3..678874fec 100644 --- a/src/trio/_core/_tests/test_guest_mode.py +++ b/src/trio/_core/_tests/test_guest_mode.py @@ -11,14 +11,14 @@ import time import traceback import warnings -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence from functools import partial from math import inf from typing import ( TYPE_CHECKING, - Any, NoReturn, TypeVar, + cast, ) import pytest @@ -26,7 +26,7 @@ import trio import trio.testing -from trio.abc import Instrument +from trio.abc import Clock, Instrument from ..._util import signal_raise from .tutil import gc_collect_harder, restore_unraisablehook @@ -37,7 +37,7 @@ from trio._channel import MemorySendChannel T = TypeVar("T") -InHost: TypeAlias = Callable[[object], None] +InHost: TypeAlias = Callable[[Callable[[], object]], None] # The simplest possible "host" loop. @@ -47,12 +47,16 @@ # - final result is returned # - any unhandled exceptions cause an immediate crash def trivial_guest_run( - trio_fn: Callable[..., Awaitable[T]], + trio_fn: Callable[[InHost], Awaitable[T]], *, in_host_after_start: Callable[[], None] | None = None, - **start_guest_run_kwargs: Any, + host_uses_signal_set_wakeup_fd: bool = False, + clock: Clock | None = None, + instruments: Sequence[Instrument] = (), + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = True, ) -> T: - todo: queue.Queue[tuple[str, Outcome[T] | Callable[..., object]]] = queue.Queue() + todo: queue.Queue[tuple[str, Outcome[T] | Callable[[], object]]] = queue.Queue() host_thread = threading.current_thread() @@ -86,7 +90,11 @@ def done_callback(outcome: Outcome[T]) -> None: run_sync_soon_threadsafe=run_sync_soon_threadsafe, run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe, done_callback=done_callback, - **start_guest_run_kwargs, + host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd, + clock=clock, + instruments=instruments, + restrict_keyboard_interrupt_to_checkpoints=restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups=strict_exception_groups, ) if in_host_after_start is not None: in_host_after_start() @@ -170,10 +178,16 @@ async def early_task() -> None: assert res == "ok" assert set(record) == {"system task ran", "main task ran", "run_sync_soon cb ran"} - class BadClock: + class BadClock(Clock): def start_clock(self) -> NoReturn: raise ValueError("whoops") + def current_time(self) -> float: + raise NotImplementedError() + + def deadline_to_sleep_time(self, deadline: float) -> float: + raise NotImplementedError() + def after_start_never_runs() -> None: # pragma: no cover pytest.fail("shouldn't get here") @@ -431,14 +445,20 @@ async def abandoned_main(in_host: InHost) -> None: def aiotrio_run( - trio_fn: Callable[..., Awaitable[T]], + trio_fn: Callable[[], Awaitable[T]], *, pass_not_threadsafe: bool = True, - **start_guest_run_kwargs: Any, + run_sync_soon_not_threadsafe: InHost | None = None, + host_uses_signal_set_wakeup_fd: bool = False, + clock: Clock | None = None, + instruments: Sequence[Instrument] = (), + restrict_keyboard_interrupt_to_checkpoints: bool = False, + strict_exception_groups: bool = True, ) -> T: loop = asyncio.new_event_loop() async def aio_main() -> T: + nonlocal run_sync_soon_not_threadsafe trio_done_fut = loop.create_future() def trio_done_callback(main_outcome: Outcome[object]) -> None: @@ -446,13 +466,18 @@ def trio_done_callback(main_outcome: Outcome[object]) -> None: trio_done_fut.set_result(main_outcome) if pass_not_threadsafe: - start_guest_run_kwargs["run_sync_soon_not_threadsafe"] = loop.call_soon + run_sync_soon_not_threadsafe = cast(InHost, loop.call_soon) trio.lowlevel.start_guest_run( trio_fn, run_sync_soon_threadsafe=loop.call_soon_threadsafe, done_callback=trio_done_callback, - **start_guest_run_kwargs, + run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe, + host_uses_signal_set_wakeup_fd=host_uses_signal_set_wakeup_fd, + clock=clock, + instruments=instruments, + restrict_keyboard_interrupt_to_checkpoints=restrict_keyboard_interrupt_to_checkpoints, + strict_exception_groups=strict_exception_groups, ) return (await trio_done_fut).unwrap() # type: ignore[no-any-return] diff --git a/src/trio/_core/_tests/test_ki.py b/src/trio/_core/_tests/test_ki.py index d403cfa7a..823c7aab0 100644 --- a/src/trio/_core/_tests/test_ki.py +++ b/src/trio/_core/_tests/test_ki.py @@ -678,7 +678,10 @@ async def _consume_async_generator(agen: AsyncGenerator[None, None]) -> None: await agen.aclose() -def _consume_function_for_coverage(fn: Callable[..., object]) -> None: +# Explicit .../"Any" is not allowed +def _consume_function_for_coverage( # type: ignore[misc] + fn: Callable[..., object], +) -> None: result = fn() if inspect.isasyncgen(result): result = _consume_async_generator(result) diff --git a/src/trio/_core/_tests/test_parking_lot.py b/src/trio/_core/_tests/test_parking_lot.py index d9afee83d..809fb2824 100644 --- a/src/trio/_core/_tests/test_parking_lot.py +++ b/src/trio/_core/_tests/test_parking_lot.py @@ -304,9 +304,10 @@ async def test_parking_lot_breaker_registration() -> None: # registering a task as breaker on an already broken lot is fine lot.break_lot() - child_task = None + child_task: _core.Task | None = None async with trio.open_nursery() as nursery: child_task = await nursery.start(dummy_task) + assert isinstance(child_task, _core.Task) add_parking_lot_breaker(child_task, lot) nursery.cancel_scope.cancel() assert lot.broken_by == [task, child_task] @@ -339,6 +340,9 @@ async def test_parking_lot_multiple_breakers_exit() -> None: child_task1 = await nursery.start(dummy_task) child_task2 = await nursery.start(dummy_task) child_task3 = await nursery.start(dummy_task) + assert isinstance(child_task1, _core.Task) + assert isinstance(child_task2, _core.Task) + assert isinstance(child_task3, _core.Task) add_parking_lot_breaker(child_task1, lot) add_parking_lot_breaker(child_task2, lot) add_parking_lot_breaker(child_task3, lot) @@ -350,9 +354,11 @@ async def test_parking_lot_multiple_breakers_exit() -> None: async def test_parking_lot_breaker_register_exited_task() -> None: lot = ParkingLot() - child_task = None + child_task: _core.Task | None = None async with trio.open_nursery() as nursery: - child_task = await nursery.start(dummy_task) + value = await nursery.start(dummy_task) + assert isinstance(value, _core.Task) + child_task = value nursery.cancel_scope.cancel() # trying to register an exited task as lot breaker errors with pytest.raises( diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index f04c95161..7b3958804 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -10,7 +10,7 @@ import weakref from contextlib import ExitStack, contextmanager, suppress from math import inf, nan -from typing import TYPE_CHECKING, Any, NoReturn, TypeVar, cast +from typing import TYPE_CHECKING, NoReturn, TypeVar import outcome import pytest @@ -823,7 +823,9 @@ async def task3(task_status: _core.TaskStatus[_core.CancelScope]) -> None: await sleep_forever() async with _core.open_nursery() as nursery: - scope: _core.CancelScope = await nursery.start(task3) + value = await nursery.start(task3) + assert isinstance(value, _core.CancelScope) + scope: _core.CancelScope = value with pytest.raises(RuntimeError, match="from unrelated"): scope.__exit__(None, None, None) scope.cancel() @@ -1646,7 +1648,10 @@ async def func1(expected: str) -> None: async def func2() -> None: # pragma: no cover pass - async def check(spawn_fn: Callable[..., object]) -> None: + # Explicit .../"Any" is not allowed + async def check( # type: ignore[misc] + spawn_fn: Callable[..., object], + ) -> None: spawn_fn(func1, "func1") spawn_fn(func1, "func2", name=func2) spawn_fn(func1, "func3", name="func3") @@ -1681,13 +1686,14 @@ async def test_current_effective_deadline(mock_clock: _core.MockClock) -> None: def test_nice_error_on_bad_calls_to_run_or_spawn() -> None: - def bad_call_run( + # Explicit .../"Any" is not allowed + def bad_call_run( # type: ignore[misc] func: Callable[..., Awaitable[object]], *args: tuple[object, ...], ) -> None: _core.run(func, *args) - def bad_call_spawn( + def bad_call_spawn( # type: ignore[misc] func: Callable[..., Awaitable[object]], *args: tuple[object, ...], ) -> None: @@ -1959,7 +1965,9 @@ async def sleeping_children( # Cancelling the setup_nursery just *before* calling started() async with _core.open_nursery() as nursery: - target_nursery: _core.Nursery = await nursery.start(setup_nursery) + value = await nursery.start(setup_nursery) + assert isinstance(value, _core.Nursery) + target_nursery: _core.Nursery = value await target_nursery.start( sleeping_children, target_nursery.cancel_scope.cancel, @@ -1967,7 +1975,9 @@ async def sleeping_children( # Cancelling the setup_nursery just *after* calling started() async with _core.open_nursery() as nursery: - target_nursery = await nursery.start(setup_nursery) + value = await nursery.start(setup_nursery) + assert isinstance(value, _core.Nursery) + target_nursery = value await target_nursery.start(sleeping_children, lambda: None) target_nursery.cancel_scope.cancel() @@ -2285,7 +2295,8 @@ async def detachable_coroutine( await sleep(0) nonlocal task, pdco_outcome task = _core.current_task() - pdco_outcome = await outcome.acapture( + # `No overload variant of "acapture" matches argument types "Callable[[Outcome[object]], Coroutine[Any, Any, object]]", "Outcome[None]"` + pdco_outcome = await outcome.acapture( # type: ignore[call-overload] _core.permanently_detach_coroutine_object, task_outcome, ) @@ -2298,10 +2309,11 @@ async def detachable_coroutine( # is still iterable. At that point anything can be sent into the coroutine, so the .coro type # is wrong. assert pdco_outcome is None - assert not_none(task).coro.send(cast(Any, "be free!")) == "I'm free!" + # `Argument 1 to "send" of "Coroutine" has incompatible type "str"; expected "Outcome[object]"` + assert not_none(task).coro.send("be free!") == "I'm free!" # type: ignore[arg-type] assert pdco_outcome == outcome.Value("be free!") with pytest.raises(StopIteration): - not_none(task).coro.send(cast(Any, None)) + not_none(task).coro.send(None) # type: ignore[arg-type] # Check the exception paths too task = None @@ -2314,7 +2326,7 @@ async def detachable_coroutine( assert not_none(task).coro.throw(throw_in) == "uh oh" assert pdco_outcome == outcome.Error(throw_in) with pytest.raises(StopIteration): - task.coro.send(cast(Any, None)) + task.coro.send(None) async def bad_detach() -> None: async with _core.open_nursery(): @@ -2366,9 +2378,10 @@ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: # pragma: no cover await wait_all_tasks_blocked() # Okay, it's detached. Here's our coroutine runner: - assert not_none(task).coro.send(cast(Any, "not trio!")) == 1 - assert not_none(task).coro.send(cast(Any, None)) == 2 - assert not_none(task).coro.send(cast(Any, None)) == "byebye" + # `Argument 1 to "send" of "Coroutine" has incompatible type "str"; expected "Outcome[object]"` + assert not_none(task).coro.send("not trio!") == 1 # type: ignore[arg-type] + assert not_none(task).coro.send(None) == 2 # type: ignore[arg-type] + assert not_none(task).coro.send(None) == "byebye" # type: ignore[arg-type] # Now it's been reattached, and we can leave the nursery @@ -2398,7 +2411,8 @@ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: await wait_all_tasks_blocked() assert task is not None nursery.cancel_scope.cancel() - task.coro.send(cast(Any, None)) + # `Argument 1 to "send" of "Coroutine" has incompatible type "None"; expected "Outcome[object]"` + task.coro.send(None) # type: ignore[arg-type] assert abort_fn_called diff --git a/src/trio/_core/_thread_cache.py b/src/trio/_core/_thread_cache.py index c61222269..189d5a583 100644 --- a/src/trio/_core/_thread_cache.py +++ b/src/trio/_core/_thread_cache.py @@ -7,10 +7,13 @@ from functools import partial from itertools import count from threading import Lock, Thread -from typing import Any, Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar import outcome +if TYPE_CHECKING: + from collections.abc import Callable + RetT = TypeVar("RetT") @@ -126,6 +129,8 @@ def darwin_namefunc( class WorkerThread(Generic[RetT]): + __slots__ = ("_default_name", "_job", "_thread", "_thread_cache", "_worker_lock") + def __init__(self, thread_cache: ThreadCache) -> None: self._job: ( tuple[ @@ -207,8 +212,11 @@ def _work(self) -> None: class ThreadCache: + __slots__ = ("_idle_workers",) + def __init__(self) -> None: - self._idle_workers: dict[WorkerThread[Any], None] = {} + # Explicit "Any" not allowed + self._idle_workers: dict[WorkerThread[Any], None] = {} # type: ignore[misc] def start_thread_soon( self, diff --git a/src/trio/_core/_traps.py b/src/trio/_core/_traps.py index fc31a182a..bef77b768 100644 --- a/src/trio/_core/_traps.py +++ b/src/trio/_core/_traps.py @@ -4,7 +4,10 @@ import enum import types -from typing import TYPE_CHECKING, Any, Callable, NoReturn +from collections.abc import Awaitable + +# Jedi gets mad in test_static_tool_sees_class_members if we use collections Callable +from typing import TYPE_CHECKING, Any, Callable, NoReturn, Union, cast import attrs import outcome @@ -12,10 +15,40 @@ from . import _run if TYPE_CHECKING: + from collections.abc import Generator + from typing_extensions import TypeAlias from ._run import Task +RaiseCancelT: TypeAlias = Callable[[], NoReturn] + + +# This class object is used as a singleton. +# Not exported in the trio._core namespace, but imported directly by _run. +class CancelShieldedCheckpoint: + __slots__ = () + + +# Not exported in the trio._core namespace, but imported directly by _run. +@attrs.frozen(slots=False) +class WaitTaskRescheduled: + abort_func: Callable[[RaiseCancelT], Abort] + + +# Not exported in the trio._core namespace, but imported directly by _run. +@attrs.frozen(slots=False) +class PermanentlyDetachCoroutineObject: + final_outcome: outcome.Outcome[object] + + +MessageType: TypeAlias = Union[ + type[CancelShieldedCheckpoint], + WaitTaskRescheduled, + PermanentlyDetachCoroutineObject, + object, +] + # Helper for the bottommost 'yield'. You can't use 'yield' inside an async # function, but you can inside a generator, and if you decorate your generator @@ -25,14 +58,18 @@ # tracking machinery. Since our traps are public APIs, we make them real async # functions, and then this helper takes care of the actual yield: @types.coroutine -def _async_yield(obj: Any) -> Any: # type: ignore[misc] +def _real_async_yield( + obj: MessageType, +) -> Generator[MessageType, None, None]: return (yield obj) -# This class object is used as a singleton. -# Not exported in the trio._core namespace, but imported directly by _run. -class CancelShieldedCheckpoint: - pass +# Real yield value is from trio's main loop, but type checkers can't +# understand that, so we cast it to make type checkers understand. +_async_yield = cast( + Callable[[MessageType], Awaitable[outcome.Outcome[object]]], + _real_async_yield, +) async def cancel_shielded_checkpoint() -> None: @@ -66,18 +103,10 @@ class Abort(enum.Enum): FAILED = 2 -# Not exported in the trio._core namespace, but imported directly by _run. -@attrs.frozen(slots=False) -class WaitTaskRescheduled: - abort_func: Callable[[RaiseCancelT], Abort] - - -RaiseCancelT: TypeAlias = Callable[[], NoReturn] - - # Should always return the type a Task "expects", unless you willfully reschedule it # with a bad value. -async def wait_task_rescheduled( +# Explicit "Any" is not allowed +async def wait_task_rescheduled( # type: ignore[misc] abort_func: Callable[[RaiseCancelT], Abort], ) -> Any: """Put the current task to sleep, with cancellation support. @@ -181,15 +210,9 @@ def abort(inner_raise_cancel): return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap() -# Not exported in the trio._core namespace, but imported directly by _run. -@attrs.frozen(slots=False) -class PermanentlyDetachCoroutineObject: - final_outcome: outcome.Outcome[Any] - - async def permanently_detach_coroutine_object( - final_outcome: outcome.Outcome[Any], -) -> Any: + final_outcome: outcome.Outcome[object], +) -> object: """Permanently detach the current task from the Trio scheduler. Normally, a Trio task doesn't exit until its coroutine object exits. When @@ -222,7 +245,7 @@ async def permanently_detach_coroutine_object( async def temporarily_detach_coroutine_object( abort_func: Callable[[RaiseCancelT], Abort], -) -> Any: +) -> object: """Temporarily detach the current coroutine object from the Trio scheduler. diff --git a/src/trio/_dtls.py b/src/trio/_dtls.py index 70115d18d..c971471c0 100644 --- a/src/trio/_dtls.py +++ b/src/trio/_dtls.py @@ -19,7 +19,6 @@ from itertools import count from typing import ( TYPE_CHECKING, - Any, Generic, TypeVar, Union, @@ -1220,7 +1219,9 @@ def __init__( # as a peer provides a valid cookie, we can immediately tear down the # old connection. # {remote address: DTLSChannel} - self._streams: WeakValueDictionary[Any, DTLSChannel] = WeakValueDictionary() + self._streams: WeakValueDictionary[AddressFormat, DTLSChannel] = ( + WeakValueDictionary() + ) self._listening_context: SSL.Context | None = None self._listening_key: bytes | None = None self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) diff --git a/src/trio/_file_io.py b/src/trio/_file_io.py index 1f8202e7a..5a612ffb0 100644 --- a/src/trio/_file_io.py +++ b/src/trio/_file_io.py @@ -449,7 +449,7 @@ async def open_file( newline: str | None = None, closefd: bool = True, opener: _Opener | None = None, -) -> AsyncIOWrapper[Any]: +) -> AsyncIOWrapper[object]: """Asynchronous version of :func:`open`. Returns: diff --git a/src/trio/_highlevel_open_tcp_stream.py b/src/trio/_highlevel_open_tcp_stream.py index b7c0468e0..d4ec98355 100644 --- a/src/trio/_highlevel_open_tcp_stream.py +++ b/src/trio/_highlevel_open_tcp_stream.py @@ -134,16 +134,9 @@ def close_all() -> Generator[set[SocketType], None, None]: raise BaseExceptionGroup("", errs) -def reorder_for_rfc_6555_section_5_4( - targets: list[ - tuple[ - AddressFamily, - SocketKind, - int, - str, - Any, - ] - ], +# Explicit "Any" is not allowed +def reorder_for_rfc_6555_section_5_4( # type: ignore[misc] + targets: list[tuple[AddressFamily, SocketKind, int, str, Any]], ) -> None: # RFC 6555 section 5.4 says that if getaddrinfo returns multiple address # families (e.g. IPv4 and IPv6), then you should make sure that your first diff --git a/src/trio/_highlevel_serve_listeners.py b/src/trio/_highlevel_serve_listeners.py index 0a85c8ecb..9b17f8d53 100644 --- a/src/trio/_highlevel_serve_listeners.py +++ b/src/trio/_highlevel_serve_listeners.py @@ -25,7 +25,8 @@ StreamT = TypeVar("StreamT", bound=trio.abc.AsyncResource) -ListenerT = TypeVar("ListenerT", bound=trio.abc.Listener[Any]) +# Explicit "Any" is not allowed +ListenerT = TypeVar("ListenerT", bound=trio.abc.Listener[Any]) # type: ignore[misc] Handler = Callable[[StreamT], Awaitable[object]] @@ -67,7 +68,8 @@ async def _serve_one_listener( # https://github.com/python/typing/issues/548 -async def serve_listeners( +# Explicit "Any" is not allowed +async def serve_listeners( # type: ignore[misc] handler: Handler[StreamT], listeners: list[ListenerT], *, diff --git a/src/trio/_path.py b/src/trio/_path.py index 2c9dfff29..a58136b75 100644 --- a/src/trio/_path.py +++ b/src/trio/_path.py @@ -30,8 +30,9 @@ T = TypeVar("T") -def _wraps_async( - wrapped: Callable[..., Any], +# Explicit .../"Any" is not allowed +def _wraps_async( # type: ignore[misc] + wrapped: Callable[..., object], ) -> Callable[[Callable[P, T]], Callable[P, Awaitable[T]]]: def decorator(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: diff --git a/src/trio/_socket.py b/src/trio/_socket.py index fe7e1502b..ce57aab9c 100644 --- a/src/trio/_socket.py +++ b/src/trio/_socket.py @@ -50,7 +50,8 @@ # most users, so currently we just specify it as `Any`. Otherwise we would write: # `AddressFormat = TypeVar("AddressFormat")` # but instead we simply do: -AddressFormat: TypeAlias = Any +# Explicit "Any" is not allowed +AddressFormat: TypeAlias = Any # type: ignore[misc] # Usage: @@ -714,7 +715,7 @@ def recvmsg( __bufsize: int, __ancbufsize: int = 0, __flags: int = 0, - ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, object]]: raise NotImplementedError if sys.platform != "win32" or ( @@ -726,7 +727,7 @@ def recvmsg_into( __buffers: Iterable[Buffer], __ancbufsize: int = 0, __flags: int = 0, - ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, object]]: raise NotImplementedError def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: @@ -747,7 +748,7 @@ async def sendto( __address: tuple[object, ...] | str | Buffer, ) -> int: ... - async def sendto(self, *args: Any) -> int: + async def sendto(self, *args: object) -> int: raise NotImplementedError if sys.platform != "win32" or ( @@ -1195,7 +1196,7 @@ def recvmsg( __bufsize: int, __ancbufsize: int = 0, __flags: int = 0, - ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: ... + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, object]]: ... recvmsg = _make_simple_sock_method_wrapper( _stdlib_socket.socket.recvmsg, @@ -1217,7 +1218,7 @@ def recvmsg_into( __buffers: Iterable[Buffer], __ancbufsize: int = 0, __flags: int = 0, - ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: ... + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, object]]: ... recvmsg_into = _make_simple_sock_method_wrapper( _stdlib_socket.socket.recvmsg_into, @@ -1257,8 +1258,8 @@ async def sendto( __address: tuple[object, ...] | str | Buffer, ) -> int: ... - @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] - async def sendto(self, *args: Any) -> int: + @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) + async def sendto(self, *args: object) -> int: """Similar to :meth:`socket.socket.sendto`, but async.""" # args is: data[, flags], address # and kwargs are not accepted diff --git a/src/trio/_ssl.py b/src/trio/_ssl.py index df1cbc37b..0a0419fbc 100644 --- a/src/trio/_ssl.py +++ b/src/trio/_ssl.py @@ -16,6 +16,10 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable + from typing_extensions import TypeVarTuple, Unpack + + Ts = TypeVarTuple("Ts") + # General theory of operation: # # We implement an API that closely mirrors the stdlib ssl module's blocking @@ -219,7 +223,13 @@ class NeedHandshakeError(Exception): class _Once: - def __init__(self, afn: Callable[..., Awaitable[object]], *args: object) -> None: + __slots__ = ("_afn", "_args", "_done", "started") + + def __init__( + self, + afn: Callable[[*Ts], Awaitable[object]], + *args: Unpack[Ts], + ) -> None: self._afn = afn self._args = args self.started = False @@ -413,7 +423,11 @@ def __init__( "version", } - def __getattr__(self, name: str) -> Any: + # Explicit "Any" is not allowed + def __getattr__( # type: ignore[misc] + self, + name: str, + ) -> Any: if name in self._forwarded: if name in self._after_handshake and not self._handshook.done: raise NeedHandshakeError(f"call do_handshake() before calling {name!r}") @@ -447,8 +461,8 @@ def _check_status(self) -> None: # too. async def _retry( self, - fn: Callable[..., T], - *args: object, + fn: Callable[[*Ts], T], + *args: Unpack[Ts], ignore_want_read: bool = False, is_handshake: bool = False, ) -> T | None: diff --git a/src/trio/_tests/test_deprecate_strict_exception_groups_false.py b/src/trio/_tests/test_deprecate_strict_exception_groups_false.py index 7e575aa92..1b02c9ee7 100644 --- a/src/trio/_tests/test_deprecate_strict_exception_groups_false.py +++ b/src/trio/_tests/test_deprecate_strict_exception_groups_false.py @@ -32,7 +32,7 @@ async def foo_loose_nursery() -> None: async with trio.open_nursery(strict_exception_groups=False): ... - def helper(fun: Callable[..., Awaitable[None]], num: int) -> None: + def helper(fun: Callable[[], Awaitable[None]], num: int) -> None: with pytest.warns( trio.TrioDeprecationWarning, match="strict_exception_groups=False", diff --git a/src/trio/_tests/test_exports.py b/src/trio/_tests/test_exports.py index d82481402..de2449755 100644 --- a/src/trio/_tests/test_exports.py +++ b/src/trio/_tests/test_exports.py @@ -390,11 +390,13 @@ def lookup_symbol(symbol: str) -> dict[str, str]: assert "node" in cached_type_info node = cached_type_info["node"] - static_names = no_hidden(k for k in node["names"] if not k.startswith(".")) + static_names = no_hidden( + k for k in node.get("names", ()) if not k.startswith(".") + ) for symbol in node["mro"][1:]: node = lookup_symbol(symbol)["node"] static_names |= no_hidden( - k for k in node["names"] if not k.startswith(".") + k for k in node.get("names", ()) if not k.startswith(".") ) static_names -= ignore_names diff --git a/src/trio/_tests/test_highlevel_open_tcp_listeners.py b/src/trio/_tests/test_highlevel_open_tcp_listeners.py index 72fe0382f..61abd43f0 100644 --- a/src/trio/_tests/test_highlevel_open_tcp_listeners.py +++ b/src/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -4,7 +4,7 @@ import socket as stdlib_socket import sys from socket import AddressFamily, SocketKind -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, cast, overload import attrs import pytest @@ -30,7 +30,7 @@ from typing_extensions import Buffer - from .._socket import AddressFormat + from trio._socket import AddressFormat async def test_open_tcp_listeners_basic() -> None: @@ -312,7 +312,9 @@ async def handler(stream: SendStream) -> None: async with trio.open_nursery() as nursery: # nursery.start is incorrectly typed, awaiting #2773 - listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0) + value = await nursery.start(serve_tcp, handler, 0) + assert isinstance(value, list) + listeners = cast(list[SocketListener], value) stream = await open_stream_to_socket_listener(listeners[0]) async with stream: assert await stream.receive_some(1) == b"x" diff --git a/src/trio/_tests/test_highlevel_open_tcp_stream.py b/src/trio/_tests/test_highlevel_open_tcp_stream.py index c84b30644..98adf7efe 100644 --- a/src/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/src/trio/_tests/test_highlevel_open_tcp_stream.py @@ -373,9 +373,6 @@ async def run_scenario( trio.socket.set_custom_socket_factory(scenario) try: - # Type ignore is for the fact that there are multiple - # keyword arguments that accept separate types, but - # str | float | None is not the same as str | None and float | None stream = await open_tcp_stream( "test.example.com", port, diff --git a/src/trio/_tests/test_highlevel_serve_listeners.py b/src/trio/_tests/test_highlevel_serve_listeners.py index 0ce82e784..013d13078 100644 --- a/src/trio/_tests/test_highlevel_serve_listeners.py +++ b/src/trio/_tests/test_highlevel_serve_listeners.py @@ -2,7 +2,7 @@ import errno from functools import partial -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, NoReturn, cast import attrs @@ -96,11 +96,13 @@ async def do_tests(parent_nursery: Nursery) -> None: parent_nursery.cancel_scope.cancel() async with trio.open_nursery() as nursery: - l2: list[MemoryListener] = await nursery.start( + value = await nursery.start( trio.serve_listeners, handler, listeners, ) + assert isinstance(value, list) + l2 = cast(list[MemoryListener], value) assert l2 == listeners # This is just split into another function because gh-136 isn't # implemented yet @@ -172,7 +174,9 @@ async def connection_watcher( # the exception is wrapped twice because we open two nested nurseries with RaisesGroup(RaisesGroup(Done)): async with trio.open_nursery() as nursery: - handler_nursery: trio.Nursery = await nursery.start(connection_watcher) + value = await nursery.start(connection_watcher) + assert isinstance(value, trio.Nursery) + handler_nursery: trio.Nursery = value await nursery.start( partial( trio.serve_listeners, diff --git a/src/trio/_tests/test_highlevel_ssl_helpers.py b/src/trio/_tests/test_highlevel_ssl_helpers.py index 841232850..68aa47846 100644 --- a/src/trio/_tests/test_highlevel_ssl_helpers.py +++ b/src/trio/_tests/test_highlevel_ssl_helpers.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import TYPE_CHECKING, NoReturn +from typing import TYPE_CHECKING, NoReturn, cast import attrs import pytest @@ -10,11 +10,13 @@ import trio.testing from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM +from .._highlevel_socket import SocketListener from .._highlevel_ssl_helpers import ( open_ssl_over_tcp_listeners, open_ssl_over_tcp_stream, serve_ssl_over_tcp, ) +from .._ssl import SSLListener # using noqa because linters don't understand how pytest fixtures work. from .test_ssl import SERVER_CTX, client_ctx # noqa: F401 @@ -25,9 +27,6 @@ from trio.abc import Stream - from .._highlevel_socket import SocketListener - from .._ssl import SSLListener - async def echo_handler(stream: Stream) -> None: async with stream: @@ -68,7 +67,8 @@ async def getaddrinfo( async def getnameinfo( self, - *args: tuple[str, int] | tuple[str, int, int, int] | int, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int, ) -> NoReturn: # pragma: no cover raise NotImplementedError @@ -82,17 +82,17 @@ async def test_open_ssl_over_tcp_stream_and_everything_else( # TODO: this function wraps an SSLListener around a SocketListener, this is illegal # according to current type hints, and probably for good reason. But there should # maybe be a different wrapper class/function that could be used instead? - res: list[SSLListener[SocketListener]] = ( # type: ignore[type-var] - await nursery.start( - partial( - serve_ssl_over_tcp, - echo_handler, - 0, - SERVER_CTX, - host="127.0.0.1", - ), - ) + value = await nursery.start( + partial( + serve_ssl_over_tcp, + echo_handler, + 0, + SERVER_CTX, + host="127.0.0.1", + ), ) + assert isinstance(value, list) + res = cast(list[SSLListener[SocketListener]], value) # type: ignore[type-var] (listener,) = res async with listener: # listener.transport_listener is of type Listener[Stream] diff --git a/src/trio/_tests/test_socket.py b/src/trio/_tests/test_socket.py index 491476a51..3f68b285b 100644 --- a/src/trio/_tests/test_socket.py +++ b/src/trio/_tests/test_socket.py @@ -8,7 +8,7 @@ import tempfile from pathlib import Path from socket import AddressFamily, SocketKind -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Union, cast import attrs import pytest @@ -33,9 +33,18 @@ Union[tuple[str, int], tuple[str, int, int, int]], ] GetAddrInfoResponse: TypeAlias = list[GaiTuple] + GetAddrInfoArgs: TypeAlias = tuple[ + Union[str, bytes, None], + Union[str, bytes, int, None], + int, + int, + int, + int, + ] else: GaiTuple: object GetAddrInfoResponse = object + GetAddrInfoArgs = object ################################################################ # utils @@ -43,21 +52,32 @@ class MonkeypatchedGAI: - def __init__(self, orig_getaddrinfo: Callable[..., GetAddrInfoResponse]) -> None: + __slots__ = ("_orig_getaddrinfo", "_responses", "record") + + def __init__( + self, + orig_getaddrinfo: Callable[ + [str | bytes | None, str | bytes | int | None, int, int, int, int], + GetAddrInfoResponse, + ], + ) -> None: self._orig_getaddrinfo = orig_getaddrinfo - self._responses: dict[tuple[Any, ...], GetAddrInfoResponse | str] = {} - self.record: list[tuple[Any, ...]] = [] + self._responses: dict[ + GetAddrInfoArgs, + GetAddrInfoResponse | str, + ] = {} + self.record: list[GetAddrInfoArgs] = [] # get a normalized getaddrinfo argument tuple def _frozenbind( self, - host: bytes | str | None, - port: bytes | str | int | None, + host: str | bytes | None, + port: str | bytes | int | None, family: int = 0, type: int = 0, proto: int = 0, flags: int = 0, - ) -> tuple[Any, ...]: + ) -> GetAddrInfoArgs: sig = inspect.signature(self._orig_getaddrinfo) bound = sig.bind(host, port, family=family, type=type, proto=proto, flags=flags) bound.apply_defaults() @@ -68,8 +88,8 @@ def _frozenbind( def set( self, response: GetAddrInfoResponse | str, - host: bytes | str | None, - port: bytes | str | int | None, + host: str | bytes | None, + port: str | bytes | int | None, family: int = 0, type: int = 0, proto: int = 0, @@ -88,33 +108,19 @@ def set( def getaddrinfo( self, - host: bytes | str | None, - port: bytes | str | int | None, + host: str | bytes | None, + port: str | bytes | int | None, family: int = 0, type: int = 0, proto: int = 0, flags: int = 0, ) -> GetAddrInfoResponse | str: - bound = self._frozenbind( - host, - port, - family=family, - type=type, - proto=proto, - flags=flags, - ) + bound = self._frozenbind(host, port, family, type, proto, flags) self.record.append(bound) if bound in self._responses: return self._responses[bound] - elif bound[-1] & stdlib_socket.AI_NUMERICHOST: - return self._orig_getaddrinfo( - host, - port, - family=family, - type=type, - proto=proto, - flags=flags, - ) + elif flags & stdlib_socket.AI_NUMERICHOST: + return self._orig_getaddrinfo(host, port, family, type, proto, flags) else: raise RuntimeError(f"gai called with unexpected arguments {bound}") @@ -640,11 +646,13 @@ async def res( | tuple[str, str, int] | tuple[str, str, int, int] ), - ) -> AddressFormat: - return await sock._resolve_address_nocp( + ) -> tuple[str | int, ...]: + value = await sock._resolve_address_nocp( args, local=local, # noqa: B023 # local is not bound in function definition ) + assert isinstance(value, tuple) + return cast(tuple[Union[str, int], ...], value) assert_eq(await res((addrs.arbitrary, "http")), (addrs.arbitrary, 80)) if v6: @@ -839,7 +847,10 @@ async def test_SocketType_connect_paths() -> None: # nose -- and then swap it back out again before we hit # wait_socket_writable, which insists on a real socket. class CancelSocket(stdlib_socket.socket): - def connect(self, *args: AddressFormat) -> None: + def connect( + self, + address: AddressFormat, + ) -> None: # accessing private method only available in _SocketType assert isinstance(sock, _SocketType) @@ -849,7 +860,7 @@ def connect(self, *args: AddressFormat) -> None: self.family, self.type, ) - sock._sock.connect(*args) + sock._sock.connect(address) # If connect *doesn't* raise, then pretend it did raise BlockingIOError # pragma: no cover @@ -896,15 +907,17 @@ async def test_resolve_address_exception_in_connect_closes_socket() -> None: with tsocket.socket() as sock: async def _resolve_address_nocp( - self: _SocketType, - *args: AddressFormat, - **kwargs: bool, + address: AddressFormat, + *, + local: bool, ) -> None: + assert address == "" + assert not local cancel_scope.cancel() await _core.checkpoint() assert isinstance(sock, _SocketType) - sock._resolve_address_nocp = _resolve_address_nocp # type: ignore[method-assign, assignment] + sock._resolve_address_nocp = _resolve_address_nocp # type: ignore[method-assign] with assert_checkpoints(): with pytest.raises(_core.Cancelled): await sock.connect("") diff --git a/src/trio/_tests/test_ssl.py b/src/trio/_tests/test_ssl.py index 9c30d4fc7..2a16a0cd1 100644 --- a/src/trio/_tests/test_ssl.py +++ b/src/trio/_tests/test_ssl.py @@ -397,7 +397,8 @@ def virtual_ssl_echo_server( yield SSLStream(fakesock, client_ctx, server_hostname="trio-test-1.example.org") -def ssl_wrap_pair( +# Explicit "Any" is not allowed +def ssl_wrap_pair( # type: ignore[misc] client_ctx: SSLContext, client_transport: T_Stream, server_transport: T_Stream, diff --git a/src/trio/_tests/test_subprocess.py b/src/trio/_tests/test_subprocess.py index 1067c39a1..88623a430 100644 --- a/src/trio/_tests/test_subprocess.py +++ b/src/trio/_tests/test_subprocess.py @@ -108,7 +108,9 @@ async def run_process_in_nursery( ) -> AsyncIterator[Process]: async with _core.open_nursery() as nursery: kwargs.setdefault("check", False) - proc: Process = await nursery.start(partial(run_process, *args, **kwargs)) + value = await nursery.start(partial(run_process, *args, **kwargs)) + assert isinstance(value, Process) + proc: Process = value yield proc nursery.cancel_scope.cancel() @@ -119,7 +121,11 @@ async def run_process_in_nursery( ids=["open_process", "run_process in nursery"], ) -BackgroundProcessType: TypeAlias = Callable[..., AbstractAsyncContextManager[Process]] +# Explicit .../"Any" is not allowed +BackgroundProcessType: TypeAlias = Callable[ # type: ignore[misc] + ..., + AbstractAsyncContextManager[Process], +] @background_process_param @@ -636,7 +642,9 @@ async def test_warn_on_cancel_SIGKILL_escalation( async def test_run_process_background_fail() -> None: with RaisesGroup(subprocess.CalledProcessError): async with _core.open_nursery() as nursery: - proc: Process = await nursery.start(run_process, EXIT_FALSE) + value = await nursery.start(run_process, EXIT_FALSE) + assert isinstance(value, Process) + proc: Process = value assert proc.returncode == 1 diff --git a/src/trio/_tests/test_testing_raisesgroup.py b/src/trio/_tests/test_testing_raisesgroup.py index 17eb6afcc..7b2fe4417 100644 --- a/src/trio/_tests/test_testing_raisesgroup.py +++ b/src/trio/_tests/test_testing_raisesgroup.py @@ -3,7 +3,6 @@ import re import sys from types import TracebackType -from typing import Any import pytest @@ -235,7 +234,10 @@ def test_RaisesGroup_matches() -> None: def test_message() -> None: - def check_message(message: str, body: RaisesGroup[Any]) -> None: + def check_message( + message: str, + body: RaisesGroup[BaseException], + ) -> None: with pytest.raises( AssertionError, match=f"^DID NOT RAISE any exception, expected {re.escape(message)}$", diff --git a/src/trio/_tests/test_threads.py b/src/trio/_tests/test_threads.py index df9d2e74a..641454d2e 100644 --- a/src/trio/_tests/test_threads.py +++ b/src/trio/_tests/test_threads.py @@ -55,7 +55,8 @@ async def test_do_in_trio_thread() -> None: trio_thread = threading.current_thread() - async def check_case( + # Explicit "Any" is not allowed + async def check_case( # type: ignore[misc] do_in_trio_thread: Callable[..., threading.Thread], fn: Callable[..., T | Awaitable[T]], expected: tuple[str, T], diff --git a/src/trio/_tests/test_util.py b/src/trio/_tests/test_util.py index 7584c2e80..23d16fe84 100644 --- a/src/trio/_tests/test_util.py +++ b/src/trio/_tests/test_util.py @@ -7,7 +7,6 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, Coroutine, Generator - import pytest import trio @@ -30,6 +29,9 @@ ) from ..testing import wait_all_tasks_blocked +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + T = TypeVar("T") diff --git a/src/trio/_threads.py b/src/trio/_threads.py index 4cd460078..7afd7b612 100644 --- a/src/trio/_threads.py +++ b/src/trio/_threads.py @@ -146,8 +146,9 @@ class ThreadPlaceholder: # Types for the to_thread_run_sync message loop @attrs.frozen(eq=False, slots=False) -class Run(Generic[RetT]): - afn: Callable[..., Awaitable[RetT]] +# Explicit .../"Any" is not allowed +class Run(Generic[RetT]): # type: ignore[misc] + afn: Callable[..., Awaitable[RetT]] # type: ignore[misc] args: tuple[object, ...] context: contextvars.Context = attrs.field( init=False, @@ -205,8 +206,9 @@ def in_trio_thread() -> None: @attrs.frozen(eq=False, slots=False) -class RunSync(Generic[RetT]): - fn: Callable[..., RetT] +# Explicit .../"Any" is not allowed +class RunSync(Generic[RetT]): # type: ignore[misc] + fn: Callable[..., RetT] # type: ignore[misc] args: tuple[object, ...] context: contextvars.Context = attrs.field( init=False, @@ -522,7 +524,8 @@ def _send_message_to_trio( return message_to_trio.queue.get().unwrap() -def from_thread_run( +# Explicit "Any" is not allowed +def from_thread_run( # type: ignore[misc] afn: Callable[..., Awaitable[RetT]], *args: object, trio_token: TrioToken | None = None, @@ -566,7 +569,8 @@ def from_thread_run( return _send_message_to_trio(trio_token, Run(afn, args)) -def from_thread_run_sync( +# Explicit "Any" is not allowed +def from_thread_run_sync( # type: ignore[misc] fn: Callable[..., RetT], *args: object, trio_token: TrioToken | None = None, diff --git a/src/trio/_util.py b/src/trio/_util.py index 641c22f3a..b354fbf5d 100644 --- a/src/trio/_util.py +++ b/src/trio/_util.py @@ -22,7 +22,8 @@ import trio -CallT = TypeVar("CallT", bound=Callable[..., Any]) +# Explicit "Any" is not allowed +CallT = TypeVar("CallT", bound=Callable[..., Any]) # type: ignore[misc] T = TypeVar("T") RetT = TypeVar("RetT") @@ -232,14 +233,16 @@ def __exit__( self._held = False -def async_wraps( +# Explicit "Any" is not allowed +def async_wraps( # type: ignore[misc] cls: type[object], wrapped_cls: type[object], attr_name: str, ) -> Callable[[CallT], CallT]: """Similar to wraps, but for async wrappers of non-async functions.""" - def decorator(func: CallT) -> CallT: + # Explicit "Any" is not allowed + def decorator(func: CallT) -> CallT: # type: ignore[misc] func.__name__ = attr_name func.__qualname__ = f"{cls.__qualname__}.{attr_name}" @@ -302,7 +305,11 @@ def open_memory_channel(max_buffer_size: int) -> Tuple[ but at least it becomes possible to write those. """ - def __init__(self, fn: Callable[..., RetT]) -> None: + # Explicit .../"Any" is not allowed + def __init__( # type: ignore[misc] + self, + fn: Callable[..., RetT], + ) -> None: update_wrapper(self, fn) self._fn = fn @@ -395,9 +402,11 @@ def name_asyncgen(agen: AsyncGeneratorType[object, NoReturn]) -> str: # work around a pyright error if TYPE_CHECKING: - Fn = TypeVar("Fn", bound=Callable[..., object]) + # Explicit .../"Any" is not allowed + Fn = TypeVar("Fn", bound=Callable[..., object]) # type: ignore[misc] - def wraps( + # Explicit .../"Any" is not allowed + def wraps( # type: ignore[misc] wrapped: Callable[..., object], assigned: Sequence[str] = ..., updated: Sequence[str] = ..., diff --git a/src/trio/testing/_fake_net.py b/src/trio/testing/_fake_net.py index c37c489bd..6ec7f68b6 100644 --- a/src/trio/testing/_fake_net.py +++ b/src/trio/testing/_fake_net.py @@ -359,7 +359,12 @@ async def _recvmsg_into( buffers: Iterable[Buffer], ancbufsize: int = 0, flags: int = 0, - ) -> tuple[int, list[tuple[int, int, bytes]], int, Any]: + ) -> tuple[ + int, + list[tuple[int, int, bytes]], + int, + tuple[str, int] | tuple[str, int, int, int], + ]: if ancbufsize != 0: raise NotImplementedError("FakeNet doesn't support ancillary data") if flags != 0: @@ -502,7 +507,11 @@ async def sendto( __address: tuple[object, ...] | str | None | Buffer, ) -> int: ... - async def sendto(self, *args: Any) -> int: + # Explicit "Any" is not allowed + async def sendto( # type: ignore[misc] + self, + *args: Any, + ) -> int: data: Buffer flags: int address: tuple[object, ...] | str | Buffer @@ -523,7 +532,11 @@ async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int: got_bytes, _address = await self.recvfrom_into(buf, nbytes, flags) return got_bytes - async def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[bytes, Any]: + async def recvfrom( + self, + bufsize: int, + flags: int = 0, + ) -> tuple[bytes, AddressFormat]: data, _ancdata, _msg_flags, address = await self._recvmsg(bufsize, flags) return data, address @@ -532,7 +545,7 @@ async def recvfrom_into( buf: Buffer, nbytes: int = 0, flags: int = 0, - ) -> tuple[int, Any]: + ) -> tuple[int, AddressFormat]: if nbytes != 0 and nbytes != memoryview(buf).nbytes: raise NotImplementedError("partial recvfrom_into") got_nbytes, _ancdata, _msg_flags, address = await self._recvmsg_into( @@ -547,7 +560,7 @@ async def _recvmsg( bufsize: int, ancbufsize: int = 0, flags: int = 0, - ) -> tuple[bytes, list[tuple[int, int, bytes]], int, Any]: + ) -> tuple[bytes, list[tuple[int, int, bytes]], int, AddressFormat]: buf = bytearray(bufsize) got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into( [buf], diff --git a/test-requirements.in b/test-requirements.in index add7798d0..1c30d8982 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -11,7 +11,8 @@ cryptography>=41.0.0 # cryptography<41 segfaults on pypy3.10 # Tools black; implementation_name == "cpython" -mypy +mypy # Would use mypy[faster-cache], but orjson has build issues on pypy +orjson; implementation_name == "cpython" ruff >= 0.6.6 astor # code generation uv >= 0.2.24 diff --git a/test-requirements.txt b/test-requirements.txt index 23ed36462..9c0f87b82 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -80,6 +80,8 @@ mypy-extensions==1.0.0 # mypy nodeenv==1.9.1 # via pyright +orjson==3.10.10 ; implementation_name == 'cpython' + # via -r test-requirements.in outcome==1.3.0.post0 # via -r test-requirements.in packaging==24.1