Skip to content

Commit

Permalink
Merge pull request #280 from Deltares/277-catch-error-when-passing-em…
Browse files Browse the repository at this point in the history
…pty-list-to-xumerge_partitions

caught empty list of partitions in xu.merge_partitions()
  • Loading branch information
Huite authored Aug 14, 2024
2 parents 7bb51d9 + 8305136 commit dc5f113
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tests/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,15 @@ def test_partition_roundtrip(self):
reordered = back.isel(mesh2d_nFaces=order)
assert reordered["face_z"].equals(self.uds["face_z"])

def test_merge_partition_single(self):
partitions = [self.uds]
back = pt.merge_partitions(partitions)
assert back == self.uds

def test_merge_partitions__errors(self):
partitions = self.uds.ugrid.partition(n_part=2)
with pytest.raises(TypeError, match="Expected UgridDataArray or UgridDataset"):
pt.merge_partitions(p.ugrid.obj for p in partitions)
pt.merge_partitions([p.ugrid.obj for p in partitions])

grid1 = partitions[1].ugrid.grid
partitions[1]["extra"] = (grid1.face_dimension, np.ones(grid1.n_face))
Expand All @@ -162,6 +167,11 @@ def test_merge_partitions__errors(self):
):
pt.merge_partitions(partitions)

with pytest.raises(
ValueError, match="Cannot merge partitions: zero partitions provided."
):
xu.merge_partitions([])

def test_merge_partitions_no_duplicates(self):
part1 = self.uds.isel(mesh2d_nFaces=[0, 1, 2, 3])
part2 = self.uds.isel(mesh2d_nFaces=[2, 3, 4, 5])
Expand Down
6 changes: 6 additions & 0 deletions xugrid/ugrid/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ def merge_partitions(partitions, merge_ugrid_chunks: bool = True):
-------
merged : UgridDataset
"""
if len(partitions) == 0:
raise ValueError("Cannot merge partitions: zero partitions provided.")
types = {type(obj) for obj in partitions}
msg = "Expected UgridDataArray or UgridDataset, received: {}"
if len(types) > 1:
Expand All @@ -337,6 +339,10 @@ def merge_partitions(partitions, merge_ugrid_chunks: bool = True):
if obj_type not in (UgridDataArray, UgridDataset):
raise TypeError(msg.format(obj_type.__name__))

# return first partition if single partition is provided
if len(partitions) == 1:
return next(iter(partitions))

# Collect grids
grids = [grid for p in partitions for grid in p.grids]
ugrid_dims = {dim for grid in grids for dim in grid.dimensions}
Expand Down

0 comments on commit dc5f113

Please sign in to comment.