diff --git a/ydb/_topic_common/common.py b/ydb/_topic_common/common.py index 7a97336e..8241dda4 100644 --- a/ydb/_topic_common/common.py +++ b/ydb/_topic_common/common.py @@ -1,5 +1,6 @@ import asyncio import concurrent.futures +import sys import threading import typing from typing import Optional @@ -29,6 +30,18 @@ def wrapper(rpc_state, response_pb, driver=None): return wrapper +if sys.hexversion < 0x03080000: + + def wrap_set_name_for_asyncio_task(task: asyncio.Task, task_name: str) -> asyncio.Task: + return task + +else: + + def wrap_set_name_for_asyncio_task(task: asyncio.Task, task_name: str) -> asyncio.Task: + task.set_name(task_name) + return task + + _shared_event_loop_lock = threading.Lock() _shared_event_loop: Optional[asyncio.AbstractEventLoop] = None diff --git a/ydb/_topic_common/common_test.py b/ydb/_topic_common/common_test.py index b31f9af9..32261520 100644 --- a/ydb/_topic_common/common_test.py +++ b/ydb/_topic_common/common_test.py @@ -6,7 +6,7 @@ import grpc import pytest -from .common import CallFromSyncToAsync +from .common import CallFromSyncToAsync, wrap_set_name_for_asyncio_task from .._grpc.grpcwrapper.common_utils import ( GrpcWrapperAsyncIO, ServerStatus, @@ -75,6 +75,19 @@ async def async_failed(): with pytest.raises(TestError): await callback_from_asyncio(async_failed) + async def test_task_name_on_asyncio_task(self): + task_name = "asyncio task" + loop = asyncio.get_running_loop() + + async def some_async_task(): + await asyncio.sleep(0) + return 1 + + asyncio_task = loop.create_task(some_async_task()) + wrap_set_name_for_asyncio_task(asyncio_task, task_name=task_name) + + assert asyncio_task.get_name() == task_name + @pytest.mark.asyncio class TestGrpcWrapperAsyncIO: diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 81c6d9f4..752e0a1f 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -10,6 +10,7 @@ import ydb from .. import _apis, issues +from .._topic_common import common as topic_common from .._utilities import AtomicCounter from ..aio import Driver from ..issues import Error as YdbError, _process_response @@ -87,7 +88,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def __del__(self): if not self._closed: - self._loop.create_task(self.close(flush=False), name="close reader") + task = self._loop.create_task(self.close(flush=False)) + topic_common.wrap_set_name_for_asyncio_task(task, task_name="close reader") async def wait_message(self): """ @@ -337,12 +339,30 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess self._update_token_event.set() - self._background_tasks.add(asyncio.create_task(self._read_messages_loop(), name="read_messages_loop")) - self._background_tasks.add(asyncio.create_task(self._decode_batches_loop(), name="decode_batches")) + self._background_tasks.add( + topic_common.wrap_set_name_for_asyncio_task( + asyncio.create_task(self._read_messages_loop()), + task_name="read_messages_loop", + ), + ) + self._background_tasks.add( + topic_common.wrap_set_name_for_asyncio_task( + asyncio.create_task(self._decode_batches_loop()), + task_name="decode_batches", + ), + ) if self._get_token_function: - self._background_tasks.add(asyncio.create_task(self._update_token_loop(), name="update_token_loop")) + self._background_tasks.add( + topic_common.wrap_set_name_for_asyncio_task( + asyncio.create_task(self._update_token_loop()), + task_name="update_token_loop", + ), + ) self._background_tasks.add( - asyncio.create_task(self._handle_background_errors(), name="handle_background_errors") + topic_common.wrap_set_name_for_asyncio_task( + asyncio.create_task(self._handle_background_errors()), + task_name="handle_background_errors", + ), ) async def wait_error(self): diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 585e88ab..c7f88a42 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -28,6 +28,7 @@ issues, ) from .._errors import check_retriable_error +from .._topic_common import common as topic_common from ..retries import RetrySettings from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .._grpc.grpcwrapper.ydb_topic import ( @@ -231,8 +232,14 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._new_messages = asyncio.Queue() self._stop_reason = self._loop.create_future() self._background_tasks = [ - asyncio.create_task(self._connection_loop(), name="connection_loop"), - asyncio.create_task(self._encode_loop(), name="encode_loop"), + topic_common.wrap_set_name_for_asyncio_task( + asyncio.create_task(self._connection_loop()), + task_name="connection_loop", + ), + topic_common.wrap_set_name_for_asyncio_task( + asyncio.create_task(self._encode_loop()), + task_name="encode_loop", + ), ] self._state_changed = asyncio.Event() @@ -366,8 +373,14 @@ async def _connection_loop(self): self._stream_connected.set() - send_loop = asyncio.create_task(self._send_loop(stream_writer), name="writer send loop") - receive_loop = asyncio.create_task(self._read_loop(stream_writer), name="writer receive loop") + send_loop = topic_common.wrap_set_name_for_asyncio_task( + asyncio.create_task(self._send_loop(stream_writer)), + task_name="writer send loop", + ) + receive_loop = topic_common.wrap_set_name_for_asyncio_task( + asyncio.create_task(self._read_loop(stream_writer)), + task_name="writer receive loop", + ) tasks = [send_loop, receive_loop] done, _ = await asyncio.wait([send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED) @@ -653,7 +666,10 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMes if self._update_token_interval is not None: self._update_token_event.set() - self._update_token_task = asyncio.create_task(self._update_token_loop(), name="update_token_loop") + self._update_token_task = topic_common.wrap_set_name_for_asyncio_task( + asyncio.create_task(self._update_token_loop()), + task_name="update_token_loop", + ) @staticmethod def _ensure_ok(message: WriterMessagesFromServerToClient):