Skip to content

Commit

Permalink
implement auto region and transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang committed Nov 9, 2023
1 parent feba698 commit 5b62b7c
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 10 deletions.
46 changes: 43 additions & 3 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_normalize_path,
)
from xarray.backends.locks import _get_scheduler
from xarray.backends.zarr import open_zarr
from xarray.core import indexing
from xarray.core.combine import (
_infer_concat_order_from_positions,
Expand Down Expand Up @@ -1446,6 +1447,44 @@ def save_mfdataset(
)


def _auto_detect_region(ds_new, ds_orig, dim):
# Create a mapping array of coordinates to indices on the original array
coord = ds_orig[dim]
da_map = DataArray(np.arange(coord.size), coords={dim: coord})

try:
da_idxs = da_map.sel({dim: ds_new[dim]})
except KeyError as e:
if "not all values found" in str(e):
raise KeyError(
f"Not all values of coordinate '{dim}' in the new array were"
" found in the original store. Writing to a zarr region slice"
" requires that no dimensions or metadata are changed by the write."
)
else:
raise e

if (da_idxs.diff(dim) != 1).any():
raise ValueError(
f"The auto-detected region of coordinate '{dim}' for writing new data"
" to the original store had non-contiguous indices. Writing to a zarr"
" region slice requires that the new data constitute a contiguous subset"
" of the original store."
)

dim_slice = slice(da_idxs.values[0], da_idxs.values[-1] + 1)

return dim_slice


def _auto_detect_regions(ds, region, store):
ds_original = open_zarr(store)
for key, val in region.items():
if val == "auto":
region[key] = _auto_detect_region(ds, ds_original, key)
return region


def _validate_region(ds, region):
if not isinstance(region, dict):
raise TypeError(f"``region`` must be a dict, got {type(region)}")
Expand Down Expand Up @@ -1532,7 +1571,7 @@ def to_zarr(
compute: Literal[True] = True,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
region: Mapping[str, slice] | None = None,
region: Mapping[str, slice | Literal["auto"]] | None = None,
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
Expand All @@ -1556,7 +1595,7 @@ def to_zarr(
compute: Literal[False],
consolidated: bool | None = None,
append_dim: Hashable | None = None,
region: Mapping[str, slice] | None = None,
region: Mapping[str, slice | Literal["auto"]] | None = None,
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
Expand All @@ -1578,7 +1617,7 @@ def to_zarr(
compute: bool = True,
consolidated: bool | None = None,
append_dim: Hashable | None = None,
region: Mapping[str, slice] | None = None,
region: Mapping[str, slice | Literal["auto"]] | None = None,
safe_chunks: bool = True,
storage_options: dict[str, str] | None = None,
zarr_version: int | None = None,
Expand Down Expand Up @@ -1643,6 +1682,7 @@ def to_zarr(
_validate_dataset_names(dataset)

if region is not None:
region = _auto_detect_regions(dataset, region, store)
_validate_region(dataset, region)
if append_dim is not None and append_dim in region:
raise ValueError(
Expand Down
19 changes: 12 additions & 7 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,15 @@ def encode_zarr_variable(var, needs_copy=True, name=None):

def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim):
if new_var.dims != existing_var.dims:
raise ValueError(
f"variable {var_name!r} already exists with different "
f"dimension names {existing_var.dims} != "
f"{new_var.dims}, but changing variable "
f"dimensions is not supported by to_zarr()."
)
if set(existing_var.dims) == set(new_var.dims):
new_var = new_var.transpose(*existing_var.dims)
else:
raise ValueError(
f"variable {var_name!r} already exists with different "
f"dimension names {existing_var.dims} != "
f"{new_var.dims}, but changing variable "
f"dimensions is not supported by to_zarr()."
)

existing_sizes = {}
for dim, size in existing_var.sizes.items():
Expand All @@ -347,6 +350,8 @@ def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim)
f"explicitly appending, but append_dim={append_dim!r}."
)

return new_var


def _put_attrs(zarr_obj, attrs):
"""Raise a more informative error message for invalid attrs."""
Expand Down Expand Up @@ -616,7 +621,7 @@ def store(
for var_name in existing_variable_names:
new_var = variables_encoded[var_name]
existing_var = existing_vars[var_name]
_validate_existing_dims(
new_var = _validate_existing_dims(
var_name,
new_var,
existing_var,
Expand Down
99 changes: 99 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5431,3 +5431,102 @@ def test_raise_writing_to_nczarr(self, mode) -> None:
def test_pickle_open_mfdataset_dataset():
ds = open_example_mfdataset(["bears.nc"])
assert_identical(ds, pickle.loads(pickle.dumps(ds)))


@requires_zarr
class TestZarrRegionAuto:
def test_zarr_region_auto_success(self, tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
ds = xr.Dataset(
{
"test": xr.DataArray(
data,
dims=("x", "y"),
coords={"x": x, "y": y},
)
}
)
ds.to_zarr(tmp_path / "test.zarr")

ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8))
ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"})

ds_updated = xr.open_zarr(tmp_path / "test.zarr")

expected = ds.copy()
expected["test"][2:4, 6:8] += 1
assert_identical(ds_updated, expected)

def test_zarr_region_auto_noncontiguous(self, tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
ds = xr.Dataset(
{
"test": xr.DataArray(
data,
dims=("x", "y"),
coords={"x": x, "y": y},
)
}
)
ds.to_zarr(tmp_path / "test.zarr")

ds_region = 1 + ds.isel(x=[0, 2, 3], y=[5, 6])
with pytest.raises(ValueError):
ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"})

def test_zarr_region_auto_new_coord_vals(self, tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
ds = xr.Dataset(
{
"test": xr.DataArray(
data,
dims=("x", "y"),
coords={"x": x, "y": y},
)
}
)
ds.to_zarr(tmp_path / "test.zarr")

x = np.arange(5, 55, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
ds = xr.Dataset(
{
"test": xr.DataArray(
data,
dims=("x", "y"),
coords={"x": x, "y": y},
)
}
)

ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8))
with pytest.raises(KeyError):
ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"})


def test_zarr_region_transpose(tmp_path):
x = np.arange(0, 50, 10)
y = np.arange(0, 20, 2)
data = np.ones((5, 10))
ds = xr.Dataset(
{
"test": xr.DataArray(
data,
dims=("x", "y"),
coords={"x": x, "y": y},
)
}
)
ds.to_zarr(tmp_path / "test.zarr")

ds_region = 1 + ds.isel(x=[0], y=[0]).transpose()
ds_region.to_zarr(
tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)}
)

0 comments on commit 5b62b7c

Please sign in to comment.