Skip to content

Commit

Permalink
Removing loop name from asyncio.create_task (#489)
Browse files Browse the repository at this point in the history
* Update topic_writer_asyncio.py

* Update topic_reader_asyncio.py

* Update topic_writer_asyncio.py

* Update topic_reader_asyncio.py

fix linters

* Update topic_reader_asyncio.py

* Update topic_reader_asyncio.py

* Update topic_reader_asyncio.py

* Update topic_reader_asyncio.py

* add wrapper for asyncio.create_task

* fix linters

* fix linters

* fix linters

* fix linters

* fix linters

* fix linters

* fix linters

* split setting task name during function declaration stage

* fix tests

* fix tests

* fix linters

* fix linters

* fix linters

* fix tests
  • Loading branch information
alex2211-put committed Sep 26, 2024
1 parent b0cce07 commit c84538d
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 11 deletions.
13 changes: 13 additions & 0 deletions ydb/_topic_common/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import concurrent.futures
import sys
import threading
import typing
from typing import Optional
Expand Down Expand Up @@ -29,6 +30,18 @@ def wrapper(rpc_state, response_pb, driver=None):
return wrapper


if sys.hexversion < 0x03080000:

def wrap_set_name_for_asyncio_task(task: asyncio.Task, task_name: str) -> asyncio.Task:
return task

else:

def wrap_set_name_for_asyncio_task(task: asyncio.Task, task_name: str) -> asyncio.Task:
task.set_name(task_name)
return task


_shared_event_loop_lock = threading.Lock()
_shared_event_loop: Optional[asyncio.AbstractEventLoop] = None

Expand Down
15 changes: 14 additions & 1 deletion ydb/_topic_common/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import grpc
import pytest

from .common import CallFromSyncToAsync
from .common import CallFromSyncToAsync, wrap_set_name_for_asyncio_task
from .._grpc.grpcwrapper.common_utils import (
GrpcWrapperAsyncIO,
ServerStatus,
Expand Down Expand Up @@ -75,6 +75,19 @@ async def async_failed():
with pytest.raises(TestError):
await callback_from_asyncio(async_failed)

async def test_task_name_on_asyncio_task(self):
task_name = "asyncio task"
loop = asyncio.get_running_loop()

async def some_async_task():
await asyncio.sleep(0)
return 1

asyncio_task = loop.create_task(some_async_task())
wrap_set_name_for_asyncio_task(asyncio_task, task_name=task_name)

assert asyncio_task.get_name() == task_name


@pytest.mark.asyncio
class TestGrpcWrapperAsyncIO:
Expand Down
30 changes: 25 additions & 5 deletions ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import ydb
from .. import _apis, issues
from .._topic_common import common as topic_common
from .._utilities import AtomicCounter
from ..aio import Driver
from ..issues import Error as YdbError, _process_response
Expand Down Expand Up @@ -87,7 +88,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

def __del__(self):
if not self._closed:
self._loop.create_task(self.close(flush=False), name="close reader")
task = self._loop.create_task(self.close(flush=False))
topic_common.wrap_set_name_for_asyncio_task(task, task_name="close reader")

async def wait_message(self):
"""
Expand Down Expand Up @@ -337,12 +339,30 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess

self._update_token_event.set()

self._background_tasks.add(asyncio.create_task(self._read_messages_loop(), name="read_messages_loop"))
self._background_tasks.add(asyncio.create_task(self._decode_batches_loop(), name="decode_batches"))
self._background_tasks.add(
topic_common.wrap_set_name_for_asyncio_task(
asyncio.create_task(self._read_messages_loop()),
task_name="read_messages_loop",
),
)
self._background_tasks.add(
topic_common.wrap_set_name_for_asyncio_task(
asyncio.create_task(self._decode_batches_loop()),
task_name="decode_batches",
),
)
if self._get_token_function:
self._background_tasks.add(asyncio.create_task(self._update_token_loop(), name="update_token_loop"))
self._background_tasks.add(
topic_common.wrap_set_name_for_asyncio_task(
asyncio.create_task(self._update_token_loop()),
task_name="update_token_loop",
),
)
self._background_tasks.add(
asyncio.create_task(self._handle_background_errors(), name="handle_background_errors")
topic_common.wrap_set_name_for_asyncio_task(
asyncio.create_task(self._handle_background_errors()),
task_name="handle_background_errors",
),
)

async def wait_error(self):
Expand Down
26 changes: 21 additions & 5 deletions ydb/_topic_writer/topic_writer_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
issues,
)
from .._errors import check_retriable_error
from .._topic_common import common as topic_common
from ..retries import RetrySettings
from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec
from .._grpc.grpcwrapper.ydb_topic import (
Expand Down Expand Up @@ -231,8 +232,14 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
self._new_messages = asyncio.Queue()
self._stop_reason = self._loop.create_future()
self._background_tasks = [
asyncio.create_task(self._connection_loop(), name="connection_loop"),
asyncio.create_task(self._encode_loop(), name="encode_loop"),
topic_common.wrap_set_name_for_asyncio_task(
asyncio.create_task(self._connection_loop()),
task_name="connection_loop",
),
topic_common.wrap_set_name_for_asyncio_task(
asyncio.create_task(self._encode_loop()),
task_name="encode_loop",
),
]

self._state_changed = asyncio.Event()
Expand Down Expand Up @@ -366,8 +373,14 @@ async def _connection_loop(self):

self._stream_connected.set()

send_loop = asyncio.create_task(self._send_loop(stream_writer), name="writer send loop")
receive_loop = asyncio.create_task(self._read_loop(stream_writer), name="writer receive loop")
send_loop = topic_common.wrap_set_name_for_asyncio_task(
asyncio.create_task(self._send_loop(stream_writer)),
task_name="writer send loop",
)
receive_loop = topic_common.wrap_set_name_for_asyncio_task(
asyncio.create_task(self._read_loop(stream_writer)),
task_name="writer receive loop",
)

tasks = [send_loop, receive_loop]
done, _ = await asyncio.wait([send_loop, receive_loop], return_when=asyncio.FIRST_COMPLETED)
Expand Down Expand Up @@ -653,7 +666,10 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamWriteMes

if self._update_token_interval is not None:
self._update_token_event.set()
self._update_token_task = asyncio.create_task(self._update_token_loop(), name="update_token_loop")
self._update_token_task = topic_common.wrap_set_name_for_asyncio_task(
asyncio.create_task(self._update_token_loop()),
task_name="update_token_loop",
)

@staticmethod
def _ensure_ok(message: WriterMessagesFromServerToClient):
Expand Down

0 comments on commit c84538d

Please sign in to comment.