Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse weights in conservative method #49

Merged
merged 8 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/).

## Unreleased

# 0.3.0 (2024-09-05)
Added:
- If latitude/longitude coordinates are detected and the domain is global, apply automatic padding at the boundaries, which gives behavior more consistent with common tools like ESMF and CDO ([#45](https://github.com/xarray-contrib/xarray-regrid/pull/45)).
- Conservative regridding weights are converted to sparse matrices if the optional [sparse](https://github.com/pydata/sparse) package is installed, which improves compute and memory performance in most cases ([#49](https://github.com/xarray-contrib/xarray-regrid/pull/49)).


## 0.3.0 (2024-09-05)

New contributors:
- [@slevang](https://github.com/slevang)
Expand Down
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,23 @@ Regridding is a common operation in earth science and other fields. While xarray

## Installation

For a minimal install:
```console
pip install xarray-regrid
```

To improve performance in certain cases:
```console
pip install xarray-regrid[accel]
```

which includes optional extras such as:
- `dask`: parallelization over chunked data
- `sparse`: for performing conservative regridding using sparse weight matrices
- `opt-einsum`: optimized einsum routines used in conservative regridding

Benchmarking varies across different hardware specifications, but the inclusion of these extras can often provide significant speedups.

## Usage
The xarray-regrid routines are accessed using the "regrid" accessor on an xarray Dataset:
```py
Expand Down
5 changes: 1 addition & 4 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,4 @@ dependencies:
- xESMF
- cdo
- pip:
- "xarray-regrid"
- "cftime"
- "matplotlib"
- "dask[distributed]"
- "xarray-regrid[accel,benchmarking,dev]"
13 changes: 9 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,27 @@ Issues = "https://github.com/EXCITED-CO2/xarray-regrid/issues"
Source = "https://github.com/EXCITED-CO2/xarray-regrid"

[project.optional-dependencies]
benchmarking = [
accel = [
"sparse",
"opt-einsum",
"dask[distributed]",
]
benchmarking = [
"matplotlib",
"zarr",
"h5netcdf",
"requests",
"aiohttp",
"pooch",
"cftime", # required for decode time of test netCDF files
]
dev = [
"hatch",
"ruff",
"mypy",
"pytest",
"pytest-cov",
"cftime", # required for decode time of test netCDF files
"pandas-stubs", # Adds typing for pandas.
"cftime",
]
docs = [ # Required for ReadTheDocs
"myst_parser",
Expand All @@ -69,7 +74,7 @@ docs = [ # Required for ReadTheDocs
path = "src/xarray_regrid/__init__.py"

[tool.hatch.envs.default]
features = ["dev", "benchmarking"]
features = ["accel", "dev", "benchmarking"]

[tool.hatch.envs.default.scripts]
lint = [
Expand Down
43 changes: 40 additions & 3 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import numpy as np
import xarray as xr

try:
import sparse # type: ignore
except ImportError:
sparse = None

from xarray_regrid import utils

EMPTY_DA_NAME = "FRAC_EMPTY"
Expand Down Expand Up @@ -125,9 +130,14 @@ def conservative_regrid_dataset(

for array in data_vars.keys():
if coord in data_vars[array].dims:
if sparse is not None:
var_weights = sparsify_weights(weights, data_vars[array])
else:
var_weights = weights

data_vars[array], valid_fracs[array] = apply_weights(
da=data_vars[array],
weights=weights,
weights=var_weights,
coord=coord,
valid_frac=valid_fracs[array],
skipna=skipna,
Expand Down Expand Up @@ -171,7 +181,9 @@ def apply_weights(
# Renormalize the weights along this dim by the accumulated valid_frac
# along previous dimensions
if valid_frac.name != EMPTY_DA_NAME:
weights_norm = weights * valid_frac / valid_frac.mean(dim=[coord])
weights_norm = weights * (valid_frac / valid_frac.mean(dim=[coord])).fillna(
0
)

da_reduced: xr.DataArray = xr.dot(
da.fillna(0), weights_norm, dim=[coord], optimize=True
Expand All @@ -180,7 +192,7 @@ def apply_weights(

if skipna:
weights_valid_sum: xr.DataArray = xr.dot(
weights_norm, notnull, dim=[coord], optimize=True
notnull, weights_norm, dim=[coord], optimize=True
)
weights_valid_sum = weights_valid_sum.rename(coord_map)
da_reduced /= weights_valid_sum.clip(1e-6, None)
Expand All @@ -195,6 +207,17 @@ def apply_weights(
valid_frac = valid_frac.rename(coord_map)
valid_frac = valid_frac.clip(0, 1)

# In some cases, dot product of dask data and sparse weights fails
# to automatically densify, which prevents future conversion to numpy
if (
sparse is not None
and da_reduced.chunks
and isinstance(da_reduced.data._meta, sparse.COO)
):
da_reduced.data = da_reduced.data.map_blocks(
lambda x: x.todense(), dtype=da_reduced.dtype
)

return da_reduced, valid_frac


Expand Down Expand Up @@ -248,3 +271,17 @@ def lat_weight(latitude: np.ndarray, latitude_res: float) -> np.ndarray:
lat = np.radians(latitude)
h = np.sin(lat + dlat / 2) - np.sin(lat - dlat / 2)
return h * dlat / (np.pi * 4) # type: ignore


def sparsify_weights(weights: xr.DataArray, da: xr.DataArray) -> xr.DataArray:
"""Create a sparse version of the weights that matches the dtype and chunks
of the array to be regridded. Even though the weights can be constructed as
dense arrays, contraction is more efficient with sparse operations."""
new_weights = weights.copy().astype(da.dtype)
if da.chunks:
chunks = {k: v for k, v in da.chunksizes.items() if k in weights.dims}
new_weights.data = new_weights.chunk(chunks).data.map_blocks(sparse.COO)
else:
new_weights.data = sparse.COO(weights.data)

return new_weights
23 changes: 16 additions & 7 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"conservative": DATA_PATH / "cdo_conservative_64b.nc",
}

CHUNK_SCHEMES = [{}, {"time": 1}, {"longitude": 100, "latitude": 100}]


@pytest.fixture(scope="session")
def sample_input_data() -> xr.Dataset:
Expand Down Expand Up @@ -71,31 +73,38 @@ def conservative_sample_grid():


@pytest.mark.parametrize("method", ["linear", "nearest"])
@pytest.mark.parametrize("chunks", CHUNK_SCHEMES)
def test_basic_regridders_ds(
sample_input_data, sample_grid_ds, cdo_comparison_data, method
sample_input_data, sample_grid_ds, cdo_comparison_data, method, chunks
):
"""Test the dataset regridders (except conservative)."""
regridder = getattr(sample_input_data.regrid, method)
regridder = getattr(sample_input_data.chunk(chunks).regrid, method)
ds_regrid = regridder(sample_grid_ds)
ds_cdo = cdo_comparison_data[method]
xr.testing.assert_allclose(ds_regrid, ds_cdo, rtol=0.002, atol=2e-5)


@pytest.mark.parametrize("method", ["linear", "nearest"])
@pytest.mark.parametrize("chunks", CHUNK_SCHEMES)
def test_basic_regridders_da(
sample_input_data, sample_grid_ds, cdo_comparison_data, method
sample_input_data, sample_grid_ds, cdo_comparison_data, method, chunks
):
"""Test the dataarray regridders (except conservative)."""
regridder = getattr(sample_input_data["d2m"].regrid, method)
regridder = getattr(sample_input_data["d2m"].chunk(chunks).regrid, method)
da_regrid = regridder(sample_grid_ds)
da_cdo = cdo_comparison_data[method]["d2m"]
xr.testing.assert_allclose(da_regrid, da_cdo, rtol=0.002, atol=2e-5)


@pytest.mark.parametrize("chunks", CHUNK_SCHEMES)
def test_conservative_regridder(
conservative_input_data, conservative_sample_grid, cdo_comparison_data
conservative_input_data,
conservative_sample_grid,
cdo_comparison_data,
chunks,
):
ds_regrid = conservative_input_data.regrid.conservative(
input_data = conservative_input_data.chunk(chunks)
ds_regrid = input_data.regrid.conservative(
conservative_sample_grid, latitude_coord="latitude"
)
ds_cdo = cdo_comparison_data["conservative"]
Expand Down Expand Up @@ -201,7 +210,7 @@ def test_conservative_nan_thresholds_against_coarsen(nan_threshold):

@pytest.mark.skipif(xesmf is None, reason="xesmf required")
def test_conservative_nan_thresholds_against_xesmf():
ds = xr.tutorial.open_dataset("ersstv5").sst.compute().isel(time=[0])
ds = xr.tutorial.open_dataset("ersstv5").sst.isel(time=[0]).compute()
ds = ds.rename(lon="longitude", lat="latitude")
new_grid = xarray_regrid.Grid(
north=90,
Expand Down
Loading