Skip to content

Commit

Permalink
Merge pull request #240 from jmccreight/feat_control_improvements
Browse files Browse the repository at this point in the history
Feat control improvements
  • Loading branch information
jmccreight authored Oct 28, 2023
2 parents 3038a5b + a938fa4 commit 5456603
Show file tree
Hide file tree
Showing 36 changed files with 899 additions and 340 deletions.
49 changes: 48 additions & 1 deletion autotest/test_control.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import copy, deepcopy
from datetime import datetime

import numpy as np
Expand Down Expand Up @@ -152,5 +153,51 @@ def test_control_advance(control_simple, params_simple):


def test_init_load(domain):
control = Control.load(domain["control_file"])
with pytest.warns(RuntimeWarning):
_ = Control.load_prms(domain["control_file"])
return None


def test_deepcopy(domain):
ctl = Control.load_prms(domain["control_file"], warn_unused_options=False)
ctl_sh = copy(ctl)
ctl_dp = deepcopy(ctl)

opt_restart_orig = ctl.options["verbosity"]
opt_restart_new = "something_else"
ctl.options["verbosity"] = opt_restart_new
assert ctl_sh.options["verbosity"] == opt_restart_new
assert ctl_dp.options["verbosity"] == opt_restart_orig

return None


def test_setitem_setattr(domain):
ctl = Control.load_prms(domain["control_file"], warn_unused_options=False)

# __setitem__ on OptsDict
ctl.options["verbosity"] = 12
with pytest.raises(NameError):
ctl.options["foobar"] = 12

# __setattr__ on Control
ctl.options = {"verbosity": 45}
with pytest.raises(NameError):
ctl.options = {"foobar": 12}

# __setitem__ on Control
ctl["options"] = {"verbosity": 45}
with pytest.raises(NameError):
ctl["options"] = {"foobar": 12}

# The value for options must be a dictionary
with pytest.raises(ValueError):
ctl.options = None


def test_yaml_roundtrip(domain, tmp_path):
ctl = Control.load_prms(domain["control_file"], warn_unused_options=False)
yml_file = tmp_path / "control.yaml"
ctl.to_yaml(yml_file)
ctl_2 = Control.from_yaml(yml_file)
np.testing.assert_equal(ctl.to_dict(), ctl_2.to_dict())
26 changes: 17 additions & 9 deletions autotest/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,21 @@
}


invoke_style = ("prms", "model_dict", "model_dict_from_yml")
invoke_style = ("prms", "model_dict", "model_dict_from_yaml")


@pytest.fixture(scope="function")
def control(domain):
control = Control.load(domain["control_file"])
control.options["verbose"] = 10
control = Control.load_prms(
domain["control_file"], warn_unused_options=False
)
control.options["verbosity"] = 10
control.options["budget_type"] = None
if fortran_avail:
control.options["calc_method"] = "fortran"
else:
control.options["calc_method"] = "numba"
control.options["load_n_time_batches"] = 1
del control.options["netcdf_output_var_names"]
return control


Expand Down Expand Up @@ -110,9 +112,9 @@ def model_args(domain, control, discretization, request):
"parameters": None,
}

elif invoke_style == "model_dict_from_yml":
yml_file = domain["dir"] / "nhm_model.yml"
model_dict = Model.model_dict_from_yml(yml_file)
elif invoke_style == "model_dict_from_yaml":
yaml_file = domain["dir"] / "nhm_model.yml"
model_dict = Model.model_dict_from_yaml(yaml_file)

args = {
"process_list_or_model_dict": model_dict,
Expand Down Expand Up @@ -147,12 +149,18 @@ def test_model(domain, model_args, tmp_path):
control = model_args["control"]

control.options["input_dir"] = input_dir
model_out_dir = tmp_path / "output"
control.options["netcdf_output_dir"] = model_out_dir

if control.options["calc_method"] == "fortran":
with pytest.warns(UserWarning):
model = Model(**model_args)
model = Model(**model_args, write_control=model_out_dir)
else:
model = Model(**model_args)
model = Model(**model_args, write_control=model_out_dir)

# check that control yaml file was written
control_yaml_file = sorted(model_out_dir.glob("*model_control.yaml"))
assert len(control_yaml_file) == 1

# Test passing of control calc_method option
if fortran_avail:
Expand Down
33 changes: 20 additions & 13 deletions autotest/test_netcdf_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ def params(domain):

@pytest.fixture(scope="function")
def control(domain):
control = Control.load(domain["control_file"])
control = Control.load_prms(
domain["control_file"], warn_unused_options=False
)
control.edit_n_time_steps(n_time_steps)
control.options["budget_type"] = "error"
del control.options["netcdf_output_var_names"]
return control


Expand Down Expand Up @@ -96,19 +99,20 @@ def test_process_budgets(domain, control, params, tmp_path, budget_sum_param):
]
output_vars = None

model.initialize_netcdf(
tmp_dir,
budget_args=budget_args,
output_vars=output_vars,
)

with pytest.warns(UserWarning):
model.initialize_netcdf(
tmp_dir,
budget_args=budget_args,
output_vars=output_vars,
)

with pytest.raises(RuntimeError):
model.initialize_netcdf(
tmp_dir,
budget_args=budget_args,
output_vars=output_vars,
)

