Skip to content

Commit

Permalink
199 nlmodgwfic slow for uniform starting head (#200)
Browse files Browse the repository at this point in the history
* fix for #199

* fix

* return da in some situations

---------

Co-authored-by: Ruben Caljé <[email protected]>
  • Loading branch information
OnnoEbbens and rubencalje authored Jul 7, 2023
1 parent 4fba381 commit 62eadc9
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
12 changes: 6 additions & 6 deletions nlmod/gwf/gwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ def npf(
gwf,
pname=pname,
icelltype=icelltype,
k=k.data,
k33=k33.data,
k=k,
k33=k33,
save_flows=save_flows,
**kwargs,
)
Expand Down Expand Up @@ -381,7 +381,7 @@ def ghb(
bhead = f"{da_name}_peil"
cond = f"{da_name}_cond"

mask_arr = _get_value_from_ds_datavar(ds, "cond", cond)
mask_arr = _get_value_from_ds_datavar(ds, "cond", cond, return_da=True)
mask = mask_arr != 0

ghb_rec = grid.da_to_reclist(
Expand Down Expand Up @@ -466,7 +466,7 @@ def drn(
elev = f"{da_name}_peil"
cond = f"{da_name}_cond"

mask_arr = _get_value_from_ds_datavar(ds, "cond", cond)
mask_arr = _get_value_from_ds_datavar(ds, "cond", cond, return_da=True)
mask = mask_arr != 0

first_active_layer = layer is None
Expand Down Expand Up @@ -636,7 +636,7 @@ def chd(
)
mask = kwargs.pop("chd")

maskarr = _get_value_from_ds_datavar(ds, "mask", mask)
maskarr = _get_value_from_ds_datavar(ds, "mask", mask, return_da=True)
mask = maskarr != 0

# get the stress_period_data
Expand Down Expand Up @@ -693,7 +693,7 @@ def surface_drain_from_ds(ds, gwf, resistance, elev="ahn", pname="drn", **kwargs

ds.attrs["surface_drn_resistance"] = resistance

maskarr = _get_value_from_ds_datavar(ds, "elev", elev)
maskarr = _get_value_from_ds_datavar(ds, "elev", elev, return_da=True)
mask = maskarr.notnull()

drn_rec = grid.da_to_reclist(
Expand Down
10 changes: 9 additions & 1 deletion nlmod/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def _get_value_from_ds_attr(ds, varname, attr=None, value=None, warn=True):
return value


def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True):
def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True, return_da=False):
"""Internal function to get value from dataset data variables.
Parameters
Expand All @@ -551,6 +551,9 @@ def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True):
the same as varname. If not passed as string, it is treated as data
warn : bool, optional
log warning if value not found
return_da : bool, optional
if True a dataarray can be returned, if False a dataarray is always
converted to a numpy array before being returned. The default is False.
Returns
-------
Expand Down Expand Up @@ -597,4 +600,9 @@ def _get_value_from_ds_datavar(ds, varname, datavar=None, warn=True):
f"to function or check whether 'ds.{datavar}' was set correctly."
)
logger.warning(msg)

if not return_da:
if isinstance(value, xr.DataArray):
value = value.values

return value
10 changes: 6 additions & 4 deletions tests/test_003_mfpackages.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,22 @@ def get_value_from_ds_datavar():
ds["test_var"] = ("layer", "y", "x"), np.arange(np.product(shape)).reshape(shape)

# get value from ds
v0 = nlmod.util._get_value_from_ds_datavar(ds, "test_var", "test_var")
v0 = nlmod.util._get_value_from_ds_datavar(
ds, "test_var", "test_var", return_da=True
)
xr.testing.assert_equal(ds["test_var"], v0)

# get value from ds, variable and stored name are different
v1 = nlmod.util._get_value_from_ds_datavar(ds, "test", "test_var")
xr.testing.assert_equal(ds["test_var"], v1)
xr.testing.assert_equal(ds["test_var"].values, v1)

# do not get value from ds, value is Data Array, should log info msg
v2 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0)
v2 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0, return_da=True)
xr.testing.assert_equal(ds["test_var"], v2)

# do not get value from ds, value is Data Array, no msg
v0.name = "test2"
v3 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0)
v3 = nlmod.util._get_value_from_ds_datavar(ds, "test", v0, return_da=True)
assert (v0 == v3).all()

# return None, value is str but not in dataset, should log warning
Expand Down

0 comments on commit 62eadc9

Please sign in to comment.