Skip to content

Commit

Permalink
✨(backends) add prefetch option for data backends
Browse files Browse the repository at this point in the history
We try to align the async data backend interface with the http
backends. The `prefetch` option allows the caller to greedily read
`prefetch` number of records before they are yielded by the generator.
  • Loading branch information
SergioSim committed Dec 12, 2023
1 parent 97fa64f commit 8983c30
Show file tree
Hide file tree
Showing 14 changed files with 250 additions and 40 deletions.
14 changes: 12 additions & 2 deletions src/ralph/backends/data/async_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ async def read( # noqa: PLR0913
chunk_size: Optional[int] = None,
raw_output: bool = False,
ignore_errors: bool = False,
prefetch: Optional[PositiveInt] = None,
max_statements: Optional[PositiveInt] = None,
) -> Union[AsyncIterator[bytes], AsyncIterator[dict]]:
"""Read documents matching the query in the target index and yield them.
Expand All @@ -130,8 +131,11 @@ async def read( # noqa: PLR0913
raw_output (bool): Controls whether to yield dictionaries or bytes.
ignore_errors (bool): No impact as encoding errors are not expected in
Elasticsearch results.
prefetch (int): The number of records to prefetch (queue) while yielding.
If `prefetch` is `None` it defaults to `1`, i.e. no records are
prefetched.
max_statements (int): The maximum number of statements to yield.
If `None` (default), there is no maximum.
If `None` (default) or `0`, there is no maximum.
Yield:
bytes: The next raw document if `raw_output` is True.
Expand All @@ -141,7 +145,13 @@ async def read( # noqa: PLR0913
BackendException: If a failure occurs during Elasticsearch connection.
"""
statements = super().read(
query, target, chunk_size, raw_output, ignore_errors, max_statements
query,
target,
chunk_size,
raw_output,
ignore_errors,
prefetch,
max_statements,
)
async for statement in statements:
yield statement
Expand Down
14 changes: 12 additions & 2 deletions src/ralph/backends/data/async_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ async def read( # noqa: PLR0913
chunk_size: Optional[int] = None,
raw_output: bool = False,
ignore_errors: bool = False,
prefetch: Optional[PositiveInt] = None,
max_statements: Optional[PositiveInt] = None,
) -> Union[AsyncIterator[bytes], AsyncIterator[dict]]:
"""Read documents matching the `query` from `target` collection and yield them.
Expand All @@ -137,8 +138,11 @@ async def read( # noqa: PLR0913
ignore_errors (bool): If `True`, encoding errors during the read operation
will be ignored and logged.
If `False` (default), a `BackendException` is raised on any error.
prefetch (int): The number of records to prefetch (queue) while yielding.
If `prefetch` is `None` it defaults to `1`, i.e. no records are
prefetched.
max_statements (int): The maximum number of statements to yield.
If `None` (default), there is no maximum.
If `None` (default) or `0`, there is no maximum.
Yield:
bytes: The next raw document if `raw_output` is True.
Expand All @@ -150,7 +154,13 @@ async def read( # noqa: PLR0913
BackendParameterException: If the `target` is not a valid collection name.
"""
statements = super().read(
query, target, chunk_size, raw_output, ignore_errors, max_statements
query,
target,
chunk_size,
raw_output,
ignore_errors,
prefetch,
max_statements,
)
async for statement in statements:
yield statement
Expand Down
64 changes: 57 additions & 7 deletions src/ralph/backends/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from abc import ABC, abstractmethod
from asyncio import Queue, create_task
from enum import Enum, unique
from functools import cached_property
from io import IOBase
Expand Down Expand Up @@ -299,7 +300,7 @@ def read( # noqa: PLR0913
will be ignored and logged.
If `False` (default), a `BackendException` is raised on any error.
max_statements (int): The maximum number of statements to yield.
If `None` (default), there is no maximum.
If `None` (default) or `0`, there is no maximum.
Yield:
dict: If `raw_output` is False.
Expand All @@ -314,14 +315,15 @@ def read( # noqa: PLR0913
query = validate_backend_query(query, self.query_class, self.logger)
reader = self._read_bytes if raw_output else self._read_dicts
statements = reader(query, target, chunk_size, ignore_errors)
if max_statements is None:
if not max_statements:
yield from statements
return

max_statements -= 1
for i, statement in enumerate(statements):
yield statement
if i >= max_statements:
return
yield statement

@abstractmethod
def _read_bytes(
Expand Down Expand Up @@ -507,6 +509,7 @@ async def read( # noqa: PLR0913
chunk_size: Optional[int] = None,
raw_output: bool = False,
ignore_errors: bool = False,
prefetch: Optional[PositiveInt] = None,
max_statements: Optional[PositiveInt] = None,
) -> Union[AsyncIterator[bytes], AsyncIterator[dict]]:
"""Read records matching the `query` in the `target` container and yield them.
Expand All @@ -526,8 +529,11 @@ async def read( # noqa: PLR0913
ignore_errors (bool): If `True`, encoding errors during the read operation
will be ignored and logged.
If `False` (default), a `BackendException` is raised on any error.
prefetch (int): The number of records to prefetch (queue) while yielding.
If `prefetch` is `None` it defaults to `1`, i.e. no records are
prefetched.
max_statements (int): The maximum number of statements to yield.
If `None` (default), there is no maximum.
If `None` (default) or `0`, there is no maximum.
Yield:
dict: If `raw_output` is False.
Expand All @@ -538,20 +544,50 @@ async def read( # noqa: PLR0913
during encoding records and `ignore_errors` is set to `False`.
BackendParameterException: If a backend argument value is not valid.
"""
prefetch = prefetch if prefetch else 1
if prefetch < 1:
msg = "prefetch must be a strictly positive integer"
self.logger.error(msg)
raise BackendParameterException(msg)

if prefetch > 1:
queue = Queue(prefetch - 1)
statements = self.read(
query,
target,
chunk_size,
raw_output,
ignore_errors,
None,
max_statements,
)
task = create_task(self._queue_records(queue, statements))
while True:
statement = await queue.get()
if statement is None:
error = task.exception()
if error:
raise error

return

yield statement

chunk_size = chunk_size if chunk_size else self.settings.READ_CHUNK_SIZE
query = validate_backend_query(query, self.query_class, self.logger)
reader = self._read_bytes if raw_output else self._read_dicts
statements = reader(query, target, chunk_size, ignore_errors)
if max_statements is None:
if not max_statements:
async for statement in statements:
yield statement
return

i = 0
async for statement in statements:
if i >= max_statements:
return
yield statement
i += 1
if i >= max_statements:
return

@abstractmethod
async def _read_bytes(
Expand All @@ -573,6 +609,20 @@ async def close(self) -> None:
BackendException: If a failure occurs during the close operation.
"""

async def _queue_records(
self, queue: Queue, records: Union[AsyncIterator[bytes], AsyncIterator[dict]]
):
"""Iterate over the `records` and put them into the `queue`."""
try:
async for record in records:
await queue.put(record)
except Exception as error:
# None signals that the queue is done
await queue.put(None)
raise error

await queue.put(None)


def get_backend_generic_argument(
backend_class: Type[Union[BaseDataBackend, BaseAsyncDataBackend]], position: int
Expand Down
2 changes: 1 addition & 1 deletion src/ralph/backends/data/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def read( # noqa: PLR0913
will be ignored and logged.
If `False` (default), a `BackendException` is raised on any error.
max_statements (int): The maximum number of statements to yield.
If `None` (default), there is no maximum.
If `None` (default) or `0`, there is no maximum.
Yield:
bytes: The next raw document if `raw_output` is True.
Expand Down
2 changes: 1 addition & 1 deletion src/ralph/backends/data/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def read( # noqa: PLR0913
ignore_errors (bool): No impact as encoding errors are not expected in
Elasticsearch results.
max_statements (int): The maximum number of statements to yield.
If `None` (default), there is no maximum.
If `None` (default) or `0`, there is no maximum.
Yield:
bytes: The next raw document if `raw_output` is True.
Expand Down
2 changes: 1 addition & 1 deletion src/ralph/backends/data/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def read( # noqa: PLR0913
will be ignored and logged.
If `False` (default), a `BackendException` is raised on any error.
max_statements (int): The maximum number of statements to yield.
If `None` (default), there is no maximum.
If `None` (default) or `0`, there is no maximum.
Yield:
bytes: The next chunk of the read files if `raw_output` is True.
Expand Down
2 changes: 1 addition & 1 deletion src/ralph/backends/data/ldp.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def read( # noqa: PLR0913
raw_output (bool): Should always be set to `True`.
ignore_errors (bool): No impact as no encoding operation is performed.
max_statements (int): The maximum number of statements to yield.
If `None` (default), there is no maximum.
If `None` (default) or `0`, there is no maximum.
Yield:
bytes: The content of the archive matching the query.
Expand Down
2 changes: 1 addition & 1 deletion src/ralph/backends/data/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def read( # noqa: PLR0913
will be ignored and logged.
If `False` (default), a `BackendException` is raised on any error.
max_statements (int): The maximum number of statements to yield.
If `None` (default), there is no maximum.
If `None` (default) or `0`, there is no maximum.
Yield:
dict: If `raw_output` is False.
Expand Down
2 changes: 1 addition & 1 deletion src/ralph/backends/data/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def read( # noqa: PLR0913
will be ignored and logged.
If `False` (default), a `BackendException` is raised on any error.
max_statements (int): The maximum number of statements to yield.
If `None` (default), there is no maximum.
If `None` (default) or `0`, there is no maximum.
Yield:
dict: If `raw_output` is False.
Expand Down
2 changes: 1 addition & 1 deletion src/ralph/backends/data/swift.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def read( # noqa: PLR0913
will be ignored and logged.
If `False` (default), a `BackendException` is raised on any error.
max_statements (int): The maximum number of statements to yield.
If `None` (default), there is no maximum.
If `None` (default) or `0`, there is no maximum.
Yield:
dict: If `raw_output` is False.
Expand Down
17 changes: 13 additions & 4 deletions tests/backends/data/test_async_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,19 @@ async def mock_async_es_search(**kwargs):


@pytest.mark.anyio
async def test_backends_data_async_es_read_with_raw_ouput(es, async_es_backend):
@pytest.mark.parametrize("prefetch", [1, 10])
async def test_backends_data_async_es_read_with_raw_ouput(
prefetch, es, async_es_backend
):
"""Test the `AsyncESDataBackend.read` method with `raw_output` set to `True`."""

backend = async_es_backend()
documents = [{"id": idx, "timestamp": now()} for idx in range(10)]
assert await backend.write(documents) == 10
hits = [statement async for statement in backend.read(raw_output=True)]
hits = [
statement
async for statement in backend.read(raw_output=True, prefetch=prefetch)
]
for i, hit in enumerate(hits):
assert isinstance(hit, bytes)
assert json.loads(hit).get("_source") == documents[i]
Expand All @@ -413,13 +419,16 @@ async def test_backends_data_async_es_read_with_raw_ouput(es, async_es_backend):


@pytest.mark.anyio
async def test_backends_data_async_es_read_without_raw_ouput(es, async_es_backend):
@pytest.mark.parametrize("prefetch", [1, 10])
async def test_backends_data_async_es_read_without_raw_ouput(
prefetch, es, async_es_backend
):
"""Test the `AsyncESDataBackend.read` method with `raw_output` set to `False`."""

backend = async_es_backend()
documents = [{"id": idx, "timestamp": now()} for idx in range(10)]
assert await backend.write(documents) == 10
hits = [statement async for statement in backend.read()]
hits = [statement async for statement in backend.read(prefetch=prefetch)]
for i, hit in enumerate(hits):
assert isinstance(hit, dict)
assert hit.get("_source") == documents[i]
Expand Down
16 changes: 11 additions & 5 deletions tests/backends/data/test_async_mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,9 @@ async def test_backends_data_async_mongo_list_with_history(


@pytest.mark.anyio
@pytest.mark.parametrize("prefetch", [1, 10])
async def test_backends_data_async_mongo_read_with_raw_output(
mongo,
async_mongo_backend,
prefetch, mongo, async_mongo_backend
):
"""Test the `AsyncMongoDataBackend.read` method with `raw_output` set to `True`."""

Expand All @@ -334,7 +334,10 @@ async def test_backends_data_async_mongo_read_with_raw_output(
await backend.collection.insert_many(documents)
await backend.database.foobar.insert_many(documents[:2])

result = [statement async for statement in backend.read(raw_output=True)]
result = [
statement
async for statement in backend.read(raw_output=True, prefetch=prefetch)
]
assert result == expected
result = [
statement async for statement in backend.read(raw_output=True, target="foobar")
Expand All @@ -351,8 +354,9 @@ async def test_backends_data_async_mongo_read_with_raw_output(


@pytest.mark.anyio
@pytest.mark.parametrize("prefetch", [1, 10])
async def test_backends_data_async_mongo_read_without_raw_output(
mongo, async_mongo_backend
prefetch, mongo, async_mongo_backend
):
"""Test the `AsyncMongoDataBackend.read` method with `raw_output` set to
`False`.
Expand All @@ -372,7 +376,9 @@ async def test_backends_data_async_mongo_read_without_raw_output(
await backend.collection.insert_many(documents)
await backend.database.foobar.insert_many(documents[:2])

assert [statement async for statement in backend.read()] == expected
assert [
statement async for statement in backend.read(prefetch=prefetch)
] == expected
assert [statement async for statement in backend.read(target="foobar")] == expected[
:2
]
Expand Down
Loading

0 comments on commit 8983c30

Please sign in to comment.