Skip to content

Commit

Permalink
Merge pull request #97 from Deltares/richer-weights
Browse files Browse the repository at this point in the history
Richer weights
  • Loading branch information
Huite committed Jun 23, 2023
2 parents 4243aa9 + c792381 commit 5d7bb9c
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 178 deletions.
13 changes: 9 additions & 4 deletions tests/fixtures/fixture_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,20 @@ def grid_data_d():
@pytest.fixture(scope="function")
def grid_data_e():
return xr.DataArray(
data=np.arange(12).reshape((4, 3)),
dims=["y", "x"],
data=np.zeros((4, 3, 2)),
dims=["y", "x", "nbounds"],
coords={
"y": np.array([175, 125, 75, 25]),
"x": np.array([30, 67.5, 105]),
"dx": 25,
"dy": -50.0,
"xbounds_left": ("x", np.array([17.5, 42.5, 92.5])),
"xbounds_right": ("x", np.array([42.5, 92.5, 117.5])),
"xbounds": (
["x", "nbounds"],
np.column_stack(
(np.array([17.5, 42.5, 92.5]), np.array([42.5, 92.5, 117.5]))
),
),
"nbounds": np.arange(2),
},
)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_regrid/test_regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_overlap_regridder(disk, quads_1):
assert broadcasted.shape == (5, 100)


def test_lineair_interpolator_structured(
def test_linear_interpolator_structured(
grid_data_a, grid_data_a_layered, grid_data_b, expected_results_linear
):
regridder = BarycentricInterpolator(source=grid_data_a, target=grid_data_b)
Expand Down Expand Up @@ -172,4 +172,4 @@ def test_regridder_from_dataset(cls, disk, quads_1):
dataset = regridder.to_dataset()
new_regridder = cls.from_dataset(dataset)
new_result = new_regridder.regrid(disk)
assert new_result.equals(result)
assert np.array_equal(new_result.values, result.values, equal_nan=True)
1 change: 1 addition & 0 deletions tests/test_ugrid1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def test_ugrid1d_rename():
"node_x": "__renamed_node_x",
"node_y": "__renamed_node_y",
}
assert renamed.name == "__renamed"


def test_ugrid1d_rename_with_dataset():
Expand Down
1 change: 1 addition & 0 deletions tests/test_ugrid2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,7 @@ def test_ugrid2d_rename():
"node_x": "__renamed_node_x",
"node_y": "__renamed_node_y",
}
assert renamed.name == "__renamed"


def test_ugrid2d_rename_with_dataset():
Expand Down
61 changes: 37 additions & 24 deletions xugrid/regrid/regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,13 @@ def _regrid(source: FloatArray, A: WeightMatrixCSR, size: int):
return numba.njit(_regrid, parallel=True, cache=True)


def setup_grid(obj):
def setup_grid(obj, **kwargs):
if isinstance(obj, (xu.Ugrid2d, xu.UgridDataArray, xu.UgridDataset)):
return UnstructuredGrid2d(obj)
elif isinstance(obj, (xr.DataArray, xr.Dataset)):
return StructuredGrid2d(obj, name_y="y", name_x="x")
return StructuredGrid2d(
obj, name_y=kwargs.get("name_y", "y"), name_x=kwargs.get("name_x", "x")
)
else:
raise TypeError()

Expand Down Expand Up @@ -120,16 +122,14 @@ def _setup_regrid(self, func) -> Callable:
return

def _regrid_array(self, source):
if hasattr(self, "_source"):
source_grid = self._source
else:
source_grid = source
source_grid = self._source
first_dims_shape = source.shape[: -source_grid.ndim]

