diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 362be059..74c8bccd 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -40,15 +40,14 @@ async def test_read_and_commit_with_close_reader(self, driver, topic_with_messag assert message != message2 async def test_read_and_commit_with_ack(self, driver, topic_with_messages, topic_consumer): - reader = driver.topic_client.reader(topic_with_messages, topic_consumer) - batch = await reader.receive_batch() - await reader.commit_with_ack(batch) + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + message = await reader.receive_message() + await reader.commit_with_ack(message) - reader = driver.topic_client.reader(topic_with_messages, topic_consumer) - batch2 = await reader.receive_batch() - assert batch.messages[0] != batch2.messages[0] + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + batch = await reader.receive_batch() - await reader.close() + assert message != batch.messages[0] async def test_read_compressed_messages(self, driver, topic_path, topic_consumer): async with driver.topic_client.writer(topic_path, codec=ydb.TopicCodec.GZIP) as writer: @@ -147,12 +146,12 @@ def test_read_and_commit_with_close_reader(self, driver_sync, topic_with_message def test_read_and_commit_with_ack(self, driver_sync, topic_with_messages, topic_consumer): reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer) - batch = reader.receive_batch() - reader.commit_with_ack(batch) + message = reader.receive_message() + reader.commit_with_ack(message) reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer) - batch2 = reader.receive_batch() - assert batch.messages[0] != batch2.messages[0] + batch = reader.receive_batch() + assert message != batch.messages[0] def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer): with driver_sync.topic_client.writer(topic_path, codec=ydb.TopicCodec.GZIP) as writer: diff --git a/tests/topics/test_topic_writer.py b/tests/topics/test_topic_writer.py index 3817e34d..dfc3cdd0 100644 --- a/tests/topics/test_topic_writer.py +++ b/tests/topics/test_topic_writer.py @@ -41,10 +41,9 @@ async def test_random_producer_id(self, driver: ydb.aio.Driver, topic_path, topi async with driver.topic_client.writer(topic_path) as writer: await writer.write(ydb.TopicWriterMessage(data="123".encode())) - batch1 = await topic_reader.receive_batch() - batch2 = await topic_reader.receive_batch() + batch = await topic_reader.receive_batch() - assert batch1.messages[0].producer_id != batch2.messages[0].producer_id + assert batch.messages[0].producer_id != batch.messages[1].producer_id async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path): async with driver.topic_client.writer( @@ -83,12 +82,12 @@ async def test_write_multi_message_with_ack( assert batch.messages[0].seqno == 1 assert batch.messages[0].data == "123".encode() - # remove second recieve batch when implement batching - # https://github.com/ydb-platform/ydb-python-sdk/issues/142 - batch = await topic_reader.receive_batch() - assert batch.messages[0].offset == 1 - assert batch.messages[0].seqno == 2 - assert batch.messages[0].data == "456".encode() + # # remove second recieve batch when implement batching + # # https://github.com/ydb-platform/ydb-python-sdk/issues/142 + # batch = await topic_reader.receive_batch() + assert batch.messages[1].offset == 1 + assert batch.messages[1].seqno == 2 + assert batch.messages[1].data == "456".encode() @pytest.mark.parametrize( "codec", @@ -186,10 +185,9 @@ def test_random_producer_id( with driver_sync.topic_client.writer(topic_path) as writer: writer.write(ydb.TopicWriterMessage(data="123".encode())) - batch1 = topic_reader_sync.receive_batch() - batch2 = topic_reader_sync.receive_batch() + batch = topic_reader_sync.receive_batch() - assert batch1.messages[0].producer_id != batch2.messages[0].producer_id + assert batch.messages[0].producer_id != batch.messages[1].producer_id def test_write_multi_message_with_ack( self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader @@ -210,10 +208,10 @@ def test_write_multi_message_with_ack( # remove second recieve batch when implement batching # https://github.com/ydb-platform/ydb-python-sdk/issues/142 - batch = topic_reader_sync.receive_batch() - assert batch.messages[0].offset == 1 - assert batch.messages[0].seqno == 2 - assert batch.messages[0].data == "456".encode() + # batch = topic_reader_sync.receive_batch() + assert batch.messages[1].offset == 1 + assert batch.messages[1].seqno == 2 + assert batch.messages[1].data == "456".encode() @pytest.mark.parametrize( "codec", diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 81c6d9f4..5004d25d 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -3,9 +3,9 @@ import asyncio import concurrent.futures import gzip +import random import typing from asyncio import Task -from collections import deque from typing import Optional, Set, Dict, Union, Callable import ydb @@ -264,7 +264,7 @@ class ReaderStream: _state_changed: asyncio.Event _closed: bool - _message_batches: typing.Deque[datatypes.PublicBatch] + _message_batches: typing.Dict[int, datatypes.PublicBatch] _first_error: asyncio.Future[YdbError] _update_token_interval: Union[int, float] @@ -296,7 +296,7 @@ def __init__( self._closed = False self._first_error = asyncio.get_running_loop().create_future() self._batches_to_decode = asyncio.Queue() - self._message_batches = deque() + self._message_batches = dict() self._update_token_interval = settings.update_token_interval self._get_token_function = get_token_function @@ -359,6 +359,10 @@ async def wait_messages(self): await self._state_changed.wait() self._state_changed.clear() + def _get_random_batch(self): + rnd_id = random.choice(list(self._message_batches.keys())) + return rnd_id, self._message_batches[rnd_id] + def receive_batch_nowait(self): if self._get_first_error(): raise self._get_first_error() @@ -366,22 +370,25 @@ def receive_batch_nowait(self): if not self._message_batches: return None - batch = self._message_batches.popleft() + part_sess_id, batch = self._get_random_batch() self._buffer_release_bytes(batch._bytes_size) + del self._message_batches[part_sess_id] + return batch def receive_message_nowait(self): if self._get_first_error(): raise self._get_first_error() - try: - batch = self._message_batches[0] - message = batch.pop_message() - except IndexError: + if not self._message_batches: return None - if batch.empty(): - self.receive_batch_nowait() + part_sess_id, batch = self._get_random_batch() + + message = batch.messages.pop(0) + if len(batch.messages) == 0: + self._buffer_release_bytes(batch._bytes_size) + del self._message_batches[part_sess_id] return message @@ -605,9 +612,18 @@ async def _decode_batches_loop(self): while True: batch = await self._batches_to_decode.get() await self._decode_batch_inplace(batch) - self._message_batches.append(batch) + self._add_batch_to_queue(batch) self._state_changed.set() + def _add_batch_to_queue(self, batch: datatypes.PublicBatch): + part_sess_id = batch._partition_session.id + if part_sess_id in self._message_batches: + self._message_batches[part_sess_id].messages.extend(batch.messages) + self._message_batches[part_sess_id]._bytes_size += batch._bytes_size + return + + self._message_batches[part_sess_id] = batch + async def _decode_batch_inplace(self, batch): if batch._codec == Codec.CODEC_RAW: return