Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement max_messages on recieve_batch #494

Merged
merged 7 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions ydb/_topic_reader/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,23 @@ def _extend(self, batch: PublicBatch) -> None:
def _pop(self) -> Tuple[List[PublicMessage], bool]:
msgs_left = True if len(self.messages) > 1 else False
return self.messages.pop(0), msgs_left

def _pop_batch(self, message_count: int) -> PublicBatch:
initial_length = len(self.messages)

if message_count >= initial_length:
raise ValueError("Pop batch with size >= actual size is not supported.")

one_message_size = self._bytes_size // initial_length

new_batch = PublicBatch(
messages=self.messages[:message_count],
_partition_session=self._partition_session,
_bytes_size=one_message_size * message_count,
_codec=self._codec,
)

self.messages = self.messages[message_count:]
self._bytes_size = self._bytes_size - new_batch._bytes_size

return new_batch
27 changes: 20 additions & 7 deletions ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ async def wait_message(self):

async def receive_batch(
self,
max_messages: typing.Union[int, None] = None,
) -> typing.Union[datatypes.PublicBatch, None]:
"""
Get one messages batch from reader.
Expand All @@ -107,7 +108,9 @@ async def receive_batch(
use asyncio.wait_for for wait with timeout.
"""
await self._reconnector.wait_message()
return self._reconnector.receive_batch_nowait()
return self._reconnector.receive_batch_nowait(
max_messages=max_messages,
)

async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]:
"""
Expand Down Expand Up @@ -214,8 +217,10 @@ async def wait_message(self):
await self._state_changed.wait()
self._state_changed.clear()

def receive_batch_nowait(self):
return self._stream_reader.receive_batch_nowait()
def receive_batch_nowait(self, max_messages: Optional[int] = None):
return self._stream_reader.receive_batch_nowait(
max_messages=max_messages,
)

def receive_message_nowait(self):
return self._stream_reader.receive_message_nowait()
Expand Down Expand Up @@ -383,17 +388,25 @@ def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]:
partition_session_id, batch = self._message_batches.popitem(last=False)
return partition_session_id, batch

def receive_batch_nowait(self):
def receive_batch_nowait(self, max_messages: Optional[int] = None):
if self._get_first_error():
raise self._get_first_error()

if not self._message_batches:
return None

_, batch = self._get_first_batch()
self._buffer_release_bytes(batch._bytes_size)
part_sess_id, batch = self._get_first_batch()

if max_messages is None or len(batch.messages) <= max_messages:
self._buffer_release_bytes(batch._bytes_size)
return batch

cutted_batch = batch._pop_batch(message_count=max_messages)

self._message_batches[part_sess_id] = batch
self._buffer_release_bytes(cutted_batch._bytes_size)

return batch
return cutted_batch

def receive_message_nowait(self):
if self._get_first_error():
Expand Down
90 changes: 90 additions & 0 deletions ydb/_topic_reader/topic_reader_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,96 @@ async def test_read_message(
assert mess == expected_message
assert dict(stream_reader._message_batches) == batches_after

@pytest.mark.parametrize(
"batches_before,max_messages,actual_messages,batches_after",
[
(
{
0: PublicBatch(
messages=[stub_message(1)],
_partition_session=stub_partition_session(),
_bytes_size=4,
_codec=Codec.CODEC_RAW,
)
},
None,
1,
{},
),
(
{
0: PublicBatch(
messages=[stub_message(1), stub_message(2)],
_partition_session=stub_partition_session(),
_bytes_size=4,
_codec=Codec.CODEC_RAW,
),
1: PublicBatch(
messages=[stub_message(3), stub_message(4)],
_partition_session=stub_partition_session(1),
_bytes_size=4,
_codec=Codec.CODEC_RAW,
),
},
1,
1,
{
1: PublicBatch(
messages=[stub_message(3), stub_message(4)],
_partition_session=stub_partition_session(1),
_bytes_size=4,
_codec=Codec.CODEC_RAW,
),
0: PublicBatch(
messages=[stub_message(2)],
_partition_session=stub_partition_session(),
_bytes_size=2,
_codec=Codec.CODEC_RAW,
),
},
),
(
{
0: PublicBatch(
messages=[stub_message(1)],
_partition_session=stub_partition_session(),
_bytes_size=4,
_codec=Codec.CODEC_RAW,
),
1: PublicBatch(
messages=[stub_message(2), stub_message(3)],
_partition_session=stub_partition_session(1),
_bytes_size=4,
_codec=Codec.CODEC_RAW,
),
},
100,
1,
{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • tests for bytes size split

1: PublicBatch(
messages=[stub_message(2), stub_message(3)],
_partition_session=stub_partition_session(1),
_bytes_size=4,
_codec=Codec.CODEC_RAW,
)
},
),
],
)
async def test_read_batch_max_messages(
self,
stream_reader,
batches_before: typing.List[datatypes.PublicBatch],
max_messages: typing.Optional[int],
actual_messages: int,
batches_after: typing.List[datatypes.PublicBatch],
):
stream_reader._message_batches = OrderedDict(batches_before)
batch = stream_reader.receive_batch_nowait(max_messages=max_messages)

assert len(batch.messages) == actual_messages
assert stream_reader._message_batches == OrderedDict(batches_after)

async def test_receive_batch_nowait(self, stream, stream_reader, partition_session):
assert stream_reader.receive_batch_nowait() is None

Expand Down
4 changes: 3 additions & 1 deletion ydb/_topic_reader/topic_reader_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def receive_batch(
self._check_closed()

return self._caller.safe_call_with_result(
self._async_reader.receive_batch(),
self._async_reader.receive_batch(
max_messages=max_messages,
),
timeout,
)

Expand Down
Loading