Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spherical padding and faster tests #45

Merged
merged 8 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,13 @@ def conservative_regrid_dataset(
weights = apply_spherical_correction(weights, latitude_coord)

for array in data_vars.keys():
non_grid_dims = [d for d in data_vars[array].dims if d not in coords]
if coord in data_vars[array].dims:
data_vars[array], valid_fracs[array] = apply_weights(
da=data_vars[array],
weights=weights,
coord=coord,
valid_frac=valid_fracs[array],
skipna=skipna,
non_grid_dims=non_grid_dims,
)
# Mask out any regridded points outside the original domain
data_vars[array] = data_vars[array].where(covered_grid)
Expand Down Expand Up @@ -161,16 +159,13 @@ def apply_weights(
coord: Hashable,
valid_frac: xr.DataArray,
skipna: bool,
non_grid_dims: list[Hashable],
) -> tuple[xr.DataArray, xr.DataArray]:
"""Apply the weights to convert data to the new coordinates."""
coord_map = {f"target_{coord}": coord}
weights_norm = weights.copy()

if skipna:
notnull = da.notnull()
if non_grid_dims:
slevang marked this conversation as resolved.
Show resolved Hide resolved
notnull = notnull.any(non_grid_dims)
# Renormalize the weights along this dim by the accumulated valid_frac
# along previous dimensions
if valid_frac.name != EMPTY_DA_NAME:
Expand Down
21 changes: 15 additions & 6 deletions src/xarray_regrid/regrid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import xarray as xr

from xarray_regrid.methods import conservative, interp, most_common
from xarray_regrid.utils import format_for_regrid


@xr.register_dataarray_accessor("regrid")
Expand Down Expand Up @@ -34,7 +35,8 @@ def linear(
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
return interp.interp_regrid(self._obj, ds_target_grid, "linear")
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "linear")

def nearest(
self,
Expand All @@ -51,14 +53,14 @@ def nearest(
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
return interp.interp_regrid(self._obj, ds_target_grid, "nearest")
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "nearest")

def cubic(
self,
ds_target_grid: xr.Dataset,
time_dim: str = "time",
) -> xr.DataArray | xr.Dataset:
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
"""Regrid to the coords of the target dataset with cubic interpolation.

Args:
Expand All @@ -68,7 +70,9 @@ def cubic(
Returns:
Data regridded to the target dataset coordinates.
"""
return interp.interp_regrid(self._obj, ds_target_grid, "cubic")
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic")

def conservative(
self,
Expand All @@ -88,6 +92,9 @@ def conservative(
time_dim: The name of the time dimension/coordinate.
skipna: If True, enable handling for NaN values. This adds some overhead,
so can be disabled for optimal performance on data without any NaNs.
With `skipna=True, chunking is recommended in the non-grid dimensions,
otherwise the intermediate arrays that track the fraction of valid data
can become very large and consume excessive memory.
Warning: with `skipna=False`, isolated NaNs will propagate throughout
the dataset due to the sequential regridding scheme over each dimension.
nan_threshold: Threshold value that will retain any output points
Expand All @@ -104,8 +111,9 @@ def conservative(
raise ValueError(msg)

ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return conservative.conservative_regrid(
self._obj, ds_target_grid, latitude_coord, skipna, nan_threshold
ds_formatted, ds_target_grid, latitude_coord, skipna, nan_threshold
)

def most_common(
Expand Down Expand Up @@ -134,8 +142,9 @@ def most_common(
Regridded data.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return most_common.most_common_wrapper(
self._obj, ds_target_grid, time_dim, max_mem
ds_formatted, ds_target_grid, time_dim, max_mem
)


Expand Down
159 changes: 156 additions & 3 deletions src/xarray_regrid/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Callable
from collections.abc import Callable, Hashable
from dataclasses import dataclass
from typing import Any, overload
from typing import Any, TypedDict, overload

import numpy as np
import pandas as pd
Expand All @@ -10,6 +10,11 @@
class InvalidBoundsError(Exception): ...


class CoordHandler(TypedDict):
names: list[str]
func: Callable


@dataclass
class Grid:
"""Object storing grid information."""
Expand Down Expand Up @@ -75,7 +80,7 @@ def create_lat_lon_coords(grid: Grid) -> tuple[np.ndarray, np.ndarray]:
grid.south, grid.north + grid.resolution_lat, grid.resolution_lat
)

if np.remainder((grid.north - grid.south), grid.resolution_lat) > 0:
if np.remainder((grid.east - grid.west), grid.resolution_lat) > 0:
lon_coords = np.arange(grid.west, grid.east, grid.resolution_lon)
slevang marked this conversation as resolved.
Show resolved Hide resolved
else:
lon_coords = np.arange(
Expand Down Expand Up @@ -235,3 +240,151 @@ def call_on_dataset(
return next(iter(result.data_vars.values())).rename(obj.name)

return result


def format_for_regrid(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset
) -> xr.DataArray | xr.Dataset:
"""Apply any pre-formatting to the input dataset to prepare for regridding.
Currently handles padding of spherical geometry if lat/lon coordinates can
be inferred and the domain size requires boundary padding.
"""
orig_chunksizes = obj.chunksizes

# Special-cased coordinates with accepted names and formatting function
coord_handlers: dict[str, CoordHandler] = {
"lat": {"names": ["lat", "latitude"], "func": format_lat},
"lon": {"names": ["lon", "longitude"], "func": format_lon},
}
# Identify coordinates that need to be formatted
formatted_coords = {}
for coord_type, handler in coord_handlers.items():
for coord in obj.coords.keys():
if str(coord).lower() in handler["names"]:
formatted_coords[coord_type] = str(coord)

# Apply formatting
for coord_type, coord in formatted_coords.items():
# Make sure formatted coords are sorted
obj = obj.sortby(coord)
target = target.sortby(coord)
obj = coord_handlers[coord_type]["func"](obj, target, formatted_coords)
# Coerce back to a single chunk if that's what was passed
if len(orig_chunksizes.get(coord, [])) == 1:
obj = obj.chunk({coord: -1})

return obj


def format_lat(
obj: xr.DataArray | xr.Dataset,
target: xr.Dataset, # noqa ARG001
formatted_coords: dict[str, str],
) -> xr.DataArray | xr.Dataset:
"""If the latitude coordinate is inferred to be global, defined as having
a value within one grid spacing of the poles, and the grid does not natively
slevang marked this conversation as resolved.
Show resolved Hide resolved
have values at -90 and 90, add a single value at each pole computed as the
mean of the first and last latitude bands. This should be roughly equivalent
to the `Pole="all"` option in `ESMF`.

For example, with a grid spacing of 1 degree, and a source grid ranging from
-89.5 to 89.5, the poles would be padded with values at -90 and 90. A grid ranging
from -88 to 88 would not be padded because coverage does not extend all the way
to the poles. A grid ranging from -90 to 90 would also not be padded because the
poles will already be covered in the regridding weights.
"""
lat_coord = formatted_coords["lat"]
lon_coord = formatted_coords.get("lon")

# Concat a padded value representing the mean of the first/last lat bands
# This should match the Pole="all" option of ESMF
# TODO: with cos(90) = 0 weighting, these weights might be 0?

polar_lat = 90
dy = obj.coords[lat_coord].diff(lat_coord).max().values.item()

# Only pad if global but don't have edge values directly at poles
# NOTE: could use xr.pad here instead of xr.concat, but none of the
# modes are an exact fit for this scheme
# South pole
if dy - polar_lat >= obj.coords[lat_coord].values[0] > -polar_lat:
south_pole = obj.isel({lat_coord: 0})
if lon_coord is not None:
south_pole = south_pole.mean(lon_coord)
obj = xr.concat([south_pole, obj], dim=lat_coord) # type: ignore
obj.coords[lat_coord].values[0] = -polar_lat

# North pole
if polar_lat - dy <= obj.coords[lat_coord].values[-1] < polar_lat:
north_pole = obj.isel({lat_coord: -1})
if lon_coord is not None:
north_pole = north_pole.mean(lon_coord)
obj = xr.concat([obj, north_pole], dim=lat_coord) # type: ignore
obj.coords[lat_coord].values[-1] = polar_lat

return obj


def format_lon(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset, formatted_coords: dict[str, str]
) -> xr.DataArray | xr.Dataset:
"""Format the longitude coordinate by shifting the source grid to line up with
the target anywhere in the range of -360 to 360, and then add a single wraparound
slevang marked this conversation as resolved.
Show resolved Hide resolved
padding column if the domain is inferred to be global and the east or west edges
of the target lie outside the source grid centers.

For example, with a source grid ranging from 0.5 to 359.5 and a target grid ranging
from -180 to 180, the source grid would be shifted to -179.5 to 179.5 and then
padded on both the left and right with wraparound values at -180.5 and 180.5 to
provide full coverage for the target edge cells at -180 and 180.
"""
lon_coord = formatted_coords["lon"]

# Find a wrap point outside of the left and right bounds of the target
# This ensures we have coverage on the target and handles global > regional
source_vals = obj.coords[lon_coord].values
target_vals = target.coords[lon_coord].values
wrap_point = (target_vals[-1] + target_vals[0] + 360) / 2
source_vals = np.where(
source_vals < wrap_point - 360, source_vals + 360, source_vals
)
source_vals = np.where(source_vals > wrap_point, source_vals - 360, source_vals)
obj.coords[lon_coord].values[:] = source_vals

# Shift operations can produce duplicates
# Simplest solution is to drop them and add back when padding
obj = obj.sortby(lon_coord).drop_duplicates(lon_coord)

# Only pad if domain is global in lon
source_lon = obj.coords[lon_coord]
target_lon = target.coords[lon_coord]
dx_s = source_lon.diff(lon_coord).max().values.item()
dx_t = target_lon.diff(lon_coord).max().values.item()
is_global_lon = source_lon.max().values - source_lon.min().values >= 360 - dx_s

if is_global_lon:
left_pad = (source_lon.values[0] - target_lon.values[0] + dx_t / 2) / dx_s
right_pad = (target_lon.values[-1] - source_lon.values[-1] + dx_t / 2) / dx_s
left_pad = int(np.ceil(np.max([left_pad, 0])))
right_pad = int(np.ceil(np.max([right_pad, 0])))
obj = obj.pad({lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True)
if left_pad:
obj.coords[lon_coord].values[:left_pad] = (
source_lon.values[-left_pad:] - 360
)
if right_pad:
obj.coords[lon_coord].values[-right_pad:] = (
source_lon.values[:right_pad] + 360
)

return obj


def coord_is_covered(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset, coord: Hashable
) -> bool:
"""Check if the source coord fully covers the target coord."""
pad = target[coord].diff(coord).max().values
left_covered = obj[coord].min() <= target[coord].min() - pad
right_covered = obj[coord].max() >= target[coord].max() + pad
return bool(left_covered.item() and right_covered.item())
Loading
Loading