Skip to content

Commit

Permalink
Prevent unsafe concurrent coordinate writes (#17)
Browse files Browse the repository at this point in the history
When concurrently writing partitions of DataArrays in a Dataset, any coordinates carried along by those DataArrays are also written concurrently. These attached coordinates do not necessarily adhere to the same chunk structure as the DataArray itself. This is an issue, since frequently they are completely unchunked, meaning that concurrent jobs attempt to write coordinates to the same blob files on disk, opening the possibility for data corruption. In practice data corruption of coordinates has been rare, but we recently encountered a situation where it occurred.

This PR fixes this issue by dropping all coordinates when doing low-level partitioned writes. As expected, now any unchunked coordinates or data variables are written once (and only once) during the store initialization step. This PR also addresses the subtle issue of writing chunked coordinates by treating them as though they were independent data variables (previously we did not have test coverage for this case, though it is admittedly rare to encounter in practice).
  • Loading branch information
spencerkclark authored May 10, 2023
1 parent 85ff3ae commit 1edae46
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 13 deletions.
52 changes: 50 additions & 2 deletions test_xpartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,15 @@ def da(request):
def _construct_dataarray(shape, chunks, name):
dims = list(string.ascii_lowercase[: len(shape)])
data = np.random.random(shape)
da = xr.DataArray(data, dims=dims, name=name)
coords = [range(length) for length in shape]
da = xr.DataArray(data, dims=dims, name=name, coords=coords)
if chunks is not None:
chunks = {dim: chunk for dim, chunk in zip(dims, chunks)}
da = da.chunk(chunks)

# Add coverage for chunked coordinates
chunked_coord_name = f"{da.name}_chunked_coord"
da = da.assign_coords({chunked_coord_name: da.chunk(chunks)})
return da


Expand Down Expand Up @@ -138,13 +143,47 @@ def ds():
return xr.merge(unchunked_dataarrays + chunked_dataarrays)


def get_files(directory):
names = os.listdir(directory)
files = []
for name in names:
path = os.path.join(directory, name)
if os.path.isfile(path):
files.append(path)
return files


def get_unchunked_variable_names(ds):
names = []
for name, variable in ds.variables.items():
if not isinstance(variable.data, dask.array.Array):
names.append(name)
return names


def checkpoint_modification_times(store, variables):
times = {}
for variable in variables:
directory = os.path.join(store, variable)
files = get_files(directory)
for file in files:
times[file] = os.path.getmtime(file)
return times


@pytest.mark.filterwarnings("ignore:Specified Dask chunks")
@pytest.mark.parametrize("ranks", [1, 2, 3, 5, 10, 11])
@pytest.mark.parametrize("collect_variable_writes", [False, True])
def test_dataset_mappable_write(tmpdir, ds, ranks, collect_variable_writes):
unchunked_variables = get_unchunked_variable_names(ds)

store = os.path.join(tmpdir, "test.zarr")
ds.partition.initialize_store(store)

# Checkpoint modification times of all files associated with unchunked
# variables. These should remain unchanged after initialization.
expected_times = checkpoint_modification_times(store, unchunked_variables)

with multiprocessing.get_context("spawn").Pool(ranks) as pool:
pool.map(
ds.partition.mappable_write(
Expand All @@ -154,8 +193,17 @@ def test_dataset_mappable_write(tmpdir, ds, ranks, collect_variable_writes):
)

result = xr.open_zarr(store)

# Check that dataset roundtrips identically.
xr.testing.assert_identical(result, ds)

# Checkpoint modification times of all files associated with unchunked
# variables after writing the chunked variables. The modification times of
# the unchunked variables should be the same as before writing the chunked
# variables.
resulting_times = checkpoint_modification_times(store, unchunked_variables)
assert expected_times == resulting_times


@pytest.mark.parametrize("has_coord", [True, False])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -317,7 +365,7 @@ def __call__(self, dsk, keys, **kwargs):


@pytest.mark.parametrize(
("collect_variable_writes", "expected_computes"), [(False, 6), (True, 3)]
("collect_variable_writes", "expected_computes"), [(False, 9), (True, 3)]
)
def test_dataset_mappable_write_minimizes_compute_calls(
tmpdir, collect_variable_writes, expected_computes
Expand Down
34 changes: 23 additions & 11 deletions xpartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def isel(self, **block_indexers) -> xr.DataArray:
def _write_partition_dataarray(
da: xr.DataArray, store: str, ranks: int, dims: Sequence[Hashable], rank: int
):
ds = da.to_dataset()
ds = da.drop_vars(da.coords).to_dataset()
partition = da.partition.indexers(ranks, rank, dims)
if partition is not None:
ds.isel(partition).to_zarr(store, region=partition)
Expand Down Expand Up @@ -214,21 +214,21 @@ def _collect_by_partition(
DataArrays that can be written out to those partitions.
"""
dataarrays = collections.defaultdict(list)
for da in ds.data_vars.values():
for da in {**ds.coords, **ds.data_vars}.values():
if isinstance(da.data, dask.array.Array):
partition_dims = [dim for dim in dims if dim in da.dims]
indexers = da.partition.indexers(ranks, rank, partition_dims)
dataarrays[freeze_indexers(indexers)].append(da)
dataarrays[freeze_indexers(indexers)].append(da.drop_vars(da.coords))
return [(unfreeze_indexers(k), xr.merge(v)) for k, v in dataarrays.items()]


def _write_partition_dataset_via_individual_variables(
ds: xr.Dataset, store: str, ranks: int, dims: Sequence[Hashable], rank: int
):
for da in ds.data_vars.values():
for da in {**ds.coords, **ds.data_vars}.values():
if isinstance(da.data, dask.array.Array):
partition_dims = [dim for dim in dims if dim in da.dims]
da.partition.write(store, ranks, partition_dims, rank)
_write_partition_dataarray(da, store, ranks, partition_dims, rank)


def _write_partition_dataset_via_collected_variables(
Expand Down Expand Up @@ -353,14 +353,27 @@ def indexers(self, ranks: int, rank: int, dims: Sequence[Hashable]) -> Region:
dask_indexers = meta_array.blocks.indexers(**meta_indexers)
return self._obj.blocks.indexers(**dask_indexers)

def write(self, store: str, ranks: int, dims: Sequence[Hashable], rank: int):
_write_partition_dataarray(self._obj, store, ranks, dims, rank)
def write(
self,
store: str,
ranks: int,
dims: Sequence[Hashable],
rank: int,
collect_variable_writes: bool = False,
):
self.to_dataset().partition.write(
store, ranks, dims, rank, collect_variable_writes
)

def mappable_write(
self, store: str, ranks: int, dims: Sequence[Hashable]
self,
store: str,
ranks: int,
dims: Sequence[Hashable],
collect_variable_writes: bool = False,
) -> Callable[[int], None]:
return functools.partial(
_write_partition_dataarray, self._obj, store, ranks, dims
return self._obj.to_dataset().partition.mappable_write(
store, ranks, dims, collect_variable_writes
)

@property
Expand Down Expand Up @@ -510,7 +523,6 @@ class _ValidWorkPlan:
"""

def __init__(self, partitioner, ranks: int, dims: Sequence[Hashable]):

self._partitioner = partitioner
self._ranks = ranks
self.dims = dims
Expand Down

0 comments on commit 1edae46

Please sign in to comment.