diff --git a/tests/test_partitioning.py b/tests/test_partitioning.py index 3be173e91..04596cf4a 100644 --- a/tests/test_partitioning.py +++ b/tests/test_partitioning.py @@ -25,13 +25,9 @@ def generate_mesh_2d(nx, ny, name="mesh2d"): def generate_mesh_1d(n, name="mesh1d"): - points = [ - (p,p) for p in np.linspace(0, n, n+1) - ] - connectivity = [ - [it, it+1] for it in range(n) - ] - + points = [(p, p) for p in np.linspace(0, n, n + 1)] + connectivity = [[it, it + 1] for it in range(n)] + return xu.Ugrid1d(*np.array(points).T, -1, np.array(connectivity), name=name) @@ -232,12 +228,13 @@ def test_merge_partitions__errors(self): ): pt.merge_partitions([self.datasets[0], dataset3]) + class TestMergeDataset1D: @pytest.fixture(autouse=True) def setup(self): grid = generate_mesh_1d(6, "mesh1d") # TODO: If partitioning implemented for 1D grids, replace with that. - i_edges = [[0,1,2], [3,4,5]] + i_edges = [[0, 1, 2], [3, 4, 5]] parts = [grid.isel(mesh1d_nEdges=np.array(ls)) for ls in i_edges] values_parts = [np.arange(part.n_edge) for part in parts] @@ -270,6 +267,7 @@ def test_merge_partitions(self): assert self.dataset_expected["a"].equals(merged["a"]) assert self.dataset_expected.equals(merged) + class TestMultiTopology1D2DMergePartitions: @pytest.fixture(autouse=True) def setup(self): @@ -277,14 +275,16 @@ def setup(self): grid_b = generate_mesh_1d(6, "mesh1d") parts_a = grid_a.partition(n_part=2) # TODO: If partitioning implemented for 1D grids, replace with that. - i_edges = [[0,1,2], [3,4,5]] + i_edges = [[0, 1, 2], [3, 4, 5]] parts_b = [grid_b.isel(mesh1d_nEdges=np.array(ls)) for ls in i_edges] - values_parts_a = [np.arange(part.n_face) for part in parts_a] + values_parts_a = [np.arange(part.n_face) for part in parts_a] values_parts_b = [np.arange(part.n_edge) for part in parts_b] datasets_parts = [] - for i, (part_a, part_b, values_a, values_b) in enumerate(zip(parts_a, parts_b, values_parts_a, values_parts_b)): + for i, (part_a, part_b, values_a, values_b) in enumerate( + zip(parts_a, parts_b, values_parts_a, values_parts_b) + ): ds = xu.UgridDataset(grids=[part_a, part_b]) ds["a"] = ((part_a.face_dimension), values_a) ds["b"] = ((part_b.edge_dimension), values_b) @@ -296,7 +296,10 @@ def setup(self): ds_expected["b"] = ((grid_b.edge_dimension), np.concatenate(values_parts_b)) ds_expected["c"] = 0 # Assign coordinates also added during merge_partitions - coords = {grid_a.face_dimension: np.arange(grid_a.n_face), grid_b.edge_dimension: np.arange(grid_b.n_edge)} + coords = { + grid_a.face_dimension: np.arange(grid_a.n_face), + grid_b.edge_dimension: np.arange(grid_b.n_edge), + } ds_expected = ds_expected.assign_coords(**coords) self.datasets_parts = datasets_parts @@ -309,4 +312,4 @@ def test_merge_partitions(self): # In case of non-UGRID data, it should default to the first partition: assert merged["c"] == 0 - assert self.dataset_expected.equals(merged) \ No newline at end of file + assert self.dataset_expected.equals(merged)