diff --git a/tests/test_partitioning.py b/tests/test_partitioning.py index 4a3a7654f..274bb87ca 100644 --- a/tests/test_partitioning.py +++ b/tests/test_partitioning.py @@ -132,6 +132,11 @@ def test_partition_roundtrip(self): reordered = back.isel(mesh2d_nFaces=order) assert reordered["face_z"].equals(self.uds["face_z"]) + def test_merge_partition_single(self): + partitions = [self.uds] + back = xu.merge_partitions(partitions) + assert back == self.uds + def test_merge_partitions__errors(self): partitions = self.uds.ugrid.partition(n_part=2) with pytest.raises(TypeError, match="Expected UgridDataArray or UgridDataset"): diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 519cf74d8..e68d30a98 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -305,6 +305,7 @@ def merge_data_along_dim( return xr.concat(to_merge, dim=merge_dim) +xu.merge_partitions([1]) def merge_partitions(partitions, merge_ugrid_chunks: bool = True): """ @@ -339,6 +340,10 @@ def merge_partitions(partitions, merge_ugrid_chunks: bool = True): if obj_type not in (UgridDataArray, UgridDataset): raise TypeError(msg.format(obj_type.__name__)) + # return first partition if single partition is provided + if len(partitions) == 1: + return next(iter(partitions)) + # Collect grids grids = [grid for p in partitions for grid in p.grids] ugrid_dims = {dim for grid in grids for dim in grid.dimensions}