Skip to content

Commit

Permalink
Move HTTPFileRetriever into fetch_data module
Browse files Browse the repository at this point in the history
Also add a test.
  • Loading branch information
jfrost-mo committed Sep 16, 2024
1 parent 78a5abd commit 5f6e3f5
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 28 deletions.
28 changes: 1 addition & 27 deletions cset-workflow/app/fetch_fcst/bin/fetch-data-http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,6 @@

"""Retrieve files via HTTP."""

import ssl
import urllib.parse
import urllib.request

from CSET._workflow_utils.fetch_data import FileRetrieverABC, fetch_data


class HTTPFileRetriever(FileRetrieverABC):
"""Retrieve files via HTTP."""

def get_file(self, file_path: str, output_dir: str) -> None:
"""Save a file from a HTTP address to the output directory.
Parameters
----------
file_path: str
Path of the file to copy on MASS. It may contain patterns
like globs, which will be expanded in a system specific manner.
output_dir: str
Path to filesystem directory into which the file should be copied.
"""
ctx = ssl.create_default_context()
save_path = urllib.parse.urlparse(file_path).path.split("/")[-1]
with urllib.request.urlopen(file_path, output_dir, context=ctx) as response:
with open(save_path, "wb") as fp:
fp.write(response.read())

from CSET._workflow_utils.fetch_data import HTTPFileRetriever, fetch_data

fetch_data(HTTPFileRetriever)
28 changes: 27 additions & 1 deletion src/CSET/_workflow_utils/fetch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import logging
import os
import shutil
import ssl
import sys
import urllib.parse
import urllib.request
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from typing import Literal
Expand All @@ -20,7 +23,7 @@


class FileRetrieverABC(abc.ABC):
"""Abstract class for retrieving files from a data source.
"""Abstract base class for retrieving files from a data source.
The `get_file` method must be defined. Optionally the __enter__ and __exit__
methods maybe be overridden to add setup or cleanup code.
Expand Down Expand Up @@ -86,6 +89,29 @@ def get_file(self, file_path: str, output_dir: str) -> None:
logging.warning("Failed to copy %s, error: %s", file, err)


class HTTPFileRetriever(FileRetrieverABC):
"""Retrieve files via HTTP."""

def get_file(self, file_path: str, output_dir: str) -> None:
"""Save a file from a HTTP address to the output directory.
Parameters
----------
file_path: str
Path of the file to copy on MASS. It may contain patterns like
globs, which will be expanded in a system specific manner.
output_dir: str
Path to filesystem directory into which the file should be copied.
"""
ctx = ssl.create_default_context()
save_path = urllib.parse.urlparse(file_path).path.split("/")[-1]
with urllib.request.urlopen(file_path, output_dir, context=ctx) as response:
with open(save_path, "wb") as fp:
# Read in 1 MiB chunks so data doesn't all have to be in memory.
while data := response.read(1024 * 1024):
fp.write(data)


def _get_needed_environment_variables() -> dict:
"""Load the needed variables from the environment."""
# Python 3.10 and older don't fully support ISO 8601 datetime formats.
Expand Down
14 changes: 14 additions & 0 deletions tests/workflow_utils/test_fetch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for fetch_data workflow utility."""

import datetime
import hashlib
from pathlib import Path

import pytest
Expand Down Expand Up @@ -186,3 +187,16 @@ def test_FilesystemFileRetriever_copy_error(caplog):
log_record = caplog.records[0]
assert log_record.levelname == "WARNING"
assert log_record.message.startswith("Failed to copy")


def test_HTTPFileRetriever(tmp_path):
"""Test retrieving a file via HTTP."""
url = "https://github.com/MetOffice/CSET/raw/48dc1d29846604aacb8d370b82bca31405931c87/tests/test_data/exeter_em01.nc"
with fetch_data.HTTPFileRetriever() as hfr:
hfr.get_file(url, str(tmp_path))
file = tmp_path / "exeter_em01.nc"
assert file.is_file()
# Check file hash is correct, indicating a non-corrupt download.
expected_hash = "67899970eeca75b9378f0275ce86db3d1d613f2bc7a178540912848dc8a69ca7"
actual_hash = hashlib.sha256(file.read_bytes()).hexdigest()
assert actual_hash == expected_hash

0 comments on commit 5f6e3f5

Please sign in to comment.