Skip to content

Commit

Permalink
Simplify logic with reviewer's suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
JoerivanEngelen committed Feb 14, 2024
1 parent 4b262e4 commit 365ee19
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 41 deletions.
2 changes: 1 addition & 1 deletion tests/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_merge_partitions__errors(self):
partitions[1]["extra"] = (grid1.face_dimension, np.ones(grid1.n_face))
with pytest.raises(
ValueError,
match="'extra' does not occur not in all partitions with 'mesh2d'",
match="Missing variables: {'extra'} in partition",
):
pt.merge_partitions(partitions)

Expand Down
80 changes: 40 additions & 40 deletions xugrid/ugrid/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,6 @@ def validate_partition_objects(
return None


def validate_vars_in_all_data_objects(
vars: list[str], data_objects: list[xr.Dataset], gridname: str
):
for var in vars:
var_in_objects = [
True if var in obj.variables else False for obj in data_objects
]
if not all(var_in_objects):
raise ValueError(
f"'{var}' does not occur not in all partitions with '{gridname}'"
)
return None


def separate_variables(objects_by_gridname: defaultdict[str, xr.Dataset], ugrid_dims: set[str]):
"""Separate into UGRID variables grouped by dimension, and other variables."""
Expand Down Expand Up @@ -276,22 +263,45 @@ def all_equal(iterator):
return grouped, other


def maybe_pad_connectivity_dims_to_max(selection, merged_grid):
nmax_dict = merged_grid.max_connectivity_sizes
nmax_dict = {
key: value for key, value in nmax_dict.items() if key in selection[0].dims
}
if not nmax_dict:
return selection

pad_width_ls = [
{dim: (0, nmax - obj.sizes[dim]) for dim, nmax in nmax_dict.items()}
for obj in selection
]

return [
obj.pad(pad_width=pad_width) for obj, pad_width in zip(selection, pad_width_ls)
]
def merge_data_along_dim(
data_objects: list[xr.Dataset],
vars: list[str],
merge_dim: str,
indexes: list[np.array],
merged_grid: UgridType,
) -> xr.Dataset:
""""
Select variables from the data objects.
Pad connectivity dims if needed.
Concatenate along dim.
"""
max_sizes = merged_grid.max_connectivity_sizes
ugrid_connectivity_dims = set(max_sizes)

to_merge = []
for obj, index in zip(data_objects, indexes):
# Check for presence of vars
missing_vars = set(vars).difference(set(obj.variables.keys()))
if missing_vars:
raise ValueError(f"Missing variables: {missing_vars} in partition {obj}")

selection = obj[vars].isel({merge_dim: index}, missing_dims="ignore")

# Pad the ugrid connectivity dims (e.g. n_max_face_node_connectivity) if
# needed.
present_dims = ugrid_connectivity_dims.intersection(selection.dims)
pad_width = {}
for dim in present_dims:
nmax = max_sizes[dim]
size = selection.sizes[dim]
if size != nmax:
pad_width[dim] = (0, nmax - size)
if pad_width:
selection = selection.pad(pad_width=pad_width)

to_merge.append(selection)

return xr.concat(to_merge, dim=merge_dim)


def merge_partitions(partitions):
Expand Down Expand Up @@ -353,21 +363,11 @@ def merge_partitions(partitions):
other_vars_obj = set(other_vars).intersection(set(obj.data_vars))
merged.update(obj[other_vars_obj])

# Now remove duplicates, then concatenate along the UGRID dimension.
for dim, dim_indexes in indexes.items():
vars = vars_by_dim[dim]
if len(vars) == 0:
continue
validate_vars_in_all_data_objects(vars, data_objects, gridname)
selection = [
obj[vars].isel({dim: index}, missing_dims="ignore")
for obj, index in zip(data_objects, dim_indexes)
]
selection_padded = maybe_pad_connectivity_dims_to_max(
selection, merged_grid
)

merged_selection = xr.concat(selection_padded, dim=dim)
merged_selection = merge_data_along_dim(data_objects, vars, dim, dim_indexes, merged_grid)
merged.update(merged_selection)

return UgridDataset(merged, merged_grids)

0 comments on commit 365ee19

Please sign in to comment.