Skip to content

Commit

Permalink
Enable broadcasting/lazy evaluation in interpolate_na and laplace_int…
Browse files Browse the repository at this point in the history
…erpolate. Fixes #292

Also changes dims into a set, consistent with xarray (future) behavior.
  • Loading branch information
Huite committed Aug 23, 2024
1 parent 43d98c3 commit b55adaf
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 54 deletions.
4 changes: 3 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ Added
- :meth:`xugrid.UgridDataArrayAccessor.interpolate_na` has been added to fill missing
data. Currently, the only supported method is ``"nearest"``.
- :attr:`xugrid.Ugrid1.dims` and :attr:`xugrid.Ugrid2.dims` have been added to
return a tuple of the UGRID dimensions.
return a set of the UGRID dimensions.
- :meth:`xugrid.UgridDataArrayAccessor.laplace_interpolate` now uses broadcasts
over non-UGRID dimensions and support lazy evaluation.

Changed
~~~~~~~
Expand Down
14 changes: 6 additions & 8 deletions tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ def test_laplace_interpolate():
data = np.array([1.0, np.nan, np.nan, np.nan, 5.0])
with pytest.raises(ValueError, match="connectivity is not a square matrix"):
con = sparse.coo_matrix(coo_content, shape=(4, 5)).tocsr()
interpolate.laplace_interpolate(con, data, use_weights=False)
interpolate.laplace_interpolate(data, con, use_weights=False)

expected = np.arange(1.0, 6.0)
con = sparse.coo_matrix(coo_content, shape=(5, 5)).tocsr()
actual = interpolate.laplace_interpolate(
con, data, use_weights=False, direct_solve=True
data, con, use_weights=False, direct_solve=True
)
assert np.allclose(actual, expected)

actual = interpolate.laplace_interpolate(
con, data, use_weights=False, direct_solve=False
data, con, use_weights=False, direct_solve=False
)
assert np.allclose(actual, expected)

Expand All @@ -57,13 +57,11 @@ def test_nearest_interpolate():
y = np.zeros_like(x)
coordinates = np.column_stack((x, y))
data = np.array([0.0, np.nan, np.nan, np.nan, 4.0])
actual = interpolate.nearest_interpolate(coordinates, data, np.inf)
actual = interpolate.nearest_interpolate(data, coordinates, np.inf)
assert np.allclose(actual, np.array([0.0, 0.0, 0.0, 4.0, 4.0]))

actual = interpolate.nearest_interpolate(coordinates, data, 1.1)
actual = interpolate.nearest_interpolate(data, coordinates, 1.1)
assert np.allclose(actual, np.array([0.0, 0.0, np.nan, 4.0, 4.0]), equal_nan=True)

with pytest.raises(ValueError, match="All values are NA."):
interpolate.nearest_interpolate(
coordinates, data=np.full_like(data, np.nan), max_distance=np.inf
)
interpolate.nearest_interpolate(np.full_like(data, np.nan), coordinates, np.inf)
2 changes: 1 addition & 1 deletion tests/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_labels_to_indices():

def test_single_ugrid_chunk():
grid = generate_mesh_2d(3, 3)
ugrid_dims = set(grid.dims)
ugrid_dims = grid.dims
da = xr.DataArray(np.ones(grid.n_face), dims=(grid.face_dimension,))
assert pt.single_ugrid_chunk(da, ugrid_dims) is da

