Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spherical padding and faster tests #45

Merged
merged 8 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/xarray_regrid/methods/conservative.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def conservative_regrid(
# Attempt to infer the latitude coordinate
if latitude_coord is None:
for coord in data.coords:
if str(coord).lower() in ["lat", "latitude"]:
if str(coord).lower().startswith("lat"):
slevang marked this conversation as resolved.
Show resolved Hide resolved
latitude_coord = coord
break

Expand Down Expand Up @@ -122,15 +122,13 @@ def conservative_regrid_dataset(
weights = apply_spherical_correction(weights, latitude_coord)

for array in data_vars.keys():
non_grid_dims = [d for d in data_vars[array].dims if d not in coords]
if coord in data_vars[array].dims:
data_vars[array], valid_fracs[array] = apply_weights(
da=data_vars[array],
weights=weights,
coord=coord,
valid_frac=valid_fracs[array],
skipna=skipna,
non_grid_dims=non_grid_dims,
)
# Mask out any regridded points outside the original domain
data_vars[array] = data_vars[array].where(covered_grid)
Expand Down Expand Up @@ -161,16 +159,13 @@ def apply_weights(
coord: Hashable,
valid_frac: xr.DataArray,
skipna: bool,
non_grid_dims: list[Hashable],
) -> tuple[xr.DataArray, xr.DataArray]:
"""Apply the weights to convert data to the new coordinates."""
coord_map = {f"target_{coord}": coord}
weights_norm = weights.copy()

if skipna:
notnull = da.notnull()
if non_grid_dims:
slevang marked this conversation as resolved.
Show resolved Hide resolved
notnull = notnull.any(non_grid_dims)
# Renormalize the weights along this dim by the accumulated valid_frac
# along previous dimensions
if valid_frac.name != EMPTY_DA_NAME:
Expand Down
21 changes: 15 additions & 6 deletions src/xarray_regrid/regrid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import xarray as xr

from xarray_regrid.methods import conservative, interp, most_common
from xarray_regrid.utils import format_for_regrid


@xr.register_dataarray_accessor("regrid")
Expand Down Expand Up @@ -34,7 +35,8 @@ def linear(
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
return interp.interp_regrid(self._obj, ds_target_grid, "linear")
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "linear")

def nearest(
self,
Expand All @@ -51,14 +53,14 @@ def nearest(
Data regridded to the target dataset coordinates.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
return interp.interp_regrid(self._obj, ds_target_grid, "nearest")
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "nearest")

def cubic(
self,
ds_target_grid: xr.Dataset,
time_dim: str = "time",
) -> xr.DataArray | xr.Dataset:
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
"""Regrid to the coords of the target dataset with cubic interpolation.

Args:
Expand All @@ -68,7 +70,9 @@ def cubic(
Returns:
Data regridded to the target dataset coordinates.
"""
return interp.interp_regrid(self._obj, ds_target_grid, "cubic")
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic")

def conservative(
self,
Expand All @@ -88,6 +92,9 @@ def conservative(
time_dim: The name of the time dimension/coordinate.
skipna: If True, enable handling for NaN values. This adds some overhead,
so can be disabled for optimal performance on data without any NaNs.
With `skipna=True, chunking is recommended in the non-grid dimensions,
otherwise the intermediate arrays that track the fraction of valid data
can become very large and consume excessive memory.
Warning: with `skipna=False`, isolated NaNs will propagate throughout
the dataset due to the sequential regridding scheme over each dimension.
nan_threshold: Threshold value that will retain any output points
Expand All @@ -104,8 +111,9 @@ def conservative(
raise ValueError(msg)

ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return conservative.conservative_regrid(
self._obj, ds_target_grid, latitude_coord, skipna, nan_threshold
ds_formatted, ds_target_grid, latitude_coord, skipna, nan_threshold
)

def most_common(
Expand Down Expand Up @@ -134,8 +142,9 @@ def most_common(
Regridded data.
"""
ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim)
ds_formatted = format_for_regrid(self._obj, ds_target_grid)
return most_common.most_common_wrapper(
self._obj, ds_target_grid, time_dim, max_mem
ds_formatted, ds_target_grid, time_dim, max_mem
)


Expand Down
128 changes: 126 additions & 2 deletions src/xarray_regrid/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable
from collections.abc import Callable, Hashable
from dataclasses import dataclass
from typing import Any, overload

Expand Down Expand Up @@ -75,7 +75,7 @@ def create_lat_lon_coords(grid: Grid) -> tuple[np.ndarray, np.ndarray]:
grid.south, grid.north + grid.resolution_lat, grid.resolution_lat
)

if np.remainder((grid.north - grid.south), grid.resolution_lat) > 0:
if np.remainder((grid.east - grid.west), grid.resolution_lat) > 0:
lon_coords = np.arange(grid.west, grid.east, grid.resolution_lon)
slevang marked this conversation as resolved.
Show resolved Hide resolved
else:
lon_coords = np.arange(
Expand Down Expand Up @@ -235,3 +235,127 @@ def call_on_dataset(
return next(iter(result.data_vars.values())).rename(obj.name)

return result


def format_for_regrid(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset
) -> xr.DataArray | xr.Dataset:
"""Apply any pre-formatting to the input dataset to prepare for regridding.
Currently handles padding of spherical geometry if appropriate coordinate names
can be inferred containing 'lat' and 'lon'.
"""
lat_coord = None
lon_coord = None

for coord in obj.coords.keys():
if str(coord).lower().startswith("lat"):
lat_coord = coord
elif str(coord).lower().startswith("lon"):
lon_coord = coord

if lon_coord is not None or lat_coord is not None:
obj = format_spherical(obj, target, lat_coord, lon_coord)

return obj


def format_spherical(
obj: xr.DataArray | xr.Dataset,
target: xr.Dataset,
lat_coord: Hashable,
lon_coord: Hashable,
) -> xr.DataArray | xr.Dataset:
"""Infer whether a lat/lon source grid represents a global domain and
automatically apply spherical padding to improve edge effects.

For longitude, shift the coordinate to line up with the target values, then
add a single wraparound padding column if the domain is global and the east
or west edges of the target lie outside the source grid centers.

For latitude, add a single value at each pole computed as the mean of the last
row for global source grids where the first or last point lie equatorward of 90.
"""

orig_chunksizes = obj.chunksizes

# If the source coord fully covers the target, don't modify them
if lat_coord and not coord_is_covered(obj, target, lat_coord):
slevang marked this conversation as resolved.
Show resolved Hide resolved
obj = obj.sortby(lat_coord)
target = target.sortby(lat_coord)

# Only pad if global but don't have edge values directly at poles
polar_lat = 90
dy = obj[lat_coord].diff(lat_coord).max().values

# South pole
if dy - polar_lat >= obj[lat_coord][0] > -polar_lat:
south_pole = obj.isel({lat_coord: 0})
# This should match the Pole="all" option of ESMF
if lon_coord is not None:
south_pole = south_pole.mean(lon_coord)
obj = xr.concat([south_pole, obj], dim=lat_coord)
obj[lat_coord].values[0] = -polar_lat

# North pole
if polar_lat - dy <= obj[lat_coord][-1] < polar_lat:
north_pole = obj.isel({lat_coord: -1})
if lon_coord is not None:
north_pole = north_pole.mean(lon_coord)
obj = xr.concat([obj, north_pole], dim=lat_coord)
obj[lat_coord].values[-1] = polar_lat

# Coerce back to a single chunk if that's what was passed
if len(orig_chunksizes.get(lat_coord, [])) == 1:
obj = obj.chunk({lat_coord: -1})

if lon_coord and not coord_is_covered(obj, target, lon_coord):
obj = obj.sortby(lon_coord)
target = target.sortby(lon_coord)

target_lon = target[lon_coord].values
# Find a wrap point outside of the left and right bounds of the target
# This ensures we have coverage on the target and handles global > regional
wrap_point = (target_lon[-1] + target_lon[0] + 360) / 2
lon = obj[lon_coord].values
lon = np.where(lon < wrap_point - 360, lon + 360, lon)
lon = np.where(lon > wrap_point, lon - 360, lon)
obj[lon_coord].values[:] = lon

# Shift operations can produce duplicates
# Simplest solution is to drop them and add back when padding
obj = obj.sortby(lon_coord).drop_duplicates(lon_coord)

# Only pad if domain is global in lon
dx_s = obj[lon_coord].diff(lon_coord).max().values
dx_t = target[lon_coord].diff(lon_coord).max().values
is_global_lon = obj[lon_coord].max() - obj[lon_coord].min() >= 360 - dx_s

if is_global_lon:
left_pad = (obj[lon_coord][0] - target[lon_coord][0] + dx_t / 2) / dx_s
right_pad = (target[lon_coord][-1] - obj[lon_coord][-1] + dx_t / 2) / dx_s
left_pad = int(np.ceil(np.max([left_pad, 0])))
right_pad = int(np.ceil(np.max([right_pad, 0])))
lon = obj[lon_coord].values
obj = obj.pad(
{lon_coord: (left_pad, right_pad)}, mode="wrap", keep_attrs=True
)
if left_pad:
obj[lon_coord].values[:left_pad] = lon[-left_pad:] - 360
if right_pad:
obj[lon_coord].values[-right_pad:] = lon[:right_pad] + 360

# Coerce back to a single chunk if that's what was passed
if len(orig_chunksizes.get(lon_coord, [])) == 1:
obj = obj.chunk({lon_coord: -1})

return obj


def coord_is_covered(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset, coord: Hashable
) -> bool:
"""Check if the source coord fully covers the target coord."""
pad = target[coord].diff(coord).max().values
left_covered = obj[coord].min() <= target[coord].min() - pad
right_covered = obj[coord].max() >= target[coord].max() + pad
return bool(left_covered.item() and right_covered.item())
Loading
Loading