Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add file_path param to download_file() to directly download files to disk #115

Merged
merged 5 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/aleph/sdk/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,30 @@ async def get_posts_iterator(
yield post

@abstractmethod
async def download_file(
self,
file_hash: str,
) -> bytes:
async def download_file(self, file_hash: str) -> bytes:
"""
Get a file from the storage engine as raw bytes.

Warning: Downloading large files can be slow and memory intensive.
Warning: Downloading large files can be slow and memory intensive. Use `download_file_to()` to download them directly to disk instead.

:param file_hash: The hash of the file to retrieve.
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")

@abstractmethod
async def download_file_to_path(
self,
file_hash: str,
path: Union[Path, str],
) -> Path:
"""
Download a file from the storage engine to given path.

:param file_hash: The hash of the file to retrieve.
:param path: The path to which the file should be saved.
"""
raise NotImplementedError()

async def download_file_ipfs(
self,
file_hash: str,
Expand Down
33 changes: 28 additions & 5 deletions src/aleph/sdk/client/http.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import logging
import os.path
import ssl
from io import BytesIO
from pathlib import Path
from typing import Any, AsyncIterable, Dict, Iterable, List, Optional, Type, Union

import aiohttp
Expand Down Expand Up @@ -206,21 +208,42 @@ async def download_file_ipfs_to_buffer(
else:
response.raise_for_status()

async def download_file(
self,
file_hash: str,
) -> bytes:
async def download_file(self, file_hash: str) -> bytes:
"""
Get a file from the storage engine as raw bytes.

Warning: Downloading large files can be slow and memory intensive.
Warning: Downloading large files can be slow and memory intensive. Use `download_file_to()` to download them directly to disk instead.

:param file_hash: The hash of the file to retrieve.
"""
buffer = BytesIO()
await self.download_file_to_buffer(file_hash, output_buffer=buffer)
return buffer.getvalue()

async def download_file_to_path(
self,
file_hash: str,
path: Union[Path, str],
) -> Path:
"""
Download a file from the storage engine to given path.

:param file_hash: The hash of the file to retrieve.
:param path: The path to which the file should be saved.
"""
if not isinstance(path, Path):
path = Path(path)

if not os.path.exists(path):
dir_path = os.path.dirname(path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)

with open(path, "wb") as file_buffer:
await self.download_file_to_buffer(file_hash, output_buffer=file_buffer)

return path

async def download_file_ipfs(
self,
file_hash: str,
Expand Down
51 changes: 46 additions & 5 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional, Union
from unittest.mock import AsyncMock, MagicMock

import pytest as pytest
Expand Down Expand Up @@ -195,19 +196,59 @@
return client


def make_custom_mock_response(resp_json, status=200) -> MockResponse:
import asyncio
from functools import wraps


def async_wrap(cls):
class AsyncWrapper:
def __init__(self, *args, **kwargs):
self._instance = cls(*args, **kwargs)

def __getattr__(self, item):
attr = getattr(self._instance, item)
if callable(attr):

@wraps(attr)
async def method(*args, **kwargs):
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, attr, *args, **kwargs)

return method
return attr

Check warning on line 218 in tests/unit/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/unit/conftest.py#L218

Added line #L218 was not covered by tests

return AsyncWrapper


AsyncBytesIO = async_wrap(BytesIO)


def make_custom_mock_response(
resp: Union[Dict[str, Any], bytes], status=200
) -> MockResponse:
class CustomMockResponse(MockResponse):
content: Optional[AsyncBytesIO]

async def json(self):
return resp_json
return resp

@property
def status(self):
return status

return CustomMockResponse(sync=True)
mock = CustomMockResponse(sync=True)

try:
mock.content = AsyncBytesIO(resp)
except Exception as e:
print(e)

return mock


def make_mock_get_session(get_return_value: Dict[str, Any]) -> AlephHttpClient:
def make_mock_get_session(
get_return_value: Union[Dict[str, Any], bytes]
) -> AlephHttpClient:
class MockHttpSession(AsyncMock):
def get(self, *_args, **_kwargs):
return make_custom_mock_response(get_return_value)
Expand Down
52 changes: 43 additions & 9 deletions tests/unit/test_download.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, patch

import pytest

from aleph.sdk import AlephHttpClient
from aleph.sdk.conf import settings as sdk_settings

from .conftest import make_mock_get_session


def make_mock_download_client(item_hash: str) -> AlephHttpClient:
if item_hash == "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH":
return make_mock_get_session(b"test\n")
if item_hash == "Qmdy5LaAL4eghxE7JD6Ah5o4PJGarjAV9st8az2k52i1vq":
return make_mock_get_session(bytes(5817703))
raise NotImplementedError


@pytest.mark.parametrize(
Expand All @@ -13,10 +26,30 @@
)
@pytest.mark.asyncio
async def test_download(file_hash: str, expected_size: int):
async with AlephHttpClient(api_server=sdk_settings.API_HOST) as client:
file_content = await client.download_file(file_hash) # File is 5B
file_size = len(file_content)
assert file_size == expected_size
mock_download_client = make_mock_download_client(file_hash)
async with mock_download_client:
file_content = await mock_download_client.download_file(file_hash)
file_size = len(file_content)
assert file_size == expected_size


@pytest.mark.asyncio
async def test_download_to_file():
MHHukiewitz marked this conversation as resolved.
Show resolved Hide resolved
file_hash = "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH"
mock_download_client = make_mock_download_client(file_hash)
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
download_path = temp_dir_path / "test.txt"

async with mock_download_client:
returned_path = await mock_download_client.download_file_to_path(
file_hash, str(download_path)
)

assert returned_path == download_path
assert download_path.is_file()
with open(download_path, "r") as file:
assert file.read().strip() == "test"


@pytest.mark.parametrize(
Expand All @@ -28,7 +61,8 @@ async def test_download(file_hash: str, expected_size: int):
)
@pytest.mark.asyncio
async def test_download_ipfs(file_hash: str, expected_size: int):
async with AlephHttpClient(api_server=sdk_settings.API_HOST) as client:
file_content = await client.download_file_ipfs(file_hash) # 5817703 B FILE
file_size = len(file_content)
assert file_size == expected_size
mock_download_client = make_mock_download_client(file_hash)
async with mock_download_client:
file_content = await mock_download_client.download_file_ipfs(file_hash)
file_size = len(file_content)
assert file_size == expected_size
Loading