for tt in range(n_time_steps):
model.advance()
model.calculate()
Expand Down Expand Up @@ -197,19 +201,22 @@ def test_separate_together_var_list(
control.options["input_dir"] = input_dir
control.options["netcdf_output_var_names"] = output_vars
control.options["netcdf_output_separate_files"] = separate
del control.options["netcdf_output_dir"]

# Could limit this to just the variables in model_procs
for ff in domain_output_dir.resolve().glob("*.nc"):
shutil.copy(ff, input_dir / ff.name)
for ff in domain_output_dir.parent.resolve().glob("*.nc"):
shutil.copy(ff, input_dir / ff.name)

with pytest.raises(RuntimeError):
model = Model(
model_procs,
control=control,
parameters=params,
)
model = Model(
model_procs,
control=control,
parameters=params,
)
with pytest.raises(ValueError):
# passing no output_dir arg and none in opts throws an error
model.initialize_netcdf()

control.options["netcdf_output_dir"] = test_output_dir
model = Model(
Expand Down
17 changes: 13 additions & 4 deletions autotest/test_nhm_self_drive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pathlib as pl

import pytest
import xarray as xr

import pywatershed as pws
Expand Down Expand Up @@ -33,18 +34,23 @@ def test_drive_indiv_process(domain, tmp_path):
nhm_output_dir = pl.Path(tmp_path) / "nhm_output"

params = pws.parameters.PrmsParameters.load(domain["param_file"])
control = pws.Control.load(domain["control_file"])
control = pws.Control.load_prms(
domain["control_file"], warn_unused_options=False
)
control.edit_n_time_steps(n_time_steps)
control.options["budget_type"] = "warn"
control.options["calc_method"] = "numba"
control.options["input_dir"] = domain["prms_run_dir"]
del control.options["netcdf_output_var_names"]

nhm = pws.Model(
nhm_processes,
control=control,
parameters=params,
)
nhm.initialize_netcdf(output_dir=nhm_output_dir)
with pytest.warns(UserWarning):
nhm.initialize_netcdf(output_dir=nhm_output_dir)

nhm.run(finalize=True)
del nhm, params, control

Expand All @@ -60,18 +66,21 @@ def test_drive_indiv_process(domain, tmp_path):
proc_model_output_dir.mkdir()

params = pws.parameters.PrmsParameters.load(domain["param_file"])
control = pws.Control.load(domain["control_file"])
control = pws.Control.load_prms(
domain["control_file"], warn_unused_options=False
)
control.edit_n_time_steps(n_time_steps)
control.options["budget_type"] = "warn"
control.options["calc_method"] = "numba"
control.options["input_dir"] = nhm_output_dir
control.options["netcdf_output_dir"] = proc_model_output_dir

proc_model = pws.Model(
[proc],
control=control,
parameters=params,
)
proc_model.initialize_netcdf(output_dir=proc_model_output_dir)
proc_model.initialize_netcdf()
proc_model.run(finalize=True)
del proc_model, params, control

Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_atmosphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])
return Control.load_prms(domain["control_file"], warn_unused_options=False)


@pytest.fixture(scope="function")
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])
return Control.load_prms(domain["control_file"], warn_unused_options=False)


@pytest.fixture(scope="function")
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])
return Control.load_prms(domain["control_file"], warn_unused_options=False)


@pytest.fixture(scope="function")
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def params(domain):

@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])
return Control.load_prms(domain["control_file"], warn_unused_options=False)


class TestPRMSEt:
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_et_can_runoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def params(domain):

@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])
return Control.load_prms(domain["control_file"], warn_unused_options=False)


def test_et_can_runoff(domain, control, params, tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_et_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def params(domain):

@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])
return Control.load_prms(domain["control_file"], warn_unused_options=False)


def test_et(domain, control, params, tmp_path):
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_groundwater.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])
return Control.load_prms(domain["control_file"], warn_unused_options=False)


@pytest.fixture(scope="function")
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_runoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])
return Control.load_prms(domain["control_file"], warn_unused_options=False)


@pytest.fixture(scope="function")
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_snow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])
return Control.load_prms(domain["control_file"], warn_unused_options=False)


@pytest.fixture(scope="function")
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_soilzone.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])
return Control.load_prms(domain["control_file"], warn_unused_options=False)


@pytest.fixture(scope="function")
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_solar_geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])
return Control.load_prms(domain["control_file"], warn_unused_options=False)


@pytest.fixture(scope="function")
Expand Down
21 changes: 17 additions & 4 deletions autotest/test_prms_to_mf6.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def test_mmr_to_mf6(domain, tmp_path, bc_binary_files, bc_flows_combine):
.to("meter ** 3 / s")
.magnitude
)
comp = abs(result - ans_tt)
assert ((comp < 1e-5) | ((comp / ans_tt) < 1e-5)).all()

else:
sim = flopy.mf6.MFSimulation.load(
Expand All @@ -98,7 +96,22 @@ def test_mmr_to_mf6(domain, tmp_path, bc_binary_files, bc_flows_combine):
.to("meter ** 3 / s")
.magnitude
)
comp = abs(result - ans_tt)
assert ((comp < 1e-5) | ((comp / ans_tt) < 1e-5)).all()

# <<
# Compare
abs_diff = abs(result - ans_tt)
with np.errstate(divide="ignore", invalid="ignore"):
rel_diff = abs_diff / ans_tt

abs_tol = 1.0e-5
rel_tol = 1.0e-5

abs_close = abs_diff < abs_tol
rel_close = rel_diff < rel_tol
rel_close = np.where(np.isnan(rel_close), False, rel_close)

close = abs_close | rel_close

assert close.all()

return
Loading

0 comments on commit 5456603

Please sign in to comment.