From 8dae6566cf5fa28eb87382aa29fe6d05f3ad5fd5 Mon Sep 17 00:00:00 2001 From: Sean Bryan Date: Thu, 23 Nov 2023 12:12:38 +1100 Subject: [PATCH] Add tests for forcing logic --- payu/models/cable.py | 41 +++++++++++++++--------------- test/models/test_cable.py | 52 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 21 deletions(-) create mode 100644 test/models/test_cable.py diff --git a/payu/models/cable.py b/payu/models/cable.py index b6cd9b23..1a554ba1 100644 --- a/payu/models/cable.py +++ b/payu/models/cable.py @@ -21,6 +21,19 @@ from payu.models.model import Model +def _get_forcing_path(variable, year, input_dir, offset=None, repeat=None): + """Return the met forcing file path for a given variable and year.""" + if offset: + year += offset[1] - offset[0] + if repeat: + year = repeat[0] + ((year - repeat[0]) % (repeat[1] - repeat[0] + 1)) + pattern = os.path.join(input_dir, f"*{variable}*{year}*.nc") + for path in glob.glob(pattern): + return path + msg = f"Unable to infer met forcing path for variable {variable} for year {year}." + raise FileNotFoundError(msg) + + class Cable(Model): def __init__(self, expt, name, config): @@ -51,26 +64,6 @@ def __init__(self, expt, name, config): "Wind", ] - def _get_forcing_path(self, variable, year): - """Return the met forcing file path for a given variable and year.""" - pattern = os.path.join(self.work_input_path, f"*{variable}*{year}*.nc") - for path in glob.glob(pattern): - return path - msg = f"Unable to infer met forcing path for variable {variable} for year {year}." - raise FileNotFoundError(msg) - - def _update_forcing(self, year, offset=None, repeat=None): - """Update the CABLE namelist file to use the correct met forcing.""" - if offset: - year += offset[1] - offset[0] - if repeat: - year = repeat[0] + ((year - repeat[0]) % (repeat[1] - repeat[0] + 1)) - for var in self.met_forcing_vars: - path = self._get_forcing_path(var, year) - self.cable_nml["cable"]["gswpfile"][var] = ( - os.path.relpath(path, start=self.work_path) - ) - def set_model_pathnames(self): super(Cable, self).set_model_pathnames() @@ -109,7 +102,13 @@ def setup(self): with open(forcing_year_config_path, 'r') as file: conf = yaml.safe_load(file) forcing_year_config = conf if conf else {} - self._update_forcing(year, **forcing_year_config) + for var in self.met_forcing_vars: + path = _get_forcing_path( + var, year, self.work_input_path, **forcing_year_config + ) + self.cable_nml["cable"]["gswpfile"][var] = ( + os.path.relpath(path, start=self.work_path) + ) # Write modified namelist file to work dir self.cable_nml.write( diff --git a/test/models/test_cable.py b/test/models/test_cable.py new file mode 100644 index 00000000..7e62b10d --- /dev/null +++ b/test/models/test_cable.py @@ -0,0 +1,52 @@ +import os +import shutil +import tempfile + +import pytest + +import payu.models.cable as cable + +from test.common import make_random_file + + +class TestGetForcingPath: + """Tests for `payu.models.cable._get_forcing_path()`.""" + + @pytest.fixture() + def input_dir(self): + """Create a temporary input directory and return its path.""" + _input_dir = tempfile.mkdtemp(prefix="payu_test_get_forcing_path") + yield _input_dir + shutil.rmtree(_input_dir) + + @pytest.fixture(autouse=True) + def _make_forcing_inputs(self, input_dir): + """Create forcing inputs from 1900 to 1903.""" + for year in [1900, 1901, 1903]: + make_random_file(os.path.join(input_dir, f"crujra_LWdown_{year}.nc")) + + def test_get_forcing_path(self, input_dir): + """Success case: test correct path can be inferred.""" + assert cable._get_forcing_path("LWdown", 1900, input_dir) == os.path.join( + input_dir, "crujra_LWdown_1900.nc" + ) + + def test_year_offset(self, input_dir): + """Success case: test correct path can be inferred with offset.""" + assert cable._get_forcing_path( + "LWdown", 2000, input_dir, offset=[2000, 1900] + ) == os.path.join(input_dir, "crujra_LWdown_1900.nc") + + def test_year_repeat(self, input_dir): + """Success case: test correct path can be inferred with repeat.""" + assert cable._get_forcing_path( + "LWdown", 1904, input_dir, repeat=[1900, 1903] + ) == os.path.join(input_dir, "crujra_LWdown_1900.nc") + + def test_file_not_found_exception(self, input_dir): + """Failure case: test exception is raised if path cannot be inferred.""" + with pytest.raises( + FileNotFoundError, + match="Unable to infer met forcing path for variable LWdown for year 1904.", + ): + _ = cable._get_forcing_path("LWdown", 1904, input_dir)