# The regridding can be mapped over additional dimensions (e.g. for every time slice).
# This is the `extra_index` iteration in _regrid().
# But it should work consistently even if no additional present: in that case we create
# a 1-sized additional dimension in front, so the `extra_index` iteration always applies.
# The regridding can be mapped over additional dimensions, e.g. for
# every time slice. This is the `extra_index` iteration in _regrid().
# But it should work consistently even if no additional present: in
# that case we create a 1-sized additional dimension in front, so the
# `extra_index` iteration always applies.
if source.ndim == source_grid.ndim:
source = source[np.newaxis]

Expand Down Expand Up @@ -222,11 +222,12 @@ def to_dataset(self) -> xr.Dataset:
"""
Store the computed weights and target in a dataset for re-use.
"""
ds = xr.Dataset(
weights_ds = xr.Dataset(
{f"__regrid_{k}": v for k, v in zip(self._weights._fields, self._weights)}
)
ugrid_ds = self._target.ugrid_topology.to_dataset()
return xr.merge((ds, ugrid_ds))
source_ds = self._source.to_dataset("__source")
target_ds = self._target.to_dataset("__target")
return xr.merge((weights_ds, source_ds, target_ds))

@staticmethod
def _csr_from_dataset(dataset: xr.Dataset) -> WeightMatrixCSR:
Expand Down Expand Up @@ -256,10 +257,19 @@ def _weights_from_dataset(
"""

@classmethod
def from_weights(cls, weights, target: "xugrid.Ugrid2d"):
def from_weights(
cls, weights, target: Union["xugrid.Ugrid2d", xr.DataArray, xr.Dataset]
):
instance = cls.__new__(cls)
instance._weights = weights
instance._target = UnstructuredGrid2d(target)
instance._weights = cls._weights_from_dataset(weights)
instance._target = setup_grid(target)
unstructured = weights["__source_type"].attrs["type"] == "UnstructuredGrid2d"
if unstructured:
instance._source = setup_grid(xu.Ugrid2d.from_dataset(weights, "__source"))
else:
instance._source = setup_grid(
weights, name_x="__source_x", name_y="__source_y"
)
return instance

@classmethod
Expand All @@ -268,9 +278,12 @@ def from_dataset(cls, dataset: xr.Dataset):
Reconstruct the regridder from a dataset with source, target indices
and weights.
"""
target = xu.Ugrid2d.from_dataset(dataset)
weights = cls._weights_from_dataset(dataset)
return cls.from_weights(weights, target)
unstructured = dataset["__target_type"].attrs["type"] == "UnstructuredGrid2d"
if unstructured:
target = xu.Ugrid2d.from_dataset(dataset, "__target")

# weights = cls._weights_from_dataset(dataset)
return cls.from_weights(dataset, target)


class CentroidLocatorRegridder(BaseRegridder):
Expand Down Expand Up @@ -307,7 +320,7 @@ def _regrid(source: FloatArray, A: WeightMatrixCOO, size: int):

@property
def weights(self):
return self._weights
return self.to_dataset()

@weights.setter
def weights(self, weights: WeightMatrixCOO, target: "xugrid.Ugrid2d"):
Expand All @@ -334,7 +347,7 @@ def _compute_weights(self, source, target, relative: bool) -> None:

@property
def weights(self):
return self._weights
return self.to_dataset()

@weights.setter
def weights(self, weights: WeightMatrixCSR):
Expand Down Expand Up @@ -398,8 +411,8 @@ def _compute_weights(self, source, target) -> None:
@classmethod
def from_weights(
cls,
weights: WeightMatrixCSR,
target: "xugrid.Ugrid2d",
weights: xr.Dataset,
target: Union["xugrid.Ugrid2d", xr.DataArray, xr.Dataset],
method: Union[str, Callable] = "mean",
):
instance = super().from_weights(weights, target)
Expand Down Expand Up @@ -497,7 +510,7 @@ def _compute_weights(self, source, target):

@property
def weights(self):
return self._weights
return self.to_dataset()

@weights.setter
def weights(self, weights: WeightMatrixCSR):
Expand Down
Loading

0 comments on commit 5d7bb9c

Please sign in to comment.