diff --git a/src/CSET/_workflow_utils/fetch_data.py b/src/CSET/_workflow_utils/fetch_data.py index 82a83f3e3..7429d9d25 100755 --- a/src/CSET/_workflow_utils/fetch_data.py +++ b/src/CSET/_workflow_utils/fetch_data.py @@ -4,6 +4,7 @@ import abc import glob +import itertools import logging import os import shutil @@ -44,7 +45,7 @@ def __exit__(self, exc_type, exc_value, traceback): logging.debug("Tearing down FileRetriever.") @abc.abstractmethod - def get_file(self, file_path: str, output_dir: str) -> None: # pragma: no cover + def get_file(self, file_path: str, output_dir: str) -> bool: # pragma: no cover """Save a file from the data source to the output directory. Not all of the given paths will exist, so FileNotFoundErrors should be @@ -60,6 +61,11 @@ def get_file(self, file_path: str, output_dir: str) -> None: # pragma: no cover like globs, which will be expanded in a system specific manner. output_dir: str Path to filesystem directory into which the file should be copied. + + Returns + ------- + bool: + True if files were transferred, otherwise False. """ raise NotImplementedError @@ -67,7 +73,7 @@ def get_file(self, file_path: str, output_dir: str) -> None: # pragma: no cover class FilesystemFileRetriever(FileRetrieverABC): """Retrieve files from the filesystem.""" - def get_file(self, file_path: str, output_dir: str) -> None: + def get_file(self, file_path: str, output_dir: str) -> bool: """Save a file from the filesystem to the output directory. Parameters @@ -77,22 +83,30 @@ def get_file(self, file_path: str, output_dir: str) -> None: like globs, which will be expanded in a system specific manner. output_dir: str Path to filesystem directory into which the file should be copied. + + Returns + ------- + bool: + True if files were transferred, otherwise False. """ file_paths = glob.glob(os.path.expanduser(file_path)) logging.debug("Copying files:\n%s", "\n".join(file_paths)) if not file_paths: logging.warning("file_path does not match any files: %s", file_path) + any_files_copied = False for file in file_paths: try: shutil.copy(file, output_dir) + any_files_copied = True except OSError as err: logging.warning("Failed to copy %s, error: %s", file, err) + return any_files_copied class HTTPFileRetriever(FileRetrieverABC): """Retrieve files via HTTP.""" - def get_file(self, file_path: str, output_dir: str) -> None: + def get_file(self, file_path: str, output_dir: str) -> bool: """Save a file from a HTTP address to the output directory. Parameters @@ -102,12 +116,18 @@ def get_file(self, file_path: str, output_dir: str) -> None: globs, which will be expanded in a system specific manner. output_dir: str Path to filesystem directory into which the file should be copied. + + Returns + ------- + bool: + True if files were transferred, otherwise False. """ ctx = ssl.create_default_context() save_path = ( f"{output_dir.removesuffix('/')}/" + urllib.parse.urlparse(file_path).path.split("/")[-1] ) + any_files_copied = False try: with urllib.request.urlopen(file_path, timeout=30, context=ctx) as response: if response.status != 200: @@ -116,8 +136,10 @@ def get_file(self, file_path: str, output_dir: str) -> None: # Read in 1 MiB chunks so data needn't fit in memory. while data := response.read(1024 * 1024): fp.write(data) + any_files_copied = True except OSError as err: logging.warning("Failed to retrieve %s, error: %s", file_path, err) + return any_files_copied def _get_needed_environment_variables() -> dict: @@ -209,6 +231,11 @@ def fetch_data(file_retriever: FileRetrieverABC = FilesystemFileRetriever): ---------- file_retriever: FileRetriever FileRetriever implementation to use. Defaults to FilesystemFileRetriever. + + Raises + ------ + FileNotFound: + If no files are found for the model, across all tried paths. """ v = _get_needed_environment_variables() @@ -230,5 +257,11 @@ def fetch_data(file_retriever: FileRetrieverABC = FilesystemFileRetriever): # Use file retriever to transfer data with multiple threads. with file_retriever() as retriever, ThreadPoolExecutor() as executor: - for path in paths: - executor.submit(retriever.get_file, path, cycle_data_dir) + files_found = any( + executor.map(retriever.get_file, paths, itertools.repeat(cycle_data_dir)) + ) + # We don't need to exhause the iterator, as all futures are submitted + # before map yields anything. Therefore they will all be resolved upon + # exiting the with block. + if not files_found: + raise FileNotFoundError("No files found for model!") diff --git a/tests/workflow_utils/test_fetch_data.py b/tests/workflow_utils/test_fetch_data.py index ffe9432fe..3b3f51d0b 100644 --- a/tests/workflow_utils/test_fetch_data.py +++ b/tests/workflow_utils/test_fetch_data.py @@ -106,6 +106,7 @@ class MockFileRetriever(fetch_data.FileRetrieverABC): def get_file(self, file_path: str, output_dir: str) -> None: nonlocal files_gotten files_gotten = True + return True monkeypatch.setattr( fetch_data,