Expand Down
2 changes: 1 addition & 1 deletion tests/test_ugrid1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_dimensions():
grid = grid1d()
assert grid.node_dimension == f"{NAME}_nNodes"
assert grid.edge_dimension == f"{NAME}_nEdges"
assert grid.dims == (f"{NAME}_nNodes", f"{NAME}_nEdges")
assert grid.dims == {f"{NAME}_nNodes", f"{NAME}_nEdges"}
assert grid.sizes == {
f"{NAME}_nNodes": 3,
f"{NAME}_nEdges": 2,
Expand Down
20 changes: 18 additions & 2 deletions tests/test_ugrid2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,22 @@ def check_attrs(ds):
check_attrs(ds)


def test_find_ugrid_dim():
grid = grid2d()
da = xr.DataArray(data=np.ones((grid.n_face,)), dims=[grid.face_dimension])
assert grid.find_ugrid_dim(da) == grid.face_dimension

weird = xr.DataArray(
data=np.ones((grid.n_face, grid.n_node)),
dims=[grid.face_dimension, grid.node_dimension],
)
with pytest.raises(
ValueError,
match="UgridDataArray should contain exactly one of the UGRID dimension",
):
grid.find_ugrid_dim(weird)


def test_ugrid2d_set_node_coords():
grid = grid2d()
ds = xr.Dataset()
Expand Down Expand Up @@ -449,11 +465,11 @@ def test_dimensions():
assert grid.node_dimension == f"{NAME}_nNodes"
assert grid.edge_dimension == f"{NAME}_nEdges"
assert grid.face_dimension == f"{NAME}_nFaces"
assert grid.dims == (
assert grid.dims == {
f"{NAME}_nNodes",
f"{NAME}_nEdges",
f"{NAME}_nFaces",
)
}
assert grid.sizes == {
f"{NAME}_nNodes": 7,
f"{NAME}_nEdges": 10,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_ugrid_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings

import dask
import geopandas as gpd
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -387,6 +388,27 @@ def test_laplace_interpolate(self):
assert isinstance(actual, xugrid.UgridDataArray)
assert np.allclose(actual, 1.0)

def test_broadcasted_laplace_interpolate(self):
uda2 = self.uda.copy()
uda2.obj[:-2] = np.nan
multiplier = xr.DataArray(
np.ones((3, 2)),
coords={"time": [0, 1, 2], "layer": [1, 2]},
dims=("time", "layer"),
)
nd_uda2 = uda2 * multiplier
actual = nd_uda2.ugrid.laplace_interpolate(direct_solve=True)
assert isinstance(actual, xugrid.UgridDataArray)
assert np.allclose(actual, 1.0)
assert set(actual.dims) == set(nd_uda2.dims)

# Test delayed evaluation too.
nd_uda2 = uda2 * multiplier.chunk({"time": 1})
actual = nd_uda2.ugrid.laplace_interpolate(direct_solve=True)
assert isinstance(actual, xugrid.UgridDataArray)
assert set(actual.dims) == set(nd_uda2.dims)
assert isinstance(actual.data, dask.array.Array)

def test_to_dataset(self):
uda2 = self.uda.copy()
uda2.ugrid.obj.name = "test"
Expand Down
64 changes: 41 additions & 23 deletions xugrid/core/dataarray_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from xugrid.core.wrap import UgridDataArray, UgridDataset
from xugrid.plot.plot import _PlotMethods
from xugrid.ugrid import connectivity
from xugrid.ugrid.interpolate import laplace_interpolate, nearest_interpolate
from xugrid.ugrid.interpolate import (
interpolate_na_helper,
laplace_interpolate,
nearest_interpolate,
)
from xugrid.ugrid.ugrid1d import Ugrid1d
from xugrid.ugrid.ugrid2d import Ugrid2d
from xugrid.ugrid.ugridbase import UgridType
Expand Down Expand Up @@ -587,6 +591,9 @@ def interpolate_na(
"""
Fill in NaNs by interpolating.
This function automatically finds the UGRID dimension and broadcasts
over the other dimensions.
Parameters
----------
method: str, default is "nearest"
Expand All @@ -607,13 +614,17 @@ def interpolate_na(

grid = self.grid
da = self.obj

filled = nearest_interpolate(
coordinates=grid.get_coordinates(dim=da.dims[0]),
data=da.to_numpy(),
max_distance=max_distance,
ugrid_dim = grid.find_ugrid_dim(da)

da_filled = interpolate_na_helper(
da,
ugrid_dim=ugrid_dim,
func=nearest_interpolate,
kwargs={
"coordinates": grid.get_coordinates(ugrid_dim),
"max_distance": max_distance,
},
)
da_filled = da.copy(data=filled)
return UgridDataArray(da_filled, grid)

def laplace_interpolate(
Expand All @@ -629,6 +640,9 @@ def laplace_interpolate(
"""
Fill in NaNs by using Laplace interpolation.
This function automatically finds the UGRID dimension and broadcasts
over the other dimensions.
This solves Laplace's equation where where there is no data, with data
values functioning as fixed potential boundary conditions.
Expand Down Expand Up @@ -669,25 +683,29 @@ def laplace_interpolate(
"""
grid = self.grid
da = self.obj
if len(da.dims) > 1:
# TODO: apply ufunc
raise NotImplementedError
if da.dims[0] == grid.edge_dimension:

grid = self.grid
da = self.obj
ugrid_dim = grid.find_ugrid_dim(da)
if ugrid_dim == grid.edge_dimension:
raise ValueError("Laplace interpolation along edges is not allowed.")

connectivity = grid.get_connectivity_matrix(da.dims[0], xy_weights=xy_weights)
filled = laplace_interpolate(
connectivity=connectivity,
data=da.to_numpy(),
use_weights=xy_weights,
direct_solve=direct_solve,
delta=delta,
relax=relax,
rtol=rtol,
atol=atol,
maxiter=maxiter,
connectivity = grid.get_connectivity_matrix(ugrid_dim, xy_weights=xy_weights)
da_filled = interpolate_na_helper(
da,
ugrid_dim,
func=laplace_interpolate,
kwargs={
"connectivity": connectivity,
"use_weights": xy_weights,
"direct_solve": direct_solve,
"delta": delta,
"relax": relax,
"rtol": rtol,
"atol": atol,
"maxiter": maxiter,
},
)
da_filled = da.copy(data=filled)
return UgridDataArray(da_filled, grid)

def to_dataset(self, optional_attributes: bool = False):
Expand Down
2 changes: 1 addition & 1 deletion xugrid/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def __init__(self, obj):
darray = obj.obj
grid = obj.grid

invalid = set(darray.dims) - set(grid.dims)
invalid = set(darray.dims) - grid.dims
if invalid:
raise ValueError(
f"UgridDataArray contains non-topology dimensions: {invalid}.\n"
Expand Down
30 changes: 26 additions & 4 deletions xugrid/ugrid/interpolate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import warnings
from typing import NamedTuple, Tuple
from typing import Any, Callable, Dict, NamedTuple, Tuple

import numba as nb
import numpy as np
import xarray as xr
from scipy import sparse
from scipy.spatial import KDTree

Expand Down Expand Up @@ -197,8 +198,8 @@ def __repr__(self) -> str:


def laplace_interpolate(
connectivity: sparse.csr_matrix,
data: FloatArray,
connectivity: sparse.csr_matrix,
use_weights: bool,
direct_solve: bool = False,
delta=0.0,
Expand All @@ -218,9 +219,9 @@ def laplace_interpolate(
Parameters
----------
data: ndarray of floats with shape ``(n,)``
connectivity: scipy.sparse.csr_matrix with shape ``(n, n)``
Sparse connectivity matrix containing ``n_nonzero`` indices and weight values.
data: ndarray of floats with shape ``(n,)``
use_weights: bool, default False.
Wether to use the data attribute of the connectivity matrix as
coefficients. If ``False``, defaults to uniform coefficients of 1.
Expand Down Expand Up @@ -310,8 +311,8 @@ def laplace_interpolate(


def nearest_interpolate(
coordinates: FloatArray,
data: FloatArray,
coordinates: FloatArray,
max_distance: float,
) -> FloatArray:
isnull = np.isnan(data)
Expand All @@ -337,3 +338,24 @@ def nearest_interpolate(
out = data.copy()
out[i_target] = data[i_source[index]]
return out


def interpolate_na_helper(
da: xr.DataArray,
ugrid_dim: str,
func: Callable,
kwargs: Dict[str, Any],
):
"""Use apply ufunc to broadcast over the non UGRID dims."""
da_filled = xr.apply_ufunc(
func,
da,
input_core_dims=[[ugrid_dim]],
output_core_dims=[[ugrid_dim]],
vectorize=True,
kwargs=kwargs,
dask="parallelized",
keep_attrs=True,
output_dtypes=[da.dtype],
)
return da_filled
2 changes: 1 addition & 1 deletion xugrid/ugrid/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def validate_partition_topology(grouped: defaultdict[str, UgridType]) -> None:
f"same type, received: {types}"
)

griddims = list({grid.dims for grid in grids})
griddims = list({tuple(grid.dims) for grid in grids})
if len(griddims) > 1:
raise ValueError(
f"Dimension names on UGRID topology {name} do not match "
Expand Down
9 changes: 5 additions & 4 deletions xugrid/ugrid/ugrid1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,8 @@ def core_dimension(self):

@property
def dims(self):
"""Tuple of UGRID dimension names: node dimension, edge dimension."""
# Tuple to preserve order, unlike set.
return (self.node_dimension, self.edge_dimension)
"""Set of UGRID dimension names: node dimension, edge dimension."""
return {self.node_dimension, self.edge_dimension}

@property
def sizes(self):
Expand Down Expand Up @@ -466,7 +465,9 @@ def isel(self, indexers=None, return_index=False, **indexers_kwargs):
)

indexers = {k: as_pandas_index(v, self.sizes[k]) for k, v in indexers.items()}
nodedim, edgedim = self.dims
nodedim = self.node_dimension
edgedim = self.edge_dimension

edge_index = {}
if nodedim in indexers:
node_index = indexers[nodedim]
Expand Down
12 changes: 7 additions & 5 deletions xugrid/ugrid/ugrid2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,12 @@ def core_dimension(self):

@property
def dims(self):
"""Tuple of UGRID dimension names: node dimension, edge dimension, face_dimension."""
# Tuple to preserve order, unlike set.
return (
"""Set of UGRID dimension names: node dimension, edge dimension, face_dimension."""
return {
self.node_dimension,
self.edge_dimension,
self.face_dimension,
)
}

@property
def sizes(self):
Expand Down Expand Up @@ -1202,7 +1201,10 @@ def isel(self, indexers=None, return_index=False, **indexers_kwargs):
)

indexers = {k: as_pandas_index(v, self.sizes[k]) for k, v in indexers.items()}
nodedim, edgedim, facedim = self.dims
nodedim = self.node_dimension
edgedim = self.edge_dimension
facedim = self.face_dimension

face_index = {}
if nodedim in indexers:
node_index = indexers[nodedim]
Expand Down
Loading

0 comments on commit b55adaf

Please sign in to comment.