diff --git a/CHANGELOG.md b/CHANGELOG.md index de204df2..ba94d19c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Fixed hungup topic reader on unknown codec + ## 3.11.1 ## * fixed unexpected require requests module on import diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 2893aa55..362be059 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -74,6 +74,19 @@ def decode(b: bytes): batch = await reader.receive_batch() assert batch.messages[0].data.decode() == "123" + async def test_error_unknown_codec(self, driver, topic_path, topic_consumer): + codec = 10001 + + def encode(b: bytes): + return bytes(reversed(b)) + + async with driver.topic_client.writer(topic_path, codec=codec, encoders={codec: encode}) as writer: + await writer.write("123") + + async with driver.topic_client.reader(topic_path, topic_consumer) as reader: + with pytest.raises(ydb.TopicReaderUnexpectedCodecError): + await asyncio.wait_for(reader.receive_batch(), timeout=5) + async def test_read_from_two_topics(self, driver, topic_path, topic2_path, topic_consumer): async with driver.topic_client.writer(topic_path) as writer: await writer.write("1") diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 50684f7c..7b3d1cfa 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -8,6 +8,7 @@ from collections import deque from typing import Optional, Set, Dict, Union, Callable +import ydb from .. import _apis, issues from .._utilities import AtomicCounter from ..aio import Driver @@ -35,7 +36,7 @@ class TopicReaderError(YdbError): pass -class TopicReaderUnexpectedCodec(YdbError): +class PublicTopicReaderUnexpectedCodecError(YdbError): pass @@ -222,9 +223,7 @@ def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.Co async def close(self, flush: bool): if self._stream_reader: - if flush: - await self.flush() - await self._stream_reader.close() + await self._stream_reader.close(flush) for task in self._background_tasks: task.cancel() @@ -339,9 +338,12 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess self._update_token_event.set() self._background_tasks.add(asyncio.create_task(self._read_messages_loop(), name="read_messages_loop")) - self._background_tasks.add(asyncio.create_task(self._decode_batches_loop())) + self._background_tasks.add(asyncio.create_task(self._decode_batches_loop(), name="decode_batches")) if self._get_token_function: self._background_tasks.add(asyncio.create_task(self._update_token_loop(), name="update_token_loop")) + self._background_tasks.add( + asyncio.create_task(self._handle_background_errors(), name="handle_background_errors") + ) async def wait_error(self): raise await self._first_error @@ -411,6 +413,17 @@ def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.Co return waiter + async def _handle_background_errors(self): + done, _ = await asyncio.wait(self._background_tasks, return_when=asyncio.FIRST_EXCEPTION) + for f in done: + f = f # type: asyncio.Future + err = f.exception() + if not isinstance(err, ydb.Error): + old_err = err + err = ydb.Error("Background process failed unexpected") + err.__cause__ = old_err + self._set_first_error(err) + async def _read_messages_loop(self): try: self._stream.write( @@ -602,7 +615,7 @@ async def _decode_batch_inplace(self, batch): try: decode_func = self._decoders[batch._codec] except KeyError: - raise TopicReaderUnexpectedCodec("Receive message with unexpected codec: %s" % batch._codec) + raise PublicTopicReaderUnexpectedCodecError("Receive message with unexpected codec: %s" % batch._codec) decode_data_futures = [] for message in batch.messages: @@ -628,9 +641,6 @@ def _get_first_error(self) -> Optional[YdbError]: return self._first_error.result() async def flush(self): - if self._closed: - raise RuntimeError("Flush on closed Stream") - futures = [] for session in self._partition_sessions.values(): futures.extend(w.future for w in session._ack_waiters) @@ -638,12 +648,15 @@ async def flush(self): if futures: await asyncio.wait(futures) - async def close(self): + async def close(self, flush: bool): if self._closed: return self._closed = True + if flush: + await self.flush() + self._set_first_error(TopicReaderStreamClosedError()) self._state_changed.set() self._stream.close() diff --git a/ydb/_topic_reader/topic_reader_asyncio_test.py b/ydb/_topic_reader/topic_reader_asyncio_test.py index 860676d0..9af91b1b 100644 --- a/ydb/_topic_reader/topic_reader_asyncio_test.py +++ b/ydb/_topic_reader/topic_reader_asyncio_test.py @@ -183,14 +183,14 @@ async def stream_reader(self, stream_reader_started: ReaderStream): yield stream_reader_started assert stream_reader_started._get_first_error() is None - await stream_reader_started.close() + await stream_reader_started.close(False) @pytest.fixture() async def stream_reader_finish_with_error(self, stream_reader_started: ReaderStream): yield stream_reader_started assert stream_reader_started._get_first_error() is not None - await stream_reader_started.close() + await stream_reader_started.close(False) @staticmethod def create_message( @@ -372,7 +372,7 @@ async def test_close_ack_waiters_when_close_stream_reader( self, stream_reader_started: ReaderStream, partition_session ): waiter = partition_session.add_waiter(self.partition_session_committed_offset + 1) - await wait_for_fast(stream_reader_started.close()) + await wait_for_fast(stream_reader_started.close(False)) with pytest.raises(topic_reader_asyncio.PublicTopicReaderPartitionExpiredError): waiter.future.result() @@ -402,7 +402,7 @@ async def test_flush(self, stream, stream_reader_started: ReaderStream, partitio # don't raises assert waiter.future.result() is None - await wait_for_fast(stream_reader_started.close()) + await wait_for_fast(stream_reader_started.close(False)) async def test_commit_ranges_for_received_messages( self, stream, stream_reader_started: ReaderStream, partition_session @@ -422,7 +422,7 @@ async def test_commit_ranges_for_received_messages( received = stream_reader_started.receive_batch_nowait().messages assert received == [m2] - await stream_reader_started.close() + await stream_reader_started.close(False) # noinspection PyTypeChecker @pytest.mark.parametrize( @@ -613,7 +613,7 @@ async def test_init_reader(self, stream, default_reader_settings): ) assert reader._session_id == "test" - await reader.close() + await reader.close(False) async def test_start_partition( self, @@ -1230,7 +1230,7 @@ async def test_update_token(self, stream): got = await wait_for_fast(stream.from_client.get()) assert expected == got - await reader.close() + await reader.close(False) async def test_read_unknown_message(self, stream, stream_reader, caplog): class TestMessage: diff --git a/ydb/topic.py b/ydb/topic.py index 2175af47..948bcff4 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -15,6 +15,7 @@ "TopicReaderMessage", "TopicReaderSelector", "TopicReaderSettings", + "TopicReaderUnexpectedCodecError", "TopicReaderPartitionExpiredError", "TopicStatWindow", "TopicWriteResult", @@ -49,6 +50,7 @@ from ._topic_reader.topic_reader_asyncio import ( PublicAsyncIOReader as TopicReaderAsyncIO, PublicTopicReaderPartitionExpiredError as TopicReaderPartitionExpiredError, + PublicTopicReaderUnexpectedCodecError as TopicReaderUnexpectedCodecError, ) from ._topic_writer.topic_writer import ( # noqa: F401