From 8983c30bc6103d51ba48cc5db4839a7f289ccbcf Mon Sep 17 00:00:00 2001 From: SergioSim Date: Tue, 14 Nov 2023 21:22:04 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8(backends)=20add=20prefetch=20option?= =?UTF-8?q?=20for=20data=20backends?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We try to align the async data backend interface with the http backends. The `prefetch` option allows the caller to greedily read `prefetch` number of records before they are yielded by the generator. --- src/ralph/backends/data/async_es.py | 14 ++- src/ralph/backends/data/async_mongo.py | 14 ++- src/ralph/backends/data/base.py | 64 +++++++++-- src/ralph/backends/data/clickhouse.py | 2 +- src/ralph/backends/data/es.py | 2 +- src/ralph/backends/data/fs.py | 2 +- src/ralph/backends/data/ldp.py | 2 +- src/ralph/backends/data/mongo.py | 2 +- src/ralph/backends/data/s3.py | 2 +- src/ralph/backends/data/swift.py | 2 +- tests/backends/data/test_async_es.py | 17 ++- tests/backends/data/test_async_mongo.py | 16 ++- tests/backends/data/test_base.py | 144 ++++++++++++++++++++++-- tests/backends/data/test_s3.py | 7 +- 14 files changed, 250 insertions(+), 40 deletions(-) diff --git a/src/ralph/backends/data/async_es.py b/src/ralph/backends/data/async_es.py index 311ddf137..8acdc66ad 100644 --- a/src/ralph/backends/data/async_es.py +++ b/src/ralph/backends/data/async_es.py @@ -115,6 +115,7 @@ async def read( # noqa: PLR0913 chunk_size: Optional[int] = None, raw_output: bool = False, ignore_errors: bool = False, + prefetch: Optional[PositiveInt] = None, max_statements: Optional[PositiveInt] = None, ) -> Union[AsyncIterator[bytes], AsyncIterator[dict]]: """Read documents matching the query in the target index and yield them. @@ -130,8 +131,11 @@ async def read( # noqa: PLR0913 raw_output (bool): Controls whether to yield dictionaries or bytes. ignore_errors (bool): No impact as encoding errors are not expected in Elasticsearch results. + prefetch (int): The number of records to prefetch (queue) while yielding. + If `prefetch` is `None` it defaults to `1`, i.e. no records are + prefetched. max_statements (int): The maximum number of statements to yield. - If `None` (default), there is no maximum. + If `None` (default) or `0`, there is no maximum. Yield: bytes: The next raw document if `raw_output` is True. @@ -141,7 +145,13 @@ async def read( # noqa: PLR0913 BackendException: If a failure occurs during Elasticsearch connection. """ statements = super().read( - query, target, chunk_size, raw_output, ignore_errors, max_statements + query, + target, + chunk_size, + raw_output, + ignore_errors, + prefetch, + max_statements, ) async for statement in statements: yield statement diff --git a/src/ralph/backends/data/async_mongo.py b/src/ralph/backends/data/async_mongo.py index 811e10388..08410c426 100644 --- a/src/ralph/backends/data/async_mongo.py +++ b/src/ralph/backends/data/async_mongo.py @@ -123,6 +123,7 @@ async def read( # noqa: PLR0913 chunk_size: Optional[int] = None, raw_output: bool = False, ignore_errors: bool = False, + prefetch: Optional[PositiveInt] = None, max_statements: Optional[PositiveInt] = None, ) -> Union[AsyncIterator[bytes], AsyncIterator[dict]]: """Read documents matching the `query` from `target` collection and yield them. @@ -137,8 +138,11 @@ async def read( # noqa: PLR0913 ignore_errors (bool): If `True`, encoding errors during the read operation will be ignored and logged. If `False` (default), a `BackendException` is raised on any error. + prefetch (int): The number of records to prefetch (queue) while yielding. + If `prefetch` is `None` it defaults to `1`, i.e. no records are + prefetched. max_statements (int): The maximum number of statements to yield. - If `None` (default), there is no maximum. + If `None` (default) or `0`, there is no maximum. Yield: bytes: The next raw document if `raw_output` is True. @@ -150,7 +154,13 @@ async def read( # noqa: PLR0913 BackendParameterException: If the `target` is not a valid collection name. """ statements = super().read( - query, target, chunk_size, raw_output, ignore_errors, max_statements + query, + target, + chunk_size, + raw_output, + ignore_errors, + prefetch, + max_statements, ) async for statement in statements: yield statement diff --git a/src/ralph/backends/data/base.py b/src/ralph/backends/data/base.py index 308a1423b..eedcf6e82 100644 --- a/src/ralph/backends/data/base.py +++ b/src/ralph/backends/data/base.py @@ -2,6 +2,7 @@ import logging from abc import ABC, abstractmethod +from asyncio import Queue, create_task from enum import Enum, unique from functools import cached_property from io import IOBase @@ -299,7 +300,7 @@ def read( # noqa: PLR0913 will be ignored and logged. If `False` (default), a `BackendException` is raised on any error. max_statements (int): The maximum number of statements to yield. - If `None` (default), there is no maximum. + If `None` (default) or `0`, there is no maximum. Yield: dict: If `raw_output` is False. @@ -314,14 +315,15 @@ def read( # noqa: PLR0913 query = validate_backend_query(query, self.query_class, self.logger) reader = self._read_bytes if raw_output else self._read_dicts statements = reader(query, target, chunk_size, ignore_errors) - if max_statements is None: + if not max_statements: yield from statements return + max_statements -= 1 for i, statement in enumerate(statements): + yield statement if i >= max_statements: return - yield statement @abstractmethod def _read_bytes( @@ -507,6 +509,7 @@ async def read( # noqa: PLR0913 chunk_size: Optional[int] = None, raw_output: bool = False, ignore_errors: bool = False, + prefetch: Optional[PositiveInt] = None, max_statements: Optional[PositiveInt] = None, ) -> Union[AsyncIterator[bytes], AsyncIterator[dict]]: """Read records matching the `query` in the `target` container and yield them. @@ -526,8 +529,11 @@ async def read( # noqa: PLR0913 ignore_errors (bool): If `True`, encoding errors during the read operation will be ignored and logged. If `False` (default), a `BackendException` is raised on any error. + prefetch (int): The number of records to prefetch (queue) while yielding. + If `prefetch` is `None` it defaults to `1`, i.e. no records are + prefetched. max_statements (int): The maximum number of statements to yield. - If `None` (default), there is no maximum. + If `None` (default) or `0`, there is no maximum. Yield: dict: If `raw_output` is False. @@ -538,20 +544,50 @@ async def read( # noqa: PLR0913 during encoding records and `ignore_errors` is set to `False`. BackendParameterException: If a backend argument value is not valid. """ + prefetch = prefetch if prefetch else 1 + if prefetch < 1: + msg = "prefetch must be a strictly positive integer" + self.logger.error(msg) + raise BackendParameterException(msg) + + if prefetch > 1: + queue = Queue(prefetch - 1) + statements = self.read( + query, + target, + chunk_size, + raw_output, + ignore_errors, + None, + max_statements, + ) + task = create_task(self._queue_records(queue, statements)) + while True: + statement = await queue.get() + if statement is None: + error = task.exception() + if error: + raise error + + return + + yield statement + chunk_size = chunk_size if chunk_size else self.settings.READ_CHUNK_SIZE query = validate_backend_query(query, self.query_class, self.logger) reader = self._read_bytes if raw_output else self._read_dicts statements = reader(query, target, chunk_size, ignore_errors) - if max_statements is None: + if not max_statements: async for statement in statements: yield statement return + i = 0 async for statement in statements: - if i >= max_statements: - return yield statement i += 1 + if i >= max_statements: + return @abstractmethod async def _read_bytes( @@ -573,6 +609,20 @@ async def close(self) -> None: BackendException: If a failure occurs during the close operation. """ + async def _queue_records( + self, queue: Queue, records: Union[AsyncIterator[bytes], AsyncIterator[dict]] + ): + """Iterate over the `records` and put them into the `queue`.""" + try: + async for record in records: + await queue.put(record) + except Exception as error: + # None signals that the queue is done + await queue.put(None) + raise error + + await queue.put(None) + def get_backend_generic_argument( backend_class: Type[Union[BaseDataBackend, BaseAsyncDataBackend]], position: int diff --git a/src/ralph/backends/data/clickhouse.py b/src/ralph/backends/data/clickhouse.py index 1841b44c2..704fadf14 100755 --- a/src/ralph/backends/data/clickhouse.py +++ b/src/ralph/backends/data/clickhouse.py @@ -226,7 +226,7 @@ def read( # noqa: PLR0913 will be ignored and logged. If `False` (default), a `BackendException` is raised on any error. max_statements (int): The maximum number of statements to yield. - If `None` (default), there is no maximum. + If `None` (default) or `0`, there is no maximum. Yield: bytes: The next raw document if `raw_output` is True. diff --git a/src/ralph/backends/data/es.py b/src/ralph/backends/data/es.py index 0020a24fd..18b0beba2 100644 --- a/src/ralph/backends/data/es.py +++ b/src/ralph/backends/data/es.py @@ -215,7 +215,7 @@ def read( # noqa: PLR0913 ignore_errors (bool): No impact as encoding errors are not expected in Elasticsearch results. max_statements (int): The maximum number of statements to yield. - If `None` (default), there is no maximum. + If `None` (default) or `0`, there is no maximum. Yield: bytes: The next raw document if `raw_output` is True. diff --git a/src/ralph/backends/data/fs.py b/src/ralph/backends/data/fs.py index f111fbff1..228ec0e5f 100644 --- a/src/ralph/backends/data/fs.py +++ b/src/ralph/backends/data/fs.py @@ -172,7 +172,7 @@ def read( # noqa: PLR0913 will be ignored and logged. If `False` (default), a `BackendException` is raised on any error. max_statements (int): The maximum number of statements to yield. - If `None` (default), there is no maximum. + If `None` (default) or `0`, there is no maximum. Yield: bytes: The next chunk of the read files if `raw_output` is True. diff --git a/src/ralph/backends/data/ldp.py b/src/ralph/backends/data/ldp.py index 53b2e6d9b..3d86b52ea 100644 --- a/src/ralph/backends/data/ldp.py +++ b/src/ralph/backends/data/ldp.py @@ -170,7 +170,7 @@ def read( # noqa: PLR0913 raw_output (bool): Should always be set to `True`. ignore_errors (bool): No impact as no encoding operation is performed. max_statements (int): The maximum number of statements to yield. - If `None` (default), there is no maximum. + If `None` (default) or `0`, there is no maximum. Yield: bytes: The content of the archive matching the query. diff --git a/src/ralph/backends/data/mongo.py b/src/ralph/backends/data/mongo.py index 3f960694f..b93ed1dda 100644 --- a/src/ralph/backends/data/mongo.py +++ b/src/ralph/backends/data/mongo.py @@ -193,7 +193,7 @@ def read( # noqa: PLR0913 will be ignored and logged. If `False` (default), a `BackendException` is raised on any error. max_statements (int): The maximum number of statements to yield. - If `None` (default), there is no maximum. + If `None` (default) or `0`, there is no maximum. Yield: dict: If `raw_output` is False. diff --git a/src/ralph/backends/data/s3.py b/src/ralph/backends/data/s3.py index 3643b8ff1..15307fd0b 100644 --- a/src/ralph/backends/data/s3.py +++ b/src/ralph/backends/data/s3.py @@ -186,7 +186,7 @@ def read( # noqa: PLR0913 will be ignored and logged. If `False` (default), a `BackendException` is raised on any error. max_statements (int): The maximum number of statements to yield. - If `None` (default), there is no maximum. + If `None` (default) or `0`, there is no maximum. Yield: dict: If `raw_output` is False. diff --git a/src/ralph/backends/data/swift.py b/src/ralph/backends/data/swift.py index 006a8cc92..88245aedf 100644 --- a/src/ralph/backends/data/swift.py +++ b/src/ralph/backends/data/swift.py @@ -200,7 +200,7 @@ def read( # noqa: PLR0913 will be ignored and logged. If `False` (default), a `BackendException` is raised on any error. max_statements (int): The maximum number of statements to yield. - If `None` (default), there is no maximum. + If `None` (default) or `0`, there is no maximum. Yield: dict: If `raw_output` is False. diff --git a/tests/backends/data/test_async_es.py b/tests/backends/data/test_async_es.py index cf6e9f735..f430e78fb 100644 --- a/tests/backends/data/test_async_es.py +++ b/tests/backends/data/test_async_es.py @@ -398,13 +398,19 @@ async def mock_async_es_search(**kwargs): @pytest.mark.anyio -async def test_backends_data_async_es_read_with_raw_ouput(es, async_es_backend): +@pytest.mark.parametrize("prefetch", [1, 10]) +async def test_backends_data_async_es_read_with_raw_ouput( + prefetch, es, async_es_backend +): """Test the `AsyncESDataBackend.read` method with `raw_output` set to `True`.""" backend = async_es_backend() documents = [{"id": idx, "timestamp": now()} for idx in range(10)] assert await backend.write(documents) == 10 - hits = [statement async for statement in backend.read(raw_output=True)] + hits = [ + statement + async for statement in backend.read(raw_output=True, prefetch=prefetch) + ] for i, hit in enumerate(hits): assert isinstance(hit, bytes) assert json.loads(hit).get("_source") == documents[i] @@ -413,13 +419,16 @@ async def test_backends_data_async_es_read_with_raw_ouput(es, async_es_backend): @pytest.mark.anyio -async def test_backends_data_async_es_read_without_raw_ouput(es, async_es_backend): +@pytest.mark.parametrize("prefetch", [1, 10]) +async def test_backends_data_async_es_read_without_raw_ouput( + prefetch, es, async_es_backend +): """Test the `AsyncESDataBackend.read` method with `raw_output` set to `False`.""" backend = async_es_backend() documents = [{"id": idx, "timestamp": now()} for idx in range(10)] assert await backend.write(documents) == 10 - hits = [statement async for statement in backend.read()] + hits = [statement async for statement in backend.read(prefetch=prefetch)] for i, hit in enumerate(hits): assert isinstance(hit, dict) assert hit.get("_source") == documents[i] diff --git a/tests/backends/data/test_async_mongo.py b/tests/backends/data/test_async_mongo.py index fb3badae9..357ebc98a 100644 --- a/tests/backends/data/test_async_mongo.py +++ b/tests/backends/data/test_async_mongo.py @@ -314,9 +314,9 @@ async def test_backends_data_async_mongo_list_with_history( @pytest.mark.anyio +@pytest.mark.parametrize("prefetch", [1, 10]) async def test_backends_data_async_mongo_read_with_raw_output( - mongo, - async_mongo_backend, + prefetch, mongo, async_mongo_backend ): """Test the `AsyncMongoDataBackend.read` method with `raw_output` set to `True`.""" @@ -334,7 +334,10 @@ async def test_backends_data_async_mongo_read_with_raw_output( await backend.collection.insert_many(documents) await backend.database.foobar.insert_many(documents[:2]) - result = [statement async for statement in backend.read(raw_output=True)] + result = [ + statement + async for statement in backend.read(raw_output=True, prefetch=prefetch) + ] assert result == expected result = [ statement async for statement in backend.read(raw_output=True, target="foobar") @@ -351,8 +354,9 @@ async def test_backends_data_async_mongo_read_with_raw_output( @pytest.mark.anyio +@pytest.mark.parametrize("prefetch", [1, 10]) async def test_backends_data_async_mongo_read_without_raw_output( - mongo, async_mongo_backend + prefetch, mongo, async_mongo_backend ): """Test the `AsyncMongoDataBackend.read` method with `raw_output` set to `False`. @@ -372,7 +376,9 @@ async def test_backends_data_async_mongo_read_without_raw_output( await backend.collection.insert_many(documents) await backend.database.foobar.insert_many(documents[:2]) - assert [statement async for statement in backend.read()] == expected + assert [ + statement async for statement in backend.read(prefetch=prefetch) + ] == expected assert [statement async for statement in backend.read(target="foobar")] == expected[ :2 ] diff --git a/tests/backends/data/test_base.py b/tests/backends/data/test_base.py index f67f6b5cc..e7f8f355a 100644 --- a/tests/backends/data/test_base.py +++ b/tests/backends/data/test_base.py @@ -1,5 +1,6 @@ """Tests for the base data backend""" +import asyncio import logging from typing import Any, Union @@ -15,7 +16,7 @@ Writable, get_backend_generic_argument, ) -from ralph.exceptions import BackendParameterException +from ralph.exceptions import BackendException, BackendParameterException from ralph.utils import gather_with_limited_concurrency @@ -146,7 +147,9 @@ def close(self): backend = MockBaseDataBackend() assert list(backend.read()) == [{}, {}, {}, {}] + assert list(backend.read(max_statements=0)) == [{}, {}, {}, {}] assert list(backend.read(raw_output=True)) == [b"", b"", b""] + assert list(backend.read(max_statements=0, raw_output=True)) == [b"", b"", b""] assert list(backend.read(max_statements=9)) == [{}, {}, {}, {}] assert list(backend.read(max_statements=9, raw_output=True)) == [b"", b"", b""] @@ -154,8 +157,132 @@ def close(self): assert list(backend.read(max_statements=3)) == [{}, {}, {}] assert list(backend.read(max_statements=3, raw_output=True)) == [b"", b"", b""] - assert not list(backend.read(max_statements=0)) - assert not list(backend.read(max_statements=0, raw_output=True)) + +@pytest.mark.anyio +@pytest.mark.parametrize( + "prefetch,expected_consumed_items", + [ + # Given `prefetch` set to `None`, 0 or 1, the `read` method should consume data + # on demand. + (None, 1), # One item read -> one item consumed. + (0, 1), + (1, 1), + # Given `prefetch>1`, the `read` method should consume `prefetch` number of + # items ahead. + (2, 3), # One item read -> one item consumed + 2 items prefetched. + (3, 4), + ], +) +async def test_backends_data_base_async_read_with_prefetch( + prefetch, expected_consumed_items +): + """Test the `BaseAsyncDataBackend.read` method with `prefetch` argument.""" + consumed_items = {"count": 0} + + class MockDataBackend(BaseAsyncDataBackend[BaseDataBackendSettings, BaseQuery]): + """A class mocking the base database class.""" + + async def _read_dicts(self, *args): + """Yield 6 chunks of `chunk_size` size.""" + for _ in range(6): + consumed_items["count"] += 1 + yield {"foo": "bar"} + + async def _read_bytes(self, *args): + pass + + async def status(self): + pass + + async def close(self): + pass + + backend = MockDataBackend() + reader = backend.read(prefetch=prefetch) + assert await reader.__anext__() == {"foo": "bar"} + await asyncio.sleep(0.2) + assert consumed_items["count"] == expected_consumed_items + assert [_ async for _ in reader] == [ + {"foo": "bar"}, + {"foo": "bar"}, + {"foo": "bar"}, + {"foo": "bar"}, + {"foo": "bar"}, + ] + + +@pytest.mark.anyio +async def test_backends_data_base_async_read_with_invalid_prefetch(caplog): + """Test the `BaseAsyncDataBackend.read` method given a `prefetch` argument + that is less than `0`, should raise a `BackendParameterException`. + """ + + class MockDataBackend(BaseAsyncDataBackend[BaseDataBackendSettings, BaseQuery]): + """A class mocking the base database class.""" + + async def _read_dicts(self, *args): + pass + + async def _read_bytes(self, *args): + pass + + async def status(self): + pass + + async def close(self): + pass + + msg = "prefetch must be a strictly positive integer" + with pytest.raises(BackendParameterException, match=msg): + with caplog.at_level(logging.ERROR): + _ = [_ async for _ in MockDataBackend().read(prefetch=-1)] + + assert ("tests.backends.data.test_base", logging.ERROR, msg) in caplog.record_tuples + + +@pytest.mark.anyio +async def test_backends_data_base_async_read_with_an_error_while_prefetching(caplog): + """Test the `BaseAsyncDataBackend.read` method given a `prefech` argument and an + exception while prefetching records, should yield the remaining records before the + exception and once the last record is yielded, should re-raise the exception. + """ + consumed_items = {"count": 0} + + class MockDataBackend(BaseAsyncDataBackend[BaseDataBackendSettings, BaseQuery]): + """A class mocking the base database class.""" + + async def _read_dicts(self, *args): + for _ in range(3): + consumed_items["count"] += 1 + yield {"foo": "bar"} + + self.logger.error("connection error") + raise BackendException("connection error") + + async def _read_bytes(self, *args): + pass + + async def status(self): + pass + + async def close(self): + pass + + backend = MockDataBackend() + reader = backend.read(prefetch=10) + assert await reader.__anext__() == {"foo": "bar"} + await asyncio.sleep(0.2) + # Backend prefetched all records and catched the exception. + assert consumed_items["count"] == 3 + # Reading the remaining records. + assert await reader.__anext__() == {"foo": "bar"} + assert await reader.__anext__() == {"foo": "bar"} + msg = "connection error" + with caplog.at_level(logging.ERROR): + with pytest.raises(BackendException, match=msg): + await reader.__anext__() + + assert ("tests.backends.data.test_base", logging.ERROR, msg) in caplog.record_tuples @pytest.mark.anyio @@ -183,7 +310,13 @@ async def close(self): backend = MockAsyncBaseDataBackend() assert [_ async for _ in backend.read()] == [{}, {}, {}, {}] + assert [_ async for _ in backend.read(max_statements=0)] == [{}, {}, {}, {}] assert [_ async for _ in backend.read(raw_output=True)] == [b"", b"", b""] + assert [_ async for _ in backend.read(max_statements=0, raw_output=True)] == [ + b"", + b"", + b"", + ] assert [_ async for _ in backend.read(max_statements=9)] == [{}, {}, {}, {}] assert [_ async for _ in backend.read(max_statements=9, raw_output=True)] == [ @@ -199,10 +332,8 @@ async def close(self): b"", ] - assert not [_ async for _ in backend.read(max_statements=0)] - assert not [_ async for _ in backend.read(max_statements=0, raw_output=True)] - +@pytest.mark.anyio @pytest.mark.parametrize( "chunk_size,concurrency,expected_item_count,expected_write_calls", [ @@ -228,7 +359,6 @@ async def close(self): (1, 20, {1}, 4), ], ) -@pytest.mark.anyio async def test_backends_data_base_async_write_with_concurrency( chunk_size, concurrency, expected_item_count, expected_write_calls, monkeypatch ): diff --git a/tests/backends/data/test_s3.py b/tests/backends/data/test_s3.py index 8f9062229..1f423a6da 100644 --- a/tests/backends/data/test_s3.py +++ b/tests/backends/data/test_s3.py @@ -278,12 +278,7 @@ def test_backends_data_s3_read_with_valid_name_should_write_to_history( "timestamp": freezed_now, } in backend.history - list( - backend.read( - query="2022-09-30.gz", - raw_output=False, - ) - ) + list(backend.read(query="2022-09-30.gz", raw_output=False)) assert { "backend": "s3",