diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index dfc3cdd0..1c47893f 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -7,6 +7,8 @@ import ydb.aio +from ydb._topic_common.test_helpers import wait_condition + @pytest.mark.asyncio class TestTopicWriterAsyncIO: @@ -43,6 +45,10 @@ async def test_random_producer_id(self, driver: ydb.aio.Driver, topic_path, topi batch = await topic_reader.receive_batch() + if len(batch.messages) == 1: + batch2 = await topic_reader.receive_batch() + batch.messages.extend(batch2.messages) + assert batch.messages[0].producer_id != batch.messages[1].producer_id async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): @@ -201,14 +207,14 @@ def test_write_multi_message_with_ack( ) batch = topic_reader_sync.receive_batch() + if len(batch.messages) == 1: + batch2 = topic_reader_sync.receive_batch() + batch.messages.extend(batch2.messages) assert batch.messages[0].offset == 0 assert batch.messages[0].seqno == 1 assert batch.messages[0].data == "123".encode() - # remove second recieve batch when implement batching - # https://github.com/ydb-platform/ydb-python-sdk/issues/142 - # batch = topic_reader_sync.receive_batch() assert batch.messages[1].offset == 1 assert batch.messages[1].seqno == 2 assert batch.messages[1].data == "456".encode() diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 5004d25d..5dcae1ee 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -6,6 +6,7 @@ import random import typing from asyncio import Task +from collections import OrderedDict from typing import Optional, Set, Dict, Union, Callable import ydb @@ -296,7 +297,7 @@ def __init__( self._closed = False self._first_error = asyncio.get_running_loop().create_future() self._batches_to_decode = asyncio.Queue() - self._message_batches = dict() + self._message_batches = OrderedDict() self._update_token_interval = settings.update_token_interval self._get_token_function = get_token_function @@ -359,9 +360,9 @@ async def wait_messages(self): await self._state_changed.wait() self._state_changed.clear() - def _get_random_batch(self): - rnd_id = random.choice(list(self._message_batches.keys())) - return rnd_id, self._message_batches[rnd_id] + def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]: + first_id, batch = self._message_batches.popitem(last = False) + return first_id, batch def receive_batch_nowait(self): if self._get_first_error(): @@ -370,9 +371,8 @@ def receive_batch_nowait(self): if not self._message_batches: return None - part_sess_id, batch = self._get_random_batch() + _, batch = self._get_first_batch() self._buffer_release_bytes(batch._bytes_size) - del self._message_batches[part_sess_id] return batch @@ -383,12 +383,15 @@ def receive_message_nowait(self): if not self._message_batches: return None - part_sess_id, batch = self._get_random_batch() + part_sess_id, batch = self._get_first_batch() message = batch.messages.pop(0) + if len(batch.messages) == 0: self._buffer_release_bytes(batch._bytes_size) - del self._message_batches[part_sess_id] + else: + # TODO: we should somehow release bytes from single message as well + self._message_batches[part_sess_id] = batch return message diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 9af91b1b..f74b7d7e 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -4,7 +4,7 @@ import datetime import gzip import typing -from collections import deque +from collections import OrderedDict from dataclasses import dataclass from unittest import mock @@ -52,9 +52,9 @@ def default_executor(): executor.shutdown() -def stub_partition_session(): +def stub_partition_session(id: int = 0): return datatypes.PartitionSession( - id=0, + id=id, state=datatypes.PartitionSession.State.Active, topic_path="asd", partition_id=1, @@ -212,10 +212,10 @@ def create_message( _commit_end_offset=partition_session._next_message_start_commit_offset + offset_delta, ) - async def send_message(self, stream_reader, message: PublicMessage): - await self.send_batch(stream_reader, [message]) + async def send_message(self, stream_reader, message: PublicMessage, new_batch=True): + await self.send_batch(stream_reader, [message], new_batch=new_batch) - async def send_batch(self, stream_reader, batch: typing.List[PublicMessage]): + async def send_batch(self, stream_reader, batch: typing.List[PublicMessage], new_batch=True): if len(batch) == 0: return @@ -223,10 +223,16 @@ async def send_batch(self, stream_reader, batch: typing.List[PublicMessage]): for message in batch: assert message._partition_session is first_message._partition_session + partition_session_id = first_message._partition_session.id + def batch_count(): return len(stream_reader._message_batches) + def batch_size(): + return len(stream_reader._message_batches[partition_session_id].messages) + initial_batches = batch_count() + initial_batch_size = batch_size() if not new_batch else 0 stream = stream_reader._stream # type: StreamMock stream.from_server.put_nowait( @@ -261,7 +267,10 @@ def batch_count(): ), ) ) - await wait_condition(lambda: batch_count() > initial_batches) + if new_batch: + await wait_condition(lambda: batch_count() > initial_batches) + else: + await wait_condition(lambda: batch_size() > initial_batch_size) async def test_unknown_error(self, stream, stream_reader_finish_with_error): class TestError(Exception): @@ -412,15 +421,11 @@ async def test_commit_ranges_for_received_messages( m2._commit_start_offset = m1.offset + 1 await self.send_message(stream_reader_started, m1) - await self.send_message(stream_reader_started, m2) - - await stream_reader_started.wait_messages() - received = stream_reader_started.receive_batch_nowait().messages - assert received == [m1] + await self.send_message(stream_reader_started, m2, new_batch=False) await stream_reader_started.wait_messages() received = stream_reader_started.receive_batch_nowait().messages - assert received == [m2] + assert received == [m1, m2] await stream_reader_started.close(False) @@ -860,7 +865,7 @@ def reader_batch_count(): assert stream_reader._buffer_size_bytes == initial_buffer_size - bytes_size - last_batch = stream_reader._message_batches[-1] + _, last_batch = stream_reader._message_batches.popitem() assert last_batch == PublicBatch( messages=[ PublicMessage( @@ -1059,74 +1064,74 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti @pytest.mark.parametrize( "batches_before,expected_message,batches_after", [ - ([], None, []), + ({}, None, {}), ( - [ - PublicBatch( + { + 0: PublicBatch( messages=[stub_message(1)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, ) - ], + }, stub_message(1), - [], + {}, ), ( - [ - PublicBatch( + { + 0: PublicBatch( messages=[stub_message(1), stub_message(2)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - PublicBatch( + 1: PublicBatch( messages=[stub_message(3), stub_message(4)], - _partition_session=stub_partition_session(), + _partition_session=stub_partition_session(1), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - ], + }, stub_message(1), - [ - PublicBatch( + { + 0: PublicBatch( messages=[stub_message(2)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - PublicBatch( + 1: PublicBatch( messages=[stub_message(3), stub_message(4)], - _partition_session=stub_partition_session(), + _partition_session=stub_partition_session(1), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - ], + }, ), ( - [ - PublicBatch( + { + 0: PublicBatch( messages=[stub_message(1)], _partition_session=stub_partition_session(), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - PublicBatch( + 1: PublicBatch( messages=[stub_message(2), stub_message(3)], - _partition_session=stub_partition_session(), + _partition_session=stub_partition_session(1), _bytes_size=0, _codec=Codec.CODEC_RAW, ), - ], + }, stub_message(1), - [ - PublicBatch( + { + 1: PublicBatch( messages=[stub_message(2), stub_message(3)], - _partition_session=stub_partition_session(), + _partition_session=stub_partition_session(1), _bytes_size=0, _codec=Codec.CODEC_RAW, ) - ], + }, ), ], ) @@ -1137,11 +1142,11 @@ async def test_read_message( expected_message: PublicMessage, batches_after: typing.List[datatypes.PublicBatch], ): - stream_reader._message_batches = deque(batches_before) + stream_reader._message_batches = OrderedDict(batches_before) mess = stream_reader.receive_message_nowait() assert mess == expected_message - assert list(stream_reader._message_batches) == batches_after + assert dict(stream_reader._message_batches) == batches_after async def test_receive_batch_nowait(self, stream, stream_reader, partition_session): assert stream_reader.receive_batch_nowait() is None @@ -1152,30 +1157,21 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi await self.send_message(stream_reader, mess1) mess2 = self.create_message(partition_session, 2, 1) - await self.send_message(stream_reader, mess2) + await self.send_message(stream_reader, mess2, new_batch=False) assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size received = stream_reader.receive_batch_nowait() assert received == PublicBatch( - messages=[mess1], + messages=[mess1, mess2], _partition_session=mess1._partition_session, - _bytes_size=self.default_batch_size, - _codec=Codec.CODEC_RAW, - ) - - received = stream_reader.receive_batch_nowait() - assert received == PublicBatch( - messages=[mess2], - _partition_session=mess2._partition_session, - _bytes_size=self.default_batch_size, + _bytes_size=self.default_batch_size * 2, _codec=Codec.CODEC_RAW, ) assert stream_reader._buffer_size_bytes == initial_buffer_size - assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message - assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message + assert StreamReadMessage.ReadRequest(self.default_batch_size * 2) == stream.from_client.get_nowait().client_message with pytest.raises(asyncio.QueueEmpty): stream.from_client.get_nowait() @@ -1186,13 +1182,18 @@ async def test_receive_message_nowait(self, stream, stream_reader, partition_ses initial_buffer_size = stream_reader._buffer_size_bytes await self.send_batch( - stream_reader, [self.create_message(partition_session, 1, 1), self.create_message(partition_session, 2, 1)] + stream_reader, + [ + self.create_message(partition_session, 1, 1), + self.create_message(partition_session, 2, 1), + ], ) await self.send_batch( stream_reader, [ self.create_message(partition_session, 10, 1), ], + new_batch=False, ) assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size