From 5b62b7c42724a85d038d2eea64a9782e984abb83 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 8 Nov 2023 23:41:43 -0500 Subject: [PATCH] implement auto region and transpose --- xarray/backends/api.py | 46 ++++++++++++++-- xarray/backends/zarr.py | 19 ++++--- xarray/tests/test_backends.py | 99 +++++++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+), 10 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 27e155872de..4307c807cff 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -27,6 +27,7 @@ _normalize_path, ) from xarray.backends.locks import _get_scheduler +from xarray.backends.zarr import open_zarr from xarray.core import indexing from xarray.core.combine import ( _infer_concat_order_from_positions, @@ -1446,6 +1447,44 @@ def save_mfdataset( ) +def _auto_detect_region(ds_new, ds_orig, dim): + # Create a mapping array of coordinates to indices on the original array + coord = ds_orig[dim] + da_map = DataArray(np.arange(coord.size), coords={dim: coord}) + + try: + da_idxs = da_map.sel({dim: ds_new[dim]}) + except KeyError as e: + if "not all values found" in str(e): + raise KeyError( + f"Not all values of coordinate '{dim}' in the new array were" + " found in the original store. Writing to a zarr region slice" + " requires that no dimensions or metadata are changed by the write." + ) + else: + raise e + + if (da_idxs.diff(dim) != 1).any(): + raise ValueError( + f"The auto-detected region of coordinate '{dim}' for writing new data" + " to the original store had non-contiguous indices. Writing to a zarr" + " region slice requires that the new data constitute a contiguous subset" + " of the original store." + ) + + dim_slice = slice(da_idxs.values[0], da_idxs.values[-1] + 1) + + return dim_slice + + +def _auto_detect_regions(ds, region, store): + ds_original = open_zarr(store) + for key, val in region.items(): + if val == "auto": + region[key] = _auto_detect_region(ds, ds_original, key) + return region + + def _validate_region(ds, region): if not isinstance(region, dict): raise TypeError(f"``region`` must be a dict, got {type(region)}") @@ -1532,7 +1571,7 @@ def to_zarr( compute: Literal[True] = True, consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, @@ -1556,7 +1595,7 @@ def to_zarr( compute: Literal[False], consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, @@ -1578,7 +1617,7 @@ def to_zarr( compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, - region: Mapping[str, slice] | None = None, + region: Mapping[str, slice | Literal["auto"]] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, zarr_version: int | None = None, @@ -1643,6 +1682,7 @@ def to_zarr( _validate_dataset_names(dataset) if region is not None: + region = _auto_detect_regions(dataset, region, store) _validate_region(dataset, region) if append_dim is not None and append_dim in region: raise ValueError( diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 2b41fa5224e..2dd620e0244 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -322,12 +322,15 @@ def encode_zarr_variable(var, needs_copy=True, name=None): def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim): if new_var.dims != existing_var.dims: - raise ValueError( - f"variable {var_name!r} already exists with different " - f"dimension names {existing_var.dims} != " - f"{new_var.dims}, but changing variable " - f"dimensions is not supported by to_zarr()." - ) + if set(existing_var.dims) == set(new_var.dims): + new_var = new_var.transpose(*existing_var.dims) + else: + raise ValueError( + f"variable {var_name!r} already exists with different " + f"dimension names {existing_var.dims} != " + f"{new_var.dims}, but changing variable " + f"dimensions is not supported by to_zarr()." + ) existing_sizes = {} for dim, size in existing_var.sizes.items(): @@ -347,6 +350,8 @@ def _validate_existing_dims(var_name, new_var, existing_var, region, append_dim) f"explicitly appending, but append_dim={append_dim!r}." ) + return new_var + def _put_attrs(zarr_obj, attrs): """Raise a more informative error message for invalid attrs.""" @@ -616,7 +621,7 @@ def store( for var_name in existing_variable_names: new_var = variables_encoded[var_name] existing_var = existing_vars[var_name] - _validate_existing_dims( + new_var = _validate_existing_dims( var_name, new_var, existing_var, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 73352c3f7e1..4d30c7b638b 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -5431,3 +5431,102 @@ def test_raise_writing_to_nczarr(self, mode) -> None: def test_pickle_open_mfdataset_dataset(): ds = open_example_mfdataset(["bears.nc"]) assert_identical(ds, pickle.loads(pickle.dumps(ds))) + + +@requires_zarr +class TestZarrRegionAuto: + def test_zarr_region_auto_success(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) + ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"}) + + ds_updated = xr.open_zarr(tmp_path / "test.zarr") + + expected = ds.copy() + expected["test"][2:4, 6:8] += 1 + assert_identical(ds_updated, expected) + + def test_zarr_region_auto_noncontiguous(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + ds_region = 1 + ds.isel(x=[0, 2, 3], y=[5, 6]) + with pytest.raises(ValueError): + ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"}) + + def test_zarr_region_auto_new_coord_vals(self, tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + x = np.arange(5, 55, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + + ds_region = 1 + ds.isel(x=slice(2, 4), y=slice(6, 8)) + with pytest.raises(KeyError): + ds_region.to_zarr(tmp_path / "test.zarr", region={"x": "auto", "y": "auto"}) + + +def test_zarr_region_transpose(tmp_path): + x = np.arange(0, 50, 10) + y = np.arange(0, 20, 2) + data = np.ones((5, 10)) + ds = xr.Dataset( + { + "test": xr.DataArray( + data, + dims=("x", "y"), + coords={"x": x, "y": y}, + ) + } + ) + ds.to_zarr(tmp_path / "test.zarr") + + ds_region = 1 + ds.isel(x=[0], y=[0]).transpose() + ds_region.to_zarr( + tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)} + )