Skip to content

Commit

Permalink
Add tests for forcing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanBryan51 committed Nov 23, 2023
1 parent 646aaae commit 4b8c821
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 21 deletions.
41 changes: 20 additions & 21 deletions payu/models/cable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@
from payu.models.model import Model


def _get_forcing_path(variable, year, input_dir, offset=None, repeat=None):
"""Return the met forcing file path for a given variable and year."""
if offset:
year += offset[1] - offset[0]
if repeat:
year = repeat[0] + ((year - repeat[0]) % (repeat[1] - repeat[0] + 1))
pattern = os.path.join(input_dir, f"*{variable}*{year}*.nc")
for path in glob.glob(pattern):
return path
msg = f"Unable to infer met forcing path for variable {variable} for year {year}."
raise FileNotFoundError(msg)


class Cable(Model):

def __init__(self, expt, name, config):
Expand Down Expand Up @@ -51,26 +64,6 @@ def __init__(self, expt, name, config):
"Wind",
]

def _get_forcing_path(self, variable, year):
"""Return the met forcing file path for a given variable and year."""
pattern = os.path.join(self.work_input_path, f"*{variable}*{year}*.nc")
for path in glob.glob(pattern):
return path
msg = f"Unable to infer met forcing path for variable {variable} for year {year}."
raise FileNotFoundError(msg)

def _update_forcing(self, year, offset=None, repeat=None):
"""Update the CABLE namelist file to use the correct met forcing."""
if offset:
year += offset[1] - offset[0]
if repeat:
year = repeat[0] + ((year - repeat[0]) % (repeat[1] - repeat[0] + 1))
for var in self.met_forcing_vars:
path = self._get_forcing_path(var, year)
self.cable_nml["cable"]["gswpfile"][var] = (
os.path.relpath(path, start=self.work_path)
)

def set_model_pathnames(self):
super(Cable, self).set_model_pathnames()

Expand Down Expand Up @@ -109,7 +102,13 @@ def setup(self):
with open(forcing_year_config_path, 'r') as file:
conf = yaml.safe_load(file)
forcing_year_config = conf if conf else {}
self._update_forcing(year, **forcing_year_config)
for var in self.met_forcing_vars:
path = _get_forcing_path(
var, year, self.work_input_path, **forcing_year_config
)
self.cable_nml["cable"]["gswpfile"][var] = (
os.path.relpath(path, start=self.work_path)
)

# Write modified namelist file to work dir
self.cable_nml.write(
Expand Down
52 changes: 52 additions & 0 deletions test/models/test_cable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import shutil
import tempfile

import pytest

import payu.models.cable as cable

from test.common import tmpdir, make_random_file


class TestGetForcingPath:
"""Tests for `payu.model.cable._get_forcing_path`."""

@pytest.fixture()
def input_dir(self):
"""Create a temporary input directory and return its path."""
_input_dir = tempfile.mkdtemp(prefix="payu_test_get_forcing_path")
yield _input_dir
shutil.rmtree(_input_dir)

@pytest.fixture(autouse=True)
def _make_forcing_inputs(self, input_dir):
"""Create forcing inputs from 1900 to 1903."""
for year in [1900, 1901, 1903]:
make_random_file(os.path.join(input_dir, f"crujra_LWdown_{year}.nc"))

def test_get_forcing_path(self, input_dir):
"""Success case: test correct path can be inferred."""
assert cable._get_forcing_path("LWdown", 1900, input_dir) == os.path.join(
input_dir, "crujra_LWdown_1900.nc"
)

def test_year_offset(self, input_dir):
"""Success case: test correct path can be inferred with offset."""
assert cable._get_forcing_path(
"LWdown", 2000, input_dir, offset=[2000, 1900]
) == os.path.join(input_dir, "crujra_LWdown_1900.nc")

def test_year_repeat(self, input_dir):
"""Success case: test correct path can be inferred with repeat."""
assert cable._get_forcing_path(
"LWdown", 1904, input_dir, repeat=[1900, 1903]
) == os.path.join(input_dir, "crujra_LWdown_1900.nc")

def test_file_not_found_exception(self, input_dir):
"""Failure case: test exception is raised if path cannot be inferred."""
with pytest.raises(
FileNotFoundError,
match="Unable to infer met forcing path for variable LWdown for year 1904.",
):
_ = cable._get_forcing_path("LWdown", 1904, input_dir)

0 comments on commit 4b8c821

Please sign in to comment.