Skip to content

Commit

Permalink
tests fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvoleg committed Sep 24, 2024
1 parent 2e6ad26 commit 0a07253
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 65 deletions.
12 changes: 9 additions & 3 deletions tests/topics/test_topic_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import ydb.aio

from ydb._topic_common.test_helpers import wait_condition


@pytest.mark.asyncio
class TestTopicWriterAsyncIO:
Expand Down Expand Up @@ -43,6 +45,10 @@ async def test_random_producer_id(self, driver: ydb.aio.Driver, topic_path, topi

batch = await topic_reader.receive_batch()

if len(batch.messages) == 1:
batch2 = await topic_reader.receive_batch()
batch.messages.extend(batch2.messages)

assert batch.messages[0].producer_id != batch.messages[1].producer_id

async def test_auto_flush_on_close(self, driver: ydb.aio.Driver, topic_path):
Expand Down Expand Up @@ -201,14 +207,14 @@ def test_write_multi_message_with_ack(
)

batch = topic_reader_sync.receive_batch()
if len(batch.messages) == 1:
batch2 = topic_reader_sync.receive_batch()
batch.messages.extend(batch2.messages)

assert batch.messages[0].offset == 0
assert batch.messages[0].seqno == 1
assert batch.messages[0].data == "123".encode()

# remove second recieve batch when implement batching
# https://github.com/ydb-platform/ydb-python-sdk/issues/142
# batch = topic_reader_sync.receive_batch()
assert batch.messages[1].offset == 1
assert batch.messages[1].seqno == 2
assert batch.messages[1].data == "456".encode()
Expand Down
19 changes: 11 additions & 8 deletions ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import typing
from asyncio import Task
from collections import OrderedDict
from typing import Optional, Set, Dict, Union, Callable

import ydb
Expand Down Expand Up @@ -296,7 +297,7 @@ def __init__(
self._closed = False
self._first_error = asyncio.get_running_loop().create_future()
self._batches_to_decode = asyncio.Queue()
self._message_batches = dict()
self._message_batches = OrderedDict()

self._update_token_interval = settings.update_token_interval
self._get_token_function = get_token_function
Expand Down Expand Up @@ -359,9 +360,9 @@ async def wait_messages(self):
await self._state_changed.wait()
self._state_changed.clear()

def _get_random_batch(self):
rnd_id = random.choice(list(self._message_batches.keys()))
return rnd_id, self._message_batches[rnd_id]
def _get_first_batch(self) -> typing.Tuple[int, datatypes.PublicBatch]:
first_id, batch = self._message_batches.popitem(last = False)
return first_id, batch

def receive_batch_nowait(self):
if self._get_first_error():
Expand All @@ -370,9 +371,8 @@ def receive_batch_nowait(self):
if not self._message_batches:
return None

part_sess_id, batch = self._get_random_batch()
_, batch = self._get_first_batch()
self._buffer_release_bytes(batch._bytes_size)
del self._message_batches[part_sess_id]

return batch

Expand All @@ -383,12 +383,15 @@ def receive_message_nowait(self):
if not self._message_batches:
return None

part_sess_id, batch = self._get_random_batch()
part_sess_id, batch = self._get_first_batch()

message = batch.messages.pop(0)

if len(batch.messages) == 0:
self._buffer_release_bytes(batch._bytes_size)
del self._message_batches[part_sess_id]
else:
# TODO: we should somehow release bytes from single message as well
self._message_batches[part_sess_id] = batch

return message

Expand Down
109 changes: 55 additions & 54 deletions ydb/_topic_reader/topic_reader_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import datetime
import gzip
import typing
from collections import deque
from collections import OrderedDict
from dataclasses import dataclass
from unittest import mock

Expand Down Expand Up @@ -52,9 +52,9 @@ def default_executor():
executor.shutdown()


def stub_partition_session():
def stub_partition_session(id: int = 0):
return datatypes.PartitionSession(
id=0,
id=id,
state=datatypes.PartitionSession.State.Active,
topic_path="asd",
partition_id=1,
Expand Down Expand Up @@ -212,21 +212,27 @@ def create_message(
_commit_end_offset=partition_session._next_message_start_commit_offset + offset_delta,
)

async def send_message(self, stream_reader, message: PublicMessage):
await self.send_batch(stream_reader, [message])
async def send_message(self, stream_reader, message: PublicMessage, new_batch=True):
await self.send_batch(stream_reader, [message], new_batch=new_batch)

async def send_batch(self, stream_reader, batch: typing.List[PublicMessage]):
async def send_batch(self, stream_reader, batch: typing.List[PublicMessage], new_batch=True):
if len(batch) == 0:
return

first_message = batch[0]
for message in batch:
assert message._partition_session is first_message._partition_session

partition_session_id = first_message._partition_session.id

def batch_count():
return len(stream_reader._message_batches)

def batch_size():
return len(stream_reader._message_batches[partition_session_id].messages)

initial_batches = batch_count()
initial_batch_size = batch_size() if not new_batch else 0

stream = stream_reader._stream # type: StreamMock
stream.from_server.put_nowait(
Expand Down Expand Up @@ -261,7 +267,10 @@ def batch_count():
),
)
)
await wait_condition(lambda: batch_count() > initial_batches)
if new_batch:
await wait_condition(lambda: batch_count() > initial_batches)
else:
await wait_condition(lambda: batch_size() > initial_batch_size)

async def test_unknown_error(self, stream, stream_reader_finish_with_error):
class TestError(Exception):
Expand Down Expand Up @@ -412,15 +421,11 @@ async def test_commit_ranges_for_received_messages(
m2._commit_start_offset = m1.offset + 1

await self.send_message(stream_reader_started, m1)
await self.send_message(stream_reader_started, m2)

await stream_reader_started.wait_messages()
received = stream_reader_started.receive_batch_nowait().messages
assert received == [m1]
await self.send_message(stream_reader_started, m2, new_batch=False)

await stream_reader_started.wait_messages()
received = stream_reader_started.receive_batch_nowait().messages
assert received == [m2]
assert received == [m1, m2]

await stream_reader_started.close(False)

Expand Down Expand Up @@ -860,7 +865,7 @@ def reader_batch_count():

assert stream_reader._buffer_size_bytes == initial_buffer_size - bytes_size

last_batch = stream_reader._message_batches[-1]
_, last_batch = stream_reader._message_batches.popitem()
assert last_batch == PublicBatch(
messages=[
PublicMessage(
Expand Down Expand Up @@ -1059,74 +1064,74 @@ async def test_read_batches(self, stream_reader, partition_session, second_parti
@pytest.mark.parametrize(
"batches_before,expected_message,batches_after",
[
([], None, []),
({}, None, {}),
(
[
PublicBatch(
{
0: PublicBatch(
messages=[stub_message(1)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
)
],
},
stub_message(1),
[],
{},
),
(
[
PublicBatch(
{
0: PublicBatch(
messages=[stub_message(1), stub_message(2)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
PublicBatch(
1: PublicBatch(
messages=[stub_message(3), stub_message(4)],
_partition_session=stub_partition_session(),
_partition_session=stub_partition_session(1),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
],
},
stub_message(1),
[
PublicBatch(
{
0: PublicBatch(
messages=[stub_message(2)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
PublicBatch(
1: PublicBatch(
messages=[stub_message(3), stub_message(4)],
_partition_session=stub_partition_session(),
_partition_session=stub_partition_session(1),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
],
},
),
(
[
PublicBatch(
{
0: PublicBatch(
messages=[stub_message(1)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
PublicBatch(
1: PublicBatch(
messages=[stub_message(2), stub_message(3)],
_partition_session=stub_partition_session(),
_partition_session=stub_partition_session(1),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
],
},
stub_message(1),
[
PublicBatch(
{
1: PublicBatch(
messages=[stub_message(2), stub_message(3)],
_partition_session=stub_partition_session(),
_partition_session=stub_partition_session(1),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
)
],
},
),
],
)
Expand All @@ -1137,11 +1142,11 @@ async def test_read_message(
expected_message: PublicMessage,
batches_after: typing.List[datatypes.PublicBatch],
):
stream_reader._message_batches = deque(batches_before)
stream_reader._message_batches = OrderedDict(batches_before)
mess = stream_reader.receive_message_nowait()

assert mess == expected_message
assert list(stream_reader._message_batches) == batches_after
assert dict(stream_reader._message_batches) == batches_after

async def test_receive_batch_nowait(self, stream, stream_reader, partition_session):
assert stream_reader.receive_batch_nowait() is None
Expand All @@ -1152,30 +1157,21 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi
await self.send_message(stream_reader, mess1)

mess2 = self.create_message(partition_session, 2, 1)
await self.send_message(stream_reader, mess2)
await self.send_message(stream_reader, mess2, new_batch=False)

assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size

received = stream_reader.receive_batch_nowait()
assert received == PublicBatch(
messages=[mess1],
messages=[mess1, mess2],
_partition_session=mess1._partition_session,
_bytes_size=self.default_batch_size,
_codec=Codec.CODEC_RAW,
)

received = stream_reader.receive_batch_nowait()
assert received == PublicBatch(
messages=[mess2],
_partition_session=mess2._partition_session,
_bytes_size=self.default_batch_size,
_bytes_size=self.default_batch_size * 2,
_codec=Codec.CODEC_RAW,
)

assert stream_reader._buffer_size_bytes == initial_buffer_size

assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message
assert StreamReadMessage.ReadRequest(self.default_batch_size) == stream.from_client.get_nowait().client_message
assert StreamReadMessage.ReadRequest(self.default_batch_size * 2) == stream.from_client.get_nowait().client_message

with pytest.raises(asyncio.QueueEmpty):
stream.from_client.get_nowait()
Expand All @@ -1186,13 +1182,18 @@ async def test_receive_message_nowait(self, stream, stream_reader, partition_ses
initial_buffer_size = stream_reader._buffer_size_bytes

await self.send_batch(
stream_reader, [self.create_message(partition_session, 1, 1), self.create_message(partition_session, 2, 1)]
stream_reader,
[
self.create_message(partition_session, 1, 1),
self.create_message(partition_session, 2, 1),
],
)
await self.send_batch(
stream_reader,
[
self.create_message(partition_session, 10, 1),
],
new_batch=False,
)

assert stream_reader._buffer_size_bytes == initial_buffer_size - 2 * self.default_batch_size
Expand Down

0 comments on commit 0a07253

Please sign in to comment.