diff --git a/tests/test_partitioning.py b/tests/test_partitioning.py index c72b13f5b..19b3a3b4c 100644 --- a/tests/test_partitioning.py +++ b/tests/test_partitioning.py @@ -141,7 +141,7 @@ def test_merge_partitions__errors(self): partitions[1]["extra"] = (grid1.face_dimension, np.ones(grid1.n_face)) with pytest.raises( ValueError, - match="'extra' does not occur not in all partitions with 'mesh2d'", + match="Missing variables: {'extra'} in partition", ): pt.merge_partitions(partitions) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 839fbed5e..2edc350f3 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -214,19 +214,6 @@ def validate_partition_objects( return None -def validate_vars_in_all_data_objects( - vars: list[str], data_objects: list[xr.Dataset], gridname: str -): - for var in vars: - var_in_objects = [ - True if var in obj.variables else False for obj in data_objects - ] - if not all(var_in_objects): - raise ValueError( - f"'{var}' does not occur not in all partitions with '{gridname}'" - ) - return None - def separate_variables(objects_by_gridname: defaultdict[str, xr.Dataset], ugrid_dims: set[str]): """Separate into UGRID variables grouped by dimension, and other variables.""" @@ -276,22 +263,45 @@ def all_equal(iterator): return grouped, other -def maybe_pad_connectivity_dims_to_max(selection, merged_grid): - nmax_dict = merged_grid.max_connectivity_sizes - nmax_dict = { - key: value for key, value in nmax_dict.items() if key in selection[0].dims - } - if not nmax_dict: - return selection - - pad_width_ls = [ - {dim: (0, nmax - obj.sizes[dim]) for dim, nmax in nmax_dict.items()} - for obj in selection - ] - - return [ - obj.pad(pad_width=pad_width) for obj, pad_width in zip(selection, pad_width_ls) - ] +def merge_data_along_dim( + data_objects: list[xr.Dataset], + vars: list[str], + merge_dim: str, + indexes: list[np.array], + merged_grid: UgridType, +) -> xr.Dataset: + """" + Select variables from the data objects. + Pad connectivity dims if needed. + Concatenate along dim. + """ + max_sizes = merged_grid.max_connectivity_sizes + ugrid_connectivity_dims = set(max_sizes) + + to_merge = [] + for obj, index in zip(data_objects, indexes): + # Check for presence of vars + missing_vars = set(vars).difference(set(obj.variables.keys())) + if missing_vars: + raise ValueError(f"Missing variables: {missing_vars} in partition {obj}") + + selection = obj[vars].isel({merge_dim: index}, missing_dims="ignore") + + # Pad the ugrid connectivity dims (e.g. n_max_face_node_connectivity) if + # needed. + present_dims = ugrid_connectivity_dims.intersection(selection.dims) + pad_width = {} + for dim in present_dims: + nmax = max_sizes[dim] + size = selection.sizes[dim] + if size != nmax: + pad_width[dim] = (0, nmax - size) + if pad_width: + selection = selection.pad(pad_width=pad_width) + + to_merge.append(selection) + + return xr.concat(to_merge, dim=merge_dim) def merge_partitions(partitions): @@ -353,21 +363,11 @@ def merge_partitions(partitions): other_vars_obj = set(other_vars).intersection(set(obj.data_vars)) merged.update(obj[other_vars_obj]) - # Now remove duplicates, then concatenate along the UGRID dimension. for dim, dim_indexes in indexes.items(): vars = vars_by_dim[dim] if len(vars) == 0: continue - validate_vars_in_all_data_objects(vars, data_objects, gridname) - selection = [ - obj[vars].isel({dim: index}, missing_dims="ignore") - for obj, index in zip(data_objects, dim_indexes) - ] - selection_padded = maybe_pad_connectivity_dims_to_max( - selection, merged_grid - ) - - merged_selection = xr.concat(selection_padded, dim=dim) + merged_selection = merge_data_along_dim(data_objects, vars, dim, dim_indexes, merged_grid) merged.update(merged_selection) return UgridDataset(merged, merged_grids)