From 9936a108242ae613af02577938656e110787b5b0 Mon Sep 17 00:00:00 2001 From: Peter Bull Date: Fri, 29 Dec 2023 10:59:12 -0800 Subject: [PATCH] Implement sliced downloads in GSClient (#389) (#391) * WIP: Implement sliced downloads in GSClient (#389) * feat: Implement sliced downloads in GSClient * fix: remove unintended import changes * Mock transfer_manager. Test both worker types. * Update HISTORY.md --------- Co-authored-by: Joe O'Connor <60386246+joconnor-ecaa@users.noreply.github.com> --- HISTORY.md | 3 +++ cloudpathlib/gs/gsclient.py | 15 ++++++++++++++- tests/conftest.py | 6 ++++++ tests/mock_clients/mock_gs.py | 31 +++++++++++++++++++++++++++++++ tests/test_gs_specific.py | 10 ++++++++++ 5 files changed, 64 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 8c6c6fda..051266ba 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,8 @@ # cloudpathlib Changelog +## UNRELEASED +- Implement sliced downloads in GSClient. (Issue [#387](https://github.com/drivendataorg/cloudpathlib/issues/387), PR [#389](https://github.com/drivendataorg/cloudpathlib/pull/389)) + ## 0.17.0 (2023-12-21) - Fix `S3Client` cleanup via `Client.__del__` when `S3Client` encounters an exception during initialization. (Issue [#372](https://github.com/drivendataorg/cloudpathlib/issues/372), PR [#373](https://github.com/drivendataorg/cloudpathlib/pull/373), thanks to [@bryanwweber](https://github.com/bryanwweber)) diff --git a/cloudpathlib/gs/gsclient.py b/cloudpathlib/gs/gsclient.py index 6b924263..d75fd714 100644 --- a/cloudpathlib/gs/gsclient.py +++ b/cloudpathlib/gs/gsclient.py @@ -16,6 +16,7 @@ from google.api_core.exceptions import NotFound from google.auth.exceptions import DefaultCredentialsError from google.cloud.storage import Client as StorageClient + from google.cloud.storage import transfer_manager except ModuleNotFoundError: @@ -39,6 +40,7 @@ def __init__( file_cache_mode: Optional[Union[str, FileCacheMode]] = None, local_cache_dir: Optional[Union[str, os.PathLike]] = None, content_type_method: Optional[Callable] = mimetypes.guess_type, + download_chunks_concurrently_kwargs: Optional[Dict[str, Any]] = None, ): """Class constructor. Sets up a [`Storage Client`](https://googleapis.dev/python/storage/latest/client.html). @@ -76,6 +78,9 @@ def __init__( the `CLOUDPATHLIB_LOCAL_CACHE_DIR` environment variable. content_type_method (Optional[Callable]): Function to call to guess media type (mimetype) when writing a file to the cloud. Defaults to `mimetypes.guess_type`. Must return a tuple (content type, content encoding). + download_chunks_concurrently_kwargs (Optional[Dict[str, Any]]): Keyword arguments to pass to + [`download_chunks_concurrently`](https://cloud.google.com/python/docs/reference/storage/latest/google.cloud.storage.transfer_manager#google_cloud_storage_transfer_manager_download_chunks_concurrently) + for sliced parallel downloads. """ if application_credentials is None: application_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS") @@ -92,6 +97,8 @@ def __init__( except DefaultCredentialsError: self.client = StorageClient.create_anonymous_client() + self.download_chunks_concurrently_kwargs = download_chunks_concurrently_kwargs + super().__init__( local_cache_dir=local_cache_dir, content_type_method=content_type_method, @@ -118,7 +125,13 @@ def _download_file(self, cloud_path: GSPath, local_path: Union[str, os.PathLike] local_path = Path(local_path) - blob.download_to_filename(local_path) + if self.download_chunks_concurrently_kwargs is not None: + transfer_manager.download_chunks_concurrently( + blob, local_path, **self.download_chunks_concurrently_kwargs + ) + else: + blob.download_to_filename(local_path) + return local_path def _is_file_or_dir(self, cloud_path: GSPath) -> Optional[str]: diff --git a/tests/conftest.py b/tests/conftest.py index 30cebd21..76c4bad5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,7 @@ from .mock_clients.mock_gs import ( mocked_client_class_factory as mocked_gsclient_class_factory, DEFAULT_GS_BUCKET_NAME, + MockTransferManager, ) from .mock_clients.mock_s3 import mocked_session_class_factory, DEFAULT_S3_BUCKET_NAME @@ -184,6 +185,11 @@ def gs_rig(request, monkeypatch, assets_dir): "StorageClient", mocked_gsclient_class_factory(test_dir), ) + monkeypatch.setattr( + cloudpathlib.gs.gsclient, + "transfer_manager", + MockTransferManager, + ) rig = CloudProviderTestRig( path_class=GSPath, diff --git a/tests/mock_clients/mock_gs.py b/tests/mock_clients/mock_gs.py index fea1e63f..93056d7e 100644 --- a/tests/mock_clients/mock_gs.py +++ b/tests/mock_clients/mock_gs.py @@ -65,6 +65,21 @@ def patch(self): if "updated" in self.metadata: (self.bucket / self.name).touch() + def reload( + self, + client=None, + projection="noAcl", + if_etag_match=None, + if_etag_not_match=None, + if_generation_match=None, + if_generation_not_match=None, + if_metageneration_match=None, + if_metageneration_not_match=None, + timeout=None, + retry=None, + ): + pass + def upload_from_filename(self, filename, content_type=None): data = Path(filename).read_bytes() path = self.bucket / self.name @@ -153,3 +168,19 @@ def __next__(self): @property def prefixes(self): return self.sub_directories + + +class MockTransferManager: + @staticmethod + def download_chunks_concurrently( + blob, + filename, + chunk_size=32 * 1024 * 1024, + download_kwargs=None, + deadline=None, + worker_type="process", + max_workers=8, + *, + crc32c_checksum=True, + ): + blob.download_to_filename(filename) diff --git a/tests/test_gs_specific.py b/tests/test_gs_specific.py index 83775ac8..a851abb3 100644 --- a/tests/test_gs_specific.py +++ b/tests/test_gs_specific.py @@ -13,3 +13,13 @@ def test_gspath_properties(path_class): p2 = path_class("gs://mybucket/") assert p2.blob == "" assert p2.bucket == "mybucket" + + +@pytest.mark.parametrize("worker_type", ["process", "thread"]) +def test_concurrent_download(gs_rig, tmp_path, worker_type): + client = gs_rig.client_class(download_chunks_concurrently_kwargs={"worker_type": worker_type}) + p = gs_rig.create_cloud_path("dir_0/file0_0.txt", client=client) + dl_dir = tmp_path + assert not (dl_dir / p.name).exists() + p.download_to(dl_dir) + assert (dl_dir / p.name).is_file()