From e0f4e4bf677a8bb45617ce060d7b6ca42c69e70c Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 25 Sep 2024 17:30:40 +0300 Subject: [PATCH] Implement max_messages on recieve_batch --- ydb/_topic_reader/topic_reader_asyncio.py | 46 +++++++++++++++++++---- ydb/_topic_reader/topic_reader_sync.py | 4 +- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 0a24a86b..d4ffbdb3 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -97,6 +97,7 @@ async def wait_message(self): async def receive_batch( self, + max_messages: typing.Union[int, None] = None, ) -> typing.Union[datatypes.PublicBatch, None]: """ Get one messages batch from reader. @@ -105,7 +106,9 @@ async def receive_batch( use asyncio.wait_for for wait with timeout. """ await self._reconnector.wait_message() - return self._reconnector.receive_batch_nowait() + return self._reconnector.receive_batch_nowait( + max_messages=max_messages, + ) async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]: """ @@ -212,8 +215,10 @@ async def wait_message(self): await self._state_changed.wait() self._state_changed.clear() - def receive_batch_nowait(self): - return self._stream_reader.receive_batch_nowait() + def receive_batch_nowait(self, max_messages: Optional[int] = None): + return self._stream_reader.receive_batch_nowait( + max_messages=max_messages, + ) def receive_message_nowait(self): return self._stream_reader.receive_message_nowait() @@ -363,17 +368,44 @@ def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]: first_id, batch = self._message_batches.popitem(last=False) return first_id, batch - def receive_batch_nowait(self): + def _cut_batch_by_max_messages( + batch: datatypes.PublicBatch, + max_messages: int, + ) -> typing.Tuple[datatypes.PublicBatch, datatypes.PublicBatch]: + initial_length = len(batch.messages) + one_message_size = batch._bytes_size // initial_length + + new_batch = datatypes.PublicBatch( + messages=batch.messages[:max_messages], + _partition_session=batch._partition_session, + _bytes_size=one_message_size*max_messages, + _codec=batch._codec, + ) + + batch.messages = batch.messages[max_messages:] + batch._bytes_size = one_message_size * (initial_length - max_messages) + + return new_batch, batch + + def receive_batch_nowait(self, max_messages: Optional[int] = None): if self._get_first_error(): raise self._get_first_error() if not self._message_batches: return None - _, batch = self._get_first_batch() - self._buffer_release_bytes(batch._bytes_size) + part_sess_id, batch = self._get_first_batch() + + if max_messages is None or len(batch.messages) <= max_messages: + self._buffer_release_bytes(batch._bytes_size) + return batch + + cutted_batch, remaining_batch = self._cut_batch_by_max_messages(batch, max_messages) + + self._message_batches[part_sess_id] = remaining_batch + self._buffer_release_bytes(cutted_batch._bytes_size) - return batch + return cutted_batch def receive_message_nowait(self): if self._get_first_error(): diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index c266de82..3048d3c4 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -103,7 +103,9 @@ def receive_batch( self._check_closed() return self._caller.safe_call_with_result( - self._async_reader.receive_batch(), + self._async_reader.receive_batch( + max_messages=max_messages, + ), timeout, )