Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
JoerivanEngelen committed Feb 9, 2024
1 parent d62f736 commit dbe49f1
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions tests/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,9 @@ def generate_mesh_2d(nx, ny, name="mesh2d"):


def generate_mesh_1d(n, name="mesh1d"):
points = [
(p,p) for p in np.linspace(0, n, n+1)
]
connectivity = [
[it, it+1] for it in range(n)
]

points = [(p, p) for p in np.linspace(0, n, n + 1)]
connectivity = [[it, it + 1] for it in range(n)]

return xu.Ugrid1d(*np.array(points).T, -1, np.array(connectivity), name=name)


Expand Down Expand Up @@ -232,12 +228,13 @@ def test_merge_partitions__errors(self):
):
pt.merge_partitions([self.datasets[0], dataset3])


class TestMergeDataset1D:
@pytest.fixture(autouse=True)
def setup(self):
grid = generate_mesh_1d(6, "mesh1d")
# TODO: If partitioning implemented for 1D grids, replace with that.
i_edges = [[0,1,2], [3,4,5]]
i_edges = [[0, 1, 2], [3, 4, 5]]
parts = [grid.isel(mesh1d_nEdges=np.array(ls)) for ls in i_edges]

values_parts = [np.arange(part.n_edge) for part in parts]
Expand Down Expand Up @@ -270,21 +267,24 @@ def test_merge_partitions(self):
assert self.dataset_expected["a"].equals(merged["a"])
assert self.dataset_expected.equals(merged)


class TestMultiTopology1D2DMergePartitions:
@pytest.fixture(autouse=True)
def setup(self):
grid_a = generate_mesh_2d(2, 3, "mesh2d")
grid_b = generate_mesh_1d(6, "mesh1d")
parts_a = grid_a.partition(n_part=2)
# TODO: If partitioning implemented for 1D grids, replace with that.
i_edges = [[0,1,2], [3,4,5]]
i_edges = [[0, 1, 2], [3, 4, 5]]
parts_b = [grid_b.isel(mesh1d_nEdges=np.array(ls)) for ls in i_edges]

values_parts_a = [np.arange(part.n_face) for part in parts_a]
values_parts_a = [np.arange(part.n_face) for part in parts_a]
values_parts_b = [np.arange(part.n_edge) for part in parts_b]

datasets_parts = []
for i, (part_a, part_b, values_a, values_b) in enumerate(zip(parts_a, parts_b, values_parts_a, values_parts_b)):
for i, (part_a, part_b, values_a, values_b) in enumerate(
zip(parts_a, parts_b, values_parts_a, values_parts_b)
):
ds = xu.UgridDataset(grids=[part_a, part_b])
ds["a"] = ((part_a.face_dimension), values_a)
ds["b"] = ((part_b.edge_dimension), values_b)
Expand All @@ -296,7 +296,10 @@ def setup(self):
ds_expected["b"] = ((grid_b.edge_dimension), np.concatenate(values_parts_b))
ds_expected["c"] = 0
# Assign coordinates also added during merge_partitions
coords = {grid_a.face_dimension: np.arange(grid_a.n_face), grid_b.edge_dimension: np.arange(grid_b.n_edge)}
coords = {
grid_a.face_dimension: np.arange(grid_a.n_face),
grid_b.edge_dimension: np.arange(grid_b.n_edge),
}
ds_expected = ds_expected.assign_coords(**coords)

self.datasets_parts = datasets_parts
Expand All @@ -309,4 +312,4 @@ def test_merge_partitions(self):
# In case of non-UGRID data, it should default to the first partition:
assert merged["c"] == 0

assert self.dataset_expected.equals(merged)
assert self.dataset_expected.equals(merged)

0 comments on commit dbe49f1

Please sign in to comment.