diff --git a/cset-workflow/app/fetch_fcst/bin/fetch-data-http.py b/cset-workflow/app/fetch_fcst/bin/fetch-data-http.py index 154504e01..39085bb44 100644 --- a/cset-workflow/app/fetch_fcst/bin/fetch-data-http.py +++ b/cset-workflow/app/fetch_fcst/bin/fetch-data-http.py @@ -2,32 +2,6 @@ """Retrieve files via HTTP.""" -import ssl -import urllib.parse -import urllib.request - -from CSET._workflow_utils.fetch_data import FileRetrieverABC, fetch_data - - -class HTTPFileRetriever(FileRetrieverABC): - """Retrieve files via HTTP.""" - - def get_file(self, file_path: str, output_dir: str) -> None: - """Save a file from a HTTP address to the output directory. - - Parameters - ---------- - file_path: str - Path of the file to copy on MASS. It may contain patterns - like globs, which will be expanded in a system specific manner. - output_dir: str - Path to filesystem directory into which the file should be copied. - """ - ctx = ssl.create_default_context() - save_path = urllib.parse.urlparse(file_path).path.split("/")[-1] - with urllib.request.urlopen(file_path, output_dir, context=ctx) as response: - with open(save_path, "wb") as fp: - fp.write(response.read()) - +from CSET._workflow_utils.fetch_data import HTTPFileRetriever, fetch_data fetch_data(HTTPFileRetriever) diff --git a/src/CSET/_workflow_utils/fetch_data.py b/src/CSET/_workflow_utils/fetch_data.py index 8392cb835..ce4ecb71e 100755 --- a/src/CSET/_workflow_utils/fetch_data.py +++ b/src/CSET/_workflow_utils/fetch_data.py @@ -7,7 +7,10 @@ import logging import os import shutil +import ssl import sys +import urllib.parse +import urllib.request from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta from typing import Literal @@ -20,7 +23,7 @@ class FileRetrieverABC(abc.ABC): - """Abstract class for retrieving files from a data source. + """Abstract base class for retrieving files from a data source. The `get_file` method must be defined. Optionally the __enter__ and __exit__ methods maybe be overridden to add setup or cleanup code. @@ -86,6 +89,29 @@ def get_file(self, file_path: str, output_dir: str) -> None: logging.warning("Failed to copy %s, error: %s", file, err) +class HTTPFileRetriever(FileRetrieverABC): + """Retrieve files via HTTP.""" + + def get_file(self, file_path: str, output_dir: str) -> None: + """Save a file from a HTTP address to the output directory. + + Parameters + ---------- + file_path: str + Path of the file to copy on MASS. It may contain patterns like + globs, which will be expanded in a system specific manner. + output_dir: str + Path to filesystem directory into which the file should be copied. + """ + ctx = ssl.create_default_context() + save_path = urllib.parse.urlparse(file_path).path.split("/")[-1] + with urllib.request.urlopen(file_path, output_dir, context=ctx) as response: + with open(save_path, "wb") as fp: + # Read in 1 MiB chunks so data doesn't all have to be in memory. + while data := response.read(1024 * 1024): + fp.write(data) + + def _get_needed_environment_variables() -> dict: """Load the needed variables from the environment.""" # Python 3.10 and older don't fully support ISO 8601 datetime formats. diff --git a/tests/workflow_utils/test_fetch_data.py b/tests/workflow_utils/test_fetch_data.py index 6fbe55779..e7b000504 100644 --- a/tests/workflow_utils/test_fetch_data.py +++ b/tests/workflow_utils/test_fetch_data.py @@ -15,6 +15,7 @@ """Tests for fetch_data workflow utility.""" import datetime +import hashlib from pathlib import Path import pytest @@ -186,3 +187,16 @@ def test_FilesystemFileRetriever_copy_error(caplog): log_record = caplog.records[0] assert log_record.levelname == "WARNING" assert log_record.message.startswith("Failed to copy") + + +def test_HTTPFileRetriever(tmp_path): + """Test retrieving a file via HTTP.""" + url = "https://github.com/MetOffice/CSET/raw/48dc1d29846604aacb8d370b82bca31405931c87/tests/test_data/exeter_em01.nc" + with fetch_data.HTTPFileRetriever() as hfr: + hfr.get_file(url, str(tmp_path)) + file = tmp_path / "exeter_em01.nc" + assert file.is_file() + # Check file hash is correct, indicating a non-corrupt download. + expected_hash = "67899970eeca75b9378f0275ce86db3d1d613f2bc7a178540912848dc8a69ca7" + actual_hash = hashlib.sha256(file.read_bytes()).hexdigest() + assert actual_hash == expected_hash