Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN threshold for conservative method #39

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 102 additions & 47 deletions benchmarks/benchmarking_conservative.ipynb

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ dev = [
"pytest",
"pytest-cov",
"pandas-stubs", # Adds typing for pandas.
"cftime",
]

[tool.hatch.version]
Expand Down Expand Up @@ -84,7 +85,6 @@ select = [
"E",
"EM",
"F",
"FBT",
BSchilperoort marked this conversation as resolved.
Show resolved Hide resolved
"I",
"ICN",
"ISC",
Expand All @@ -105,8 +105,6 @@ select = [
ignore = [
# Allow non-abstract empty methods in abstract base classes
"B027",
# Allow boolean positional values in function calls, like `dict.get(... True)`
"FBT003",
# Ignore checks for possible passwords
"S105", "S106", "S107",
# Ignore complexity
Expand Down
267 changes: 119 additions & 148 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections.abc import Hashable
from typing import overload

import dask.array
import numpy as np
import xarray as xr

Expand All @@ -14,6 +13,8 @@ def conservative_regrid(
data: xr.DataArray,
target_ds: xr.Dataset,
latitude_coord: str | None,
skipna: bool = True,
nan_threshold: float = 1.0,
) -> xr.DataArray:
...

Expand All @@ -23,6 +24,8 @@ def conservative_regrid(
data: xr.Dataset,
target_ds: xr.Dataset,
latitude_coord: str | None,
skipna: bool = True,
nan_threshold: float = 1.0,
) -> xr.Dataset:
...

Expand All @@ -31,6 +34,8 @@ def conservative_regrid(
data: xr.DataArray | xr.Dataset,
target_ds: xr.Dataset,
latitude_coord: str | None,
skipna: bool = True,
nan_threshold: float = 1.0,
) -> xr.DataArray | xr.Dataset:
"""Refine a dataset using conservative regridding.

Expand All @@ -44,32 +49,40 @@ def conservative_regrid(
Args:
data: Input dataset.
target_ds: Dataset which coordinates the input dataset should be regrid to.
latitude_coord: Name of the latitude coordinate. If not provided, attempt to
infer it from the first coordinate containing the string 'lat'.
skipna: If True, enable handling for NaN values. This adds some overhead,
so should be disabled for optimal performance on data without NaNs.
nan_threshold: Threshold value that will retain any output points containing
at least this many non-null input points. The default value is 1.0,
which will keep output points containing any non-null inputs. The threshold
is applied sequentially to each dimension, and may produce different results
than a threshold applied concurrently to all regridding dimensions.

Returns:
Regridded input dataset
"""
if latitude_coord is not None:
if latitude_coord not in data.coords:
msg = "Latitude coord not in input data!"
raise ValueError(msg)
else:
latitude_coord = ""

dim_order = list(target_ds.dims)

coord_names = set(target_ds.coords).intersection(set(data.coords))
target_ds_sorted = target_ds.sortby(list(coord_names))
coords = {name: target_ds_sorted[name] for name in coord_names}
# Attempt to infer the latitude coordinate
BSchilperoort marked this conversation as resolved.
Show resolved Hide resolved
if latitude_coord is None:
for coord in data.coords:
if coord.lower().startswith("lat"):
latitude_coord = coord
break

# Make sure the regridding coordinates are sorted
coord_names = [coord for coord in target_ds.coords if coord in data.coords]
target_ds_sorted = target_ds.sortby(coord_names)
data = data.sortby(list(coord_names))
coords = {name: target_ds_sorted[name] for name in coord_names}

if isinstance(data, xr.Dataset):
regridded_data = conservative_regrid_dataset(
data, coords, latitude_coord
).transpose(*dim_order, ...)
else:
regridded_data = conservative_regrid_dataarray( # type: ignore
data, coords, latitude_coord
).transpose(*dim_order, ...)
regridded_data = utils.call_on_dataset(
BSchilperoort marked this conversation as resolved.
Show resolved Hide resolved
conservative_regrid_dataset,
data,
coords,
latitude_coord,
skipna,
nan_threshold,
)

regridded_data = regridded_data.reindex_like(target_ds, copy=False)

Expand All @@ -80,20 +93,19 @@ def conservative_regrid_dataset(
data: xr.Dataset,
coords: dict[Hashable, xr.DataArray],
latitude_coord: str,
skipna: bool,
nan_threshold: float,
) -> xr.Dataset:
"""Dataset implementation of the conservative regridding method."""
data_vars = list(data.data_vars)
data_coords = list(data.coords)
dataarrays = [data[var] for var in data_vars]
data_vars = dict(data.data_vars)
data_coords = dict(data.coords)
valid_fracs = {v: None for v in data_vars}
data_attrs = {v: data_vars[v].attrs for v in data_vars}
coord_attrs = {c: data_coords[c].attrs for c in data_coords}
ds_attrs = data.attrs

attrs = data.attrs
da_attrs = [da.attrs for da in dataarrays]
coord_attrs = [data[coord].attrs for coord in data_coords]

# track which target coordinate values are not covered by the source grid
uncovered_target_grid = {}
for coord in coords:
uncovered_target_grid[coord] = (coords[coord] <= data[coord].max()) & (
covered_grid = (coords[coord] <= data[coord].max()) & (
coords[coord] >= data[coord].min()
)

Expand All @@ -102,132 +114,96 @@ def conservative_regrid_dataset(
weights = get_weights(source_coords, target_coords)

# Modify weights to correct for latitude distortion
weights = utils.create_dot_dataarray(
weights, str(coord), target_coords, source_coords
)
if str(coord) == latitude_coord:
dot_array = utils.create_dot_dataarray(
weights, str(coord), target_coords, source_coords
)
dot_array = apply_spherical_correction(dot_array, latitude_coord)
weights = dot_array.to_numpy()

for i in range(len(dataarrays)):
if coord in dataarrays[i].coords:
da = dataarrays[i].transpose(coord, ...)
dataarrays[i] = apply_weights(da, weights, coord, target_coords)

for da, attr in zip(dataarrays, da_attrs, strict=True):
da.attrs = attr
regridded = xr.merge(dataarrays)

# Replace zeros outside of original data grid with NaNs
for coord in coords:
regridded = regridded.where(uncovered_target_grid[coord])

regridded.attrs = attrs

new_coords = [regridded[coord] for coord in data_coords]
for coord, attr in zip(new_coords, coord_attrs, strict=True):
coord.attrs = attr

return regridded # TODO: add other coordinates/data variables back in.


def conservative_regrid_dataarray(
data: xr.DataArray,
coords: dict[Hashable, xr.DataArray],
latitude_coord: str,
) -> xr.DataArray:
"""DataArray implementation of the conservative regridding method."""
data_coords = list(data.coords)

attrs = data.attrs
coord_attrs = [data[coord].attrs for coord in data_coords]
weights = apply_spherical_correction(weights, latitude_coord)

for array in data_vars.keys():
non_grid_dims = [d for d in data_vars[array].dims if d not in coords]
if coord in data_vars[array].dims:
data_vars[array], valid_fracs[array] = apply_weights(
data_vars[array],
weights,
coord,
valid_fracs[array],
skipna,
non_grid_dims,
)
# Mask out any regridded points outside the original domain
data_vars[array] = data_vars[array].where(covered_grid)

for coord in coords:
uncovered_target_grid = (coords[coord] <= data[coord].max()) & (
coords[coord] >= data[coord].min()
)
if skipna:
# Mask out any points that don't meet the nan threshold
valid_threshold = get_valid_threshold(nan_threshold)
for array, da in data_vars.items():
data_vars[array] = da.where(valid_fracs[array] >= valid_threshold)

if coord in data.coords:
target_coords = coords[coord].to_numpy()
source_coords = data[coord].to_numpy()
for array, attrs in data_attrs.items():
data_vars[array].attrs = attrs

weights = get_weights(source_coords, target_coords)
ds_regridded = xr.Dataset(data_vars=data_vars, attrs=ds_attrs)

# Modify weights to correct for latitude distortion
if str(coord) == latitude_coord:
dot_array = utils.create_dot_dataarray(
weights, str(coord), target_coords, source_coords
)
dot_array = apply_spherical_correction(dot_array, latitude_coord)
weights = dot_array.to_numpy()
for coord, attrs in coord_attrs.items():
if coord not in ds_regridded.coords:
# Add back any additional coordinates from the original dataset
ds_regridded[coord] = data_coords[coord]
ds_regridded[coord].attrs = attrs

data = data.transpose(coord, ...)
data = apply_weights(data, weights, coord, target_coords)
return ds_regridded

# Replace zeros outside of original data grid with NaNs
data = data.where(uncovered_target_grid)

new_coords = [data[coord] for coord in data_coords]
for coord, attr in zip(new_coords, coord_attrs, strict=True):
coord.attrs = attr
data.attrs = attrs
def apply_weights(
da: xr.DataArray,
weights: np.ndarray,
coord_name: Hashable,
valid_frac: xr.DataArray | None,
skipna: bool,
non_grid_dims: list[Hashable],
) -> tuple[xr.DataArray, xr.DataArray]:
"""Apply the weights to convert data to the new coordinates."""
coord_map = {f"target_{coord_name}": coord_name}
weights_norm = weights.copy()

return data
if skipna:
notnull = da.notnull()
if non_grid_dims:
notnull = notnull.any(non_grid_dims)
# Renormalize the weights along this dim by the accumulated valid_frac
# along previous dimensions
if valid_frac is not None:
weights_norm = weights * valid_frac / valid_frac.mean(coord_name)
BSchilperoort marked this conversation as resolved.
Show resolved Hide resolved

da_reduced = xr.dot(da.fillna(0), weights_norm, dim=coord_name, optimize=True)
da_reduced = da_reduced.rename(coord_map).transpose(*da.dims)

def apply_weights(
da: xr.DataArray, weights: np.ndarray, coord_name: Hashable, new_coords: np.ndarray
) -> xr.DataArray:
"""Apply the weights to convert data to the new coordinates."""
new_data: np.ndarray | dask.array.Array
if da.chunks is not None:
# Dask routine
new_data = compute_einsum_dask(da, weights)
else:
# numpy routine
new_data = compute_einsum_numpy(da, weights)

coord_mapping = {coord_name: new_coords}
coords = list(da.dims)
coords.remove(coord_name)
for coord in coords:
coord_mapping[coord] = da[coord].to_numpy()
if skipna:
weights_valid_sum = xr.dot(
weights_norm, notnull, dim=coord_name, optimize=True
BSchilperoort marked this conversation as resolved.
Show resolved Hide resolved
).rename(coord_map)
da_reduced /= weights_valid_sum.clip(1e-6, None)

return xr.DataArray(
data=new_data,
coords=coord_mapping,
name=da.name,
)
if valid_frac is None:
# Begin tracking the valid fraction
valid_frac = weights_valid_sum

else:
# Update the valid points on this dimension
valid_frac = xr.dot(
valid_frac, weights, dim=coord_name, optimize=True
).rename(coord_map)
valid_frac = valid_frac.clip(0, 1)

def compute_einsum_dask(da: xr.DataArray, weights: np.ndarray) -> dask.array.Array:
"""Compute the einsum between dask data and weights, and mask NaNs if needed."""
new_data: dask.array.Array
if np.any(np.isnan(da.data)):
new_data = dask.array.einsum(
"i...,ij->j...", da.fillna(0).data, weights, optimize="greedy"
)
isnan = dask.array.einsum(
"i...,ij->j...", np.isnan(da.data), weights, optimize="greedy"
)
new_data[isnan > 0] = np.nan
else:
new_data = dask.array.einsum(
"i...,ij->j...", da.data, weights, optimize="greedy"
)
return new_data
return da_reduced, valid_frac


def compute_einsum_numpy(da: xr.DataArray, weights: np.ndarray) -> np.ndarray:
"""Compute the einsum between numpy data and weights, and mask NaNs if needed."""
new_data: np.ndarray
if np.any(np.isnan(da.data)):
new_data = np.einsum("i...,ij->j...", da.fillna(0).data, weights)
isnan = np.einsum("i...,ij->j...", np.isnan(da.data), weights)
new_data[isnan > 0] = np.nan
else:
new_data = np.einsum("i...,ij->j...", da.data, weights)
return new_data
def get_valid_threshold(nan_threshold: float) -> float:
"""Invert the nan_threshold and coerce it to just above zero and below
one to handle numerical precision limitations in the weight sum."""
# This matches xesmf where na_thresh=0 keeps points with any valid data
valid_threshold = 1 - np.clip(nan_threshold, 1e-6, 1.0 - 1e-6)
return valid_threshold


def get_weights(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndarray:
Expand All @@ -240,14 +216,9 @@ def get_weights(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndar
Returns:
Weights, which can be used with a dot product to apply the conservative regrid.
"""
# TODO: better resolution/IntervalIndex inference
target_intervals = utils.to_intervalindex(
target_coords, resolution=target_coords[1] - target_coords[0]
)
target_intervals = utils.to_intervalindex(target_coords)
source_intervals = utils.to_intervalindex(source_coords)

source_intervals = utils.to_intervalindex(
source_coords, resolution=source_coords[1] - source_coords[0]
)
overlap = utils.overlap(source_intervals, target_intervals)
return utils.normalize_overlap(overlap)

Expand Down
Loading
Loading