Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent asyncio.wait_for swallowing task cancellation #1698

Merged
merged 3 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion asyncua/client/ha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from itertools import chain, islice

from asyncua.common.utils import wait_for

_logger = logging.getLogger(__name__)

Expand All @@ -15,7 +16,7 @@ class ClientNotFound(Exception):

async def event_wait(evt, timeout) -> bool:
try:
await asyncio.wait_for(evt.wait(), timeout)
await wait_for(evt.wait(), timeout)
except asyncio.TimeoutError:
pass
return evt.is_set()
Expand Down
7 changes: 4 additions & 3 deletions asyncua/client/ua_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from asyncua import ua
from asyncua.common.session_interface import AbstractSession
from ..common.utils import wait_for
from ..ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary, header_from_binary
from ..ua.uaerrors import BadTimeout, BadNoSubscription, BadSessionClosed, BadUserAccessDenied, UaStructParsingError
from ..ua.uaprotocol_auto import OpenSecureChannelResult, SubscriptionAcknowledgement
Expand Down Expand Up @@ -165,7 +166,7 @@ async def send_request(self, request, timeout: Optional[float] = None, message_t
# time out then.
await self.pre_request_hook()
try:
data = await asyncio.wait_for(self._send_request(request, timeout, message_type), timeout if timeout else None)
data = await wait_for(self._send_request(request, timeout, message_type), timeout if timeout else None)
except Exception:
if self.state != self.OPEN:
raise ConnectionError("Connection is closed") from None
Expand Down Expand Up @@ -221,7 +222,7 @@ async def send_hello(self, url, max_messagesize: int = 0, max_chunkcount: int =
self._callbackmap[0] = ack
if self.transport is not None:
self.transport.write(uatcp_to_binary(ua.MessageType.Hello, hello))
return await asyncio.wait_for(ack, self.timeout)
return await wait_for(ack, self.timeout)

async def open_secure_channel(self, params) -> OpenSecureChannelResult:
self.logger.info("open_secure_channel")
Expand All @@ -230,7 +231,7 @@ async def open_secure_channel(self, params) -> OpenSecureChannelResult:
if self._open_secure_channel_exchange is not None:
raise RuntimeError('Two Open Secure Channel requests can not happen too close to each other. ' 'The response must be processed and returned before the next request can be sent.')
self._open_secure_channel_exchange = params
await asyncio.wait_for(self._send_request(request, message_type=ua.MessageType.SecureOpen), self.timeout)
await wait_for(self._send_request(request, message_type=ua.MessageType.SecureOpen), self.timeout)
_return = self._open_secure_channel_exchange.Parameters # type: ignore[union-attr]
self._open_secure_channel_exchange = None
return _return
Expand Down
26 changes: 23 additions & 3 deletions asyncua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
Helper function and classes that do not rely on asyncua library.
Helper function and classes depending on ua object are in ua_utils.py
"""

import os
import asyncio
import logging
import os
import sys
from dataclasses import Field, fields
from typing import get_type_hints, Dict, Tuple, Any, Optional
from typing import Any, Awaitable, Dict, get_type_hints, Optional, Tuple, TypeVar, Union

from ..ua.uaerrors import UaError

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -132,3 +133,22 @@ def fields_with_resolved_types(
pass

return fields_


_T = TypeVar('_T')


async def wait_for(aw: Awaitable[_T], timeout: Union[int, float, None]) -> _T:
"""
Wrapped version of asyncio.wait_for that does not swallow cancellations

There is a bug in asyncio.wait_for before Python version 3.12 that prevents the inner awaitable from being cancelled
when the task is cancelled from the outside.

See https://github.com/python/cpython/issues/87555 and https://github.com/python/cpython/issues/86296
"""
if sys.version_info >= (3, 12):
return await asyncio.wait_for(aw, timeout)

import wait_for2
return await wait_for2.wait_for(aw, timeout)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ max-line-length = 160
disable_error_code = misc, arg-type, assignment, var-annotated
show_error_codes = True
check_untyped_defs = False
mypy_path = ./stubs
[mypy-asyncua.ua.uaprotocol_auto.*]
# Autogenerated file
disable_error_code = literal-required
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
author="Olivier Roulet-Dubonnet",
author_email="[email protected]",
url='http://freeopcua.github.io/',
packages=find_packages(exclude=["tests"]),
packages=find_packages(exclude=["tests", "stubs"]),
provides=["asyncua"],
license="GNU Lesser General Public License v3 or later",
install_requires=["aiofiles", "aiosqlite", "python-dateutil", "pytz", "cryptography>42.0.0", "sortedcontainers", "importlib-metadata;python_version<'3.8'", "pyOpenSSL>23.2.0", "typing-extensions"],
install_requires=["aiofiles", "aiosqlite", "python-dateutil", "pytz", "cryptography>42.0.0", "sortedcontainers", "importlib-metadata;python_version<'3.8'", "pyOpenSSL>23.2.0", "typing-extensions", 'wait_for2==0.3.2'],
classifiers=[
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
Expand Down
14 changes: 14 additions & 0 deletions stubs/wait_for2/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import asyncio
from typing import Any, Awaitable, Callable, TypeVar, Union

_T = TypeVar('_T')


async def wait_for(
fut: Awaitable[_T],
timeout: Union[int, float, None],
*,
loop: asyncio.AbstractEventLoop = None,
race_handler: Callable[[Union[_T, BaseException], bool], Any] = None,
):
...