Skip to content

Commit

Permalink
Add .ugrid.interpolate_na()
Browse files Browse the repository at this point in the history
Added some small utilities
  • Loading branch information
Huite committed Aug 23, 2024
1 parent 77393b9 commit d30d2a7
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ docs/api
docs/examples
docs/examples-dev
docs/sample_data
docs/sg_execution_times.rst

# PyBuilder
target/
Expand Down
7 changes: 7 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ UgridDataArray or UgridDataset.
UgridDataArrayAccessor.binary_erosion
UgridDataArrayAccessor.connected_components
UgridDataArrayAccessor.reverse_cuthill_mckee
UgridDataArrayAccessor.interpolate_na
UgridDataArrayAccessor.laplace_interpolate
UgridDataArrayAccessor.to_dataset
UgridDataArrayAccessor.to_netcdf
Expand Down Expand Up @@ -192,13 +193,15 @@ UGRID1D Topology
Ugrid1d.topology_dimension
Ugrid1d.dimensions
Ugrid1d.attrs
Ugrid1d.coords

Ugrid1d.n_node
Ugrid1d.node_dimension
Ugrid1d.node_coordinates
Ugrid1d.set_node_coords
Ugrid1d.assign_node_coords
Ugrid1d.assign_edge_coords
Ugrid1d.get_coordinates

Ugrid1d.n_edge
Ugrid1d.edge_dimension
Expand All @@ -211,6 +214,7 @@ UGRID1D Topology

Ugrid1d.node_edge_connectivity
Ugrid1d.node_node_connectivity
Ugrid1d.get_connectivity_matrix

Ugrid1d.copy
Ugrid1d.rename
Expand Down Expand Up @@ -251,6 +255,7 @@ UGRID2D Topology
Ugrid2d.topology_dimension
Ugrid2d.dimensions
Ugrid2d.attrs
Ugrid2d.coords

Ugrid2d.n_node
Ugrid2d.node_dimension
Expand All @@ -259,6 +264,7 @@ UGRID2D Topology
Ugrid2d.assign_node_coords
Ugrid2d.assign_edge_coords
Ugrid2d.assign_face_coords
Ugrid2d.get_coordinates

Ugrid2d.n_edge
Ugrid2d.edge_dimension
Expand Down Expand Up @@ -287,6 +293,7 @@ UGRID2D Topology
Ugrid2d.edge_node_connectivity
Ugrid2d.face_edge_connectivity
Ugrid2d.face_face_connectivity
Ugrid2d.get_connectivity_matrix

Ugrid2d.validate_edge_node_connectivity

Expand Down
8 changes: 7 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ Fixed
linear weights within the full bounds of the source grid, rather than only
within the centroids of the source grid. Previously, it would give no results
beyond the centroids for structured to structured regridding, and it would
give nearest results (equal to :class:`CentroidLocatorRegridder`) otherwise.
give nearest results (equal to :class:`xugrid.CentroidLocatorRegridder`) otherwise.

Added
~~~~~

- :meth:`UgridDataArrayAccessor.interpolate_na` has been added to fill missing
data. Currently, the only supported method is ``"nearest"``.

Changed
~~~~~~~
Expand Down
17 changes: 17 additions & 0 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,20 @@ def test_laplace_interpolate():
con, data, use_weights=False, direct_solve=False
)
assert np.allclose(actual, expected)


def test_nearest_interpolate():
x = np.array([0.0, 1.0, 2.0, 3.0, 4.0])
y = np.zeros_like(x)
coordinates = np.column_stack((x, y))
data = np.array([0.0, np.nan, np.nan, np.nan, 4.0])
actual = interpolate.nearest_interpolate(coordinates, data, np.inf)
assert np.allclose(actual, np.array([0.0, 0.0, 0.0, 4.0, 4.0]))

actual = interpolate.nearest_interpolate(coordinates, data, 1.1)
assert np.allclose(actual, np.array([0.0, 0.0, np.nan, 4.0, 4.0]), equal_nan=True)

with pytest.raises(ValueError, match="All values are NA."):
interpolate.nearest_interpolate(
coordinates, data=np.full_like(data, np.nan), max_distance=np.inf
)
32 changes: 32 additions & 0 deletions tests/test_ugrid_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def ugrid1d_ds():
)
ds = grid.to_dataset()
ds["a1d"] = xr.DataArray([1.0, 2.0, 3.0], dims=[grid.node_dimension])
ds["b1d"] = xr.DataArray([1.0, 2.0], dims=[grid.edge_dimension])
return xugrid.UgridDataset(ds)


Expand Down Expand Up @@ -1262,6 +1263,11 @@ def test_laplace_interpolate_facets():
with pytest.raises(ValueError, match=msg):
edge_uda.ugrid.laplace_interpolate(direct_solve=True)

for uda in (node_uda, edge_uda, face_uda):
actual = uda.ugrid.interpolate_na()
assert isinstance(actual, xugrid.UgridDataArray)
assert np.allclose(actual, 1.0)


def test_laplace_interpolate_1d():
uda = ugrid1d_ds()["a1d"]
Expand All @@ -1272,6 +1278,32 @@ def test_laplace_interpolate_1d():
assert np.allclose(actual, 1.0)


def test_interpolate_na_1d():
uda = ugrid1d_ds()["a1d"]
with pytest.raises(ValueError, match='"abc" is not a valid interpolator.'):
uda.ugrid.interpolate_na(method="abc")

# Node data
uda = ugrid1d_ds()["a1d"]
uda[:] = 1.0
uda[1] = np.nan
actual = uda.ugrid.interpolate_na()
assert isinstance(actual, xugrid.UgridDataArray)
assert np.allclose(actual, 1.0)

# Edge data
uda = ugrid1d_ds()["b1d"]
uda[:] = 1.0
uda[1] = np.nan
actual = uda.ugrid.interpolate_na()
assert isinstance(actual, xugrid.UgridDataArray)
assert np.allclose(actual, 1.0)

# Check max_distance
actual = uda.ugrid.interpolate_na(max_distance=0.5)
assert np.isnan(actual[1])


def test_ugriddataset_wrap_twice(tmp_path):
"""
in issue https://github.com/Deltares/xugrid/issues/208 wrapping a ds
Expand Down
24 changes: 22 additions & 2 deletions xugrid/core/dataarray_accessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import scipy.sparse
Expand Down Expand Up @@ -582,16 +582,36 @@ def reverse_cuthill_mckee(self):
def interpolate_na(
self,
method: str = "nearest",
max_distance: Optional[float] = None,
):
"""
Fill in NaNs by interpolating.
Parameters
----------
method: str, default is "nearest"
Currently the only supported method.
max_distance: nonnegative float, optional.
Use ``None`` for no maximum distance.
Returns
-------
filled: UgridDataArray of floats
"""

if method != "nearest":
raise ValueError(f'"{method}" is not a valid interpolator.')

if max_distance is None:
max_distance = np.inf

grid = self.grid
da = self.obj

filled = nearest_interpolate(
coordinates=grid.get_coordinates(dim=da.dims[0]),
data=da.to_numpy(),
max_distance=max_distance,
)
da_filled = da.copy(data=filled)
return UgridDataArray(da_filled, grid)
Expand All @@ -607,7 +627,7 @@ def laplace_interpolate(
maxiter: int = 500,
):
"""
Fill gaps in ``data`` (``np.nan`` values) using Laplace interpolation.
Fill in NaNs by using Laplace interpolation.
This solves Laplace's equation where where there is no data, with data
values functioning as fixed potential boundary conditions.
Expand Down
16 changes: 12 additions & 4 deletions xugrid/ugrid/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,20 +312,28 @@ def laplace_interpolate(
def nearest_interpolate(
coordinates: FloatArray,
data: FloatArray,
max_distance: float,
) -> FloatArray:
isnull = np.isnan(data)
if isnull.all():
raise ValueError("All values are NA.")

i_source = np.flatnonzero(~isnull)
i_target = np.flatnonzero(isnull)
source_coordinates = coordinates[i_source]
target_coordinates = coordinates[isnull]
target_coordinates = coordinates[i_target]
# Locate the nearest notnull for each null value.
tree = KDTree(source_coordinates)
_, index = tree.query(target_coordinates, workers=-1)
# index contains an intex of the target coordinates to the source
_, index = tree.query(
target_coordinates, distance_upper_bound=max_distance, workers=-1
)
# Remove entries beyond max distance, returned by .query as self.n.
keep = index < len(source_coordinates)
index = index[keep]
i_target = i_target[keep]
# index contains an index of the target coordinates to the source
# coordinates, not the direct index into the data, so we need an additional
# indexing step.
out = data.copy()
out[isnull] = data[i_source[index]]
out[i_target] = data[i_source[index]]
return out
2 changes: 2 additions & 0 deletions xugrid/ugrid/ugrid1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def coords(self):
}

def get_coordinates(self, dim: str) -> FloatArray:
"""Return the coordinates for the specified UGRID dimension."""
if dim == self.node_dimension:
return self.node_coordinates
elif dim == self.edge_dimension:
Expand All @@ -284,6 +285,7 @@ def get_coordinates(self, dim: str) -> FloatArray:
)

def get_connectivity_matrix(self, dim: str, xy_weights: bool):
"""Return the connectivity matrix for the specified UGRID dimension."""
if dim == self.node_dimension:
connectivity = self.node_node_connectivity.copy()
coordinates = self.node_coordinates
Expand Down
2 changes: 2 additions & 0 deletions xugrid/ugrid/ugrid2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ def coords(self):
}

def get_coordinates(self, dim: str) -> FloatArray:
"""Return the coordinates for the specified UGRID dimension."""
if dim == self.node_dimension:
return self.node_coordinates
elif dim == self.edge_dimension:
Expand All @@ -700,6 +701,7 @@ def get_coordinates(self, dim: str) -> FloatArray:
)

def get_connectivity_matrix(self, dim: str, xy_weights: bool):
"""Return the connectivity matrix for the specified UGRID dimension."""
if dim == self.node_dimension:
connectivity = self.node_node_connectivity.copy()
coordinates = self.node_coordinates
Expand Down

0 comments on commit d30d2a7

Please sign in to comment.