Skip to content

Commit

Permalink
Ability to batch messages in topic reader
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvoleg committed Sep 24, 2024
1 parent 8928b00 commit 7aa3545
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 37 deletions.
21 changes: 10 additions & 11 deletions tests/topics/test_topic_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,14 @@ async def test_read_and_commit_with_close_reader(self, driver, topic_with_messag
assert message != message2

async def test_read_and_commit_with_ack(self, driver, topic_with_messages, topic_consumer):
reader = driver.topic_client.reader(topic_with_messages, topic_consumer)
batch = await reader.receive_batch()
await reader.commit_with_ack(batch)
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
message = await reader.receive_message()
await reader.commit_with_ack(message)

reader = driver.topic_client.reader(topic_with_messages, topic_consumer)
batch2 = await reader.receive_batch()
assert batch.messages[0] != batch2.messages[0]
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
batch = await reader.receive_batch()

await reader.close()
assert message != batch.messages[0]

async def test_read_compressed_messages(self, driver, topic_path, topic_consumer):
async with driver.topic_client.writer(topic_path, codec=ydb.TopicCodec.GZIP) as writer:
Expand Down Expand Up @@ -147,12 +146,12 @@ def test_read_and_commit_with_close_reader(self, driver_sync, topic_with_message

def test_read_and_commit_with_ack(self, driver_sync, topic_with_messages, topic_consumer):
reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer)
batch = reader.receive_batch()
reader.commit_with_ack(batch)
message = reader.receive_message()
reader.commit_with_ack(message)

reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer)
batch2 = reader.receive_batch()
assert batch.messages[0] != batch2.messages[0]
batch = reader.receive_batch()
assert message != batch.messages[0]

def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer):
with driver_sync.topic_client.writer(topic_path, codec=ydb.TopicCodec.GZIP) as writer:
Expand Down
30 changes: 14 additions & 16 deletions tests/topics/test_topic_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,9 @@ async def test_random_producer_id(self, driver: ydb.aio.Driver, topic_path, topi
async with driver.topic_client.writer(topic_path) as writer:
await writer.write(ydb.TopicWriterMessage(data="123".encode()))

batch1 = await topic_reader.receive_batch()
batch2 = await topic_reader.receive_batch()
batch = await topic_reader.receive_batch()

assert batch1.messages[0].producer_id != batch2.messages[0].producer_id
assert batch.messages[0].producer_id != batch.messages[1].producer_id

async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path):
async with driver.topic_client.writer(
Expand Down Expand Up @@ -83,12 +82,12 @@ async def test_write_multi_message_with_ack(
assert batch.messages[0].seqno == 1
assert batch.messages[0].data == "123".encode()

# remove second recieve batch when implement batching
# https://github.com/ydb-platform/ydb-python-sdk/issues/142
batch = await topic_reader.receive_batch()
assert batch.messages[0].offset == 1
assert batch.messages[0].seqno == 2
assert batch.messages[0].data == "456".encode()
# # remove second recieve batch when implement batching
# # https://github.com/ydb-platform/ydb-python-sdk/issues/142
# batch = await topic_reader.receive_batch()
assert batch.messages[1].offset == 1
assert batch.messages[1].seqno == 2
assert batch.messages[1].data == "456".encode()

@pytest.mark.parametrize(
"codec",
Expand Down Expand Up @@ -186,10 +185,9 @@ def test_random_producer_id(
with driver_sync.topic_client.writer(topic_path) as writer:
writer.write(ydb.TopicWriterMessage(data="123".encode()))

batch1 = topic_reader_sync.receive_batch()
batch2 = topic_reader_sync.receive_batch()
batch = topic_reader_sync.receive_batch()

assert batch1.messages[0].producer_id != batch2.messages[0].producer_id
assert batch.messages[0].producer_id != batch.messages[1].producer_id

def test_write_multi_message_with_ack(
self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader
Expand All @@ -210,10 +208,10 @@ def test_write_multi_message_with_ack(

# remove second recieve batch when implement batching
# https://github.com/ydb-platform/ydb-python-sdk/issues/142
batch = topic_reader_sync.receive_batch()
assert batch.messages[0].offset == 1
assert batch.messages[0].seqno == 2
assert batch.messages[0].data == "456".encode()
# batch = topic_reader_sync.receive_batch()
assert batch.messages[1].offset == 1
assert batch.messages[1].seqno == 2
assert batch.messages[1].data == "456".encode()

@pytest.mark.parametrize(
"codec",
Expand Down
37 changes: 27 additions & 10 deletions ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import concurrent.futures
import gzip
import random
import typing
from asyncio import Task
from collections import deque
Expand Down Expand Up @@ -264,7 +265,7 @@ class ReaderStream:

_state_changed: asyncio.Event
_closed: bool
_message_batches: typing.Deque[datatypes.PublicBatch]
_message_batches: typing.Dict[int, datatypes.PublicBatch]
_first_error: asyncio.Future[YdbError]

_update_token_interval: Union[int, float]
Expand Down Expand Up @@ -296,7 +297,7 @@ def __init__(
self._closed = False
self._first_error = asyncio.get_running_loop().create_future()
self._batches_to_decode = asyncio.Queue()
self._message_batches = deque()
self._message_batches = dict()

self._update_token_interval = settings.update_token_interval
self._get_token_function = get_token_function
Expand Down Expand Up @@ -359,29 +360,36 @@ async def wait_messages(self):
await self._state_changed.wait()
self._state_changed.clear()

def _get_random_batch(self):
rnd_id = random.choice(list(self._message_batches.keys()))
return rnd_id, self._message_batches[rnd_id]

def receive_batch_nowait(self):
if self._get_first_error():
raise self._get_first_error()

if not self._message_batches:
return None

batch = self._message_batches.popleft()
part_sess_id, batch = self._get_random_batch()
self._buffer_release_bytes(batch._bytes_size)
del self._message_batches[part_sess_id]

return batch

def receive_message_nowait(self):
if self._get_first_error():
raise self._get_first_error()

try:
batch = self._message_batches[0]
message = batch.pop_message()
except IndexError:
if not self._message_batches:
return None

if batch.empty():
self.receive_batch_nowait()
part_sess_id, batch = self._get_random_batch()

message = batch.messages.pop(0)
if len(batch.messages) == 0:
self._buffer_release_bytes(batch._bytes_size)
del self._message_batches[part_sess_id]

return message

Expand Down Expand Up @@ -605,9 +613,18 @@ async def _decode_batches_loop(self):
while True:
batch = await self._batches_to_decode.get()
await self._decode_batch_inplace(batch)
self._message_batches.append(batch)
self._add_batch_to_queue(batch)
self._state_changed.set()

def _add_batch_to_queue(self, batch: datatypes.PublicBatch):
part_sess_id = batch._partition_session.id
if part_sess_id in self._message_batches:
self._message_batches[part_sess_id].messages.extend(batch.messages)
self._message_batches[part_sess_id]._bytes_size += batch._bytes_size
return

self._message_batches[part_sess_id] = batch

async def _decode_batch_inplace(self, batch):
if batch._codec == Codec.CODEC_RAW:
return
Expand Down

0 comments on commit 7aa3545

Please sign in to comment.