Skip to content

Commit

Permalink
Implement per communicator universal identifiers for Firedrake objects
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Jul 19, 2024
1 parent 480df47 commit fc78d12
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 12 deletions.
4 changes: 2 additions & 2 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
coord_element = self._load_ufl_element(path, PREFIX + "_coordinate_element")
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
mesh = make_mesh_from_coordinates(coordinates, name, comm=self.comm)
if self.has_attr(path, PREFIX + "_radial_coordinates"):
radial_coord_element = self._load_ufl_element(path, PREFIX + "_radial_coordinate_element")
radial_coord_name = self.get_attr(path, PREFIX + "_radial_coordinates")
Expand Down Expand Up @@ -1102,7 +1102,7 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
coord_element = self._load_ufl_element(path, PREFIX + "_coordinate_element")
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
mesh = make_mesh_from_coordinates(coordinates, name, comm=self.comm)
# Load plex coordinates for a complete representation of plex.
tmesh.topology_dm.coordinatesLoad(self.viewer, tmesh.sfXC)
# Load cell_orientations for immersed meshes.
Expand Down
11 changes: 7 additions & 4 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2442,7 +2442,7 @@ def mark_entities(self, f, label_value, label_name=None):


@PETSc.Log.EventDecorator()
def make_mesh_from_coordinates(coordinates, name, tolerance=0.5):
def make_mesh_from_coordinates(coordinates, name, tolerance=0.5, comm=None):
"""Given a coordinate field build a new mesh, using said coordinate field.
Parameters
Expand All @@ -2462,6 +2462,9 @@ def make_mesh_from_coordinates(coordinates, name, tolerance=0.5):
The mesh.
"""
if comm is None:
raise ValueError("A comm must be provided when creating a mesh from coordinates")

if hasattr(coordinates, '_as_mesh_geometry'):
mesh = coordinates._as_mesh_geometry()
if mesh is not None:
Expand Down Expand Up @@ -2668,7 +2671,7 @@ def Mesh(meshfile, **kwargs):
else:
coordinates = None
if coordinates is not None:
return make_mesh_from_coordinates(coordinates, name)
return make_mesh_from_coordinates(coordinates, name, comm=user_comm)

tolerance = kwargs.get("tolerance", 0.5)

Expand Down Expand Up @@ -2716,7 +2719,7 @@ def Mesh(meshfile, **kwargs):
distribution_name=kwargs.get("distribution_name"),
permutation_name=kwargs.get("permutation_name"),
comm=user_comm)
mesh = make_mesh_from_mesh_topology(topology, name)
mesh = make_mesh_from_mesh_topology(topology, name, user_comm)
if netgen and isinstance(meshfile, netgen.libngpy._meshing.Mesh):
netgen_firedrake_mesh.createFromTopology(topology, name=plex.getName(), comm=user_comm)
mesh = netgen_firedrake_mesh.firedrakeMesh
Expand Down Expand Up @@ -2863,7 +2866,7 @@ def ExtrudedMesh(mesh, layers, layer_height=None, extrusion_type='uniform', peri
eutils.make_extruded_coords(topology, mesh._coordinates, coordinates,
layer_height, extrusion_type=extrusion_type, kernel=kernel)

self = make_mesh_from_coordinates(coordinates, name)
self = make_mesh_from_coordinates(coordinates, name, comm=mesh.comm)
self._base_mesh = mesh

if extrusion_type == "radial_hedgehog":
Expand Down
6 changes: 3 additions & 3 deletions firedrake/utility_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,7 +2785,7 @@ def AnnulusMesh(
x, y = ufl.SpatialCoordinate(bar)
V = bar.coordinates.function_space()
coord = Function(V).interpolate(ufl.as_vector([x * ufl.cos(y), x * ufl.sin(y)]))
annulus = mesh.make_mesh_from_coordinates(coord.topological, name)
annulus = mesh.make_mesh_from_coordinates(coord.topological, name, comm=comm)
annulus.topology.name = mesh._generate_default_mesh_topology_name(name)
annulus._base_mesh = base
return annulus
Expand Down Expand Up @@ -2834,14 +2834,14 @@ def SolidTorusMesh(
x, y = ufl.SpatialCoordinate(unit)
V = unit.coordinates.function_space()
coord = Function(V).interpolate(ufl.as_vector([r * x + R, r * y]))
disk = mesh.make_mesh_from_coordinates(coord.topological, base_name)
disk = mesh.make_mesh_from_coordinates(coord.topological, base_name, comm=comm)
disk.topology.name = mesh._generate_default_mesh_topology_name(base_name)
disk.topology.topology_dm.setName(disk.topology.name)
bar = mesh.ExtrudedMesh(disk, layers=nR, layer_height=2 * np.pi / nR, extrusion_type="uniform", periodic=True)
x, y, z = ufl.SpatialCoordinate(bar)
V = bar.coordinates.function_space()
coord = Function(V).interpolate(ufl.as_vector([x * ufl.cos(z), x * ufl.sin(z), -y]))
torus = mesh.make_mesh_from_coordinates(coord.topological, name)
torus = mesh.make_mesh_from_coordinates(coord.topological, name, comm=comm)
torus.topology.name = mesh._generate_default_mesh_topology_name(name)
torus._base_mesh = disk
return torus
Expand Down
2 changes: 1 addition & 1 deletion tests/extrusion/test_extruded_periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_extruded_periodic_annulus():
coordV1 = FunctionSpace(mesh1, elem1)
x1, y1 = SpatialCoordinate(mesh1)
coord1 = Function(coordV1).interpolate(as_vector([x1 * cos(y1), x1 * sin(y1)]))
mesh1 = make_mesh_from_coordinates(coord1.topological, "annulus")
mesh1 = make_mesh_from_coordinates(coord1.topological, "annulus", comm=COMM_WORLD)
mesh1._base_mesh = mesh
# Check volume
x0, y0 = SpatialCoordinate(mesh0)
Expand Down
2 changes: 1 addition & 1 deletion tests/output/test_io_backward_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _get_mesh_and_V(params):
coordV = FunctionSpace(mesh, elem)
x, y = SpatialCoordinate(mesh)
coord = Function(coordV).interpolate(as_vector([x * cos(y), x * sin(y)]))
mesh = make_mesh_from_coordinates(coord.topological, name=mesh_name)
mesh = make_mesh_from_coordinates(coord.topological, name=mesh_name, comm=COMM_WORLD)
mesh._base_mesh = base
V = FunctionSpace(mesh, "RTCF", 3)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/output/test_io_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def test_io_function_extrusion_periodic(tmpdir):
coordV = FunctionSpace(extm, elem)
x, y = SpatialCoordinate(extm)
coord = Function(coordV).interpolate(as_vector([x * cos(y), x * sin(y)]))
extm = make_mesh_from_coordinates(coord.topological, name=extruded_mesh_name)
extm = make_mesh_from_coordinates(coord.topological, name=extruded_mesh_name, comm=COMM_WORLD)
extm._base_mesh = mesh
V = FunctionSpace(extm, "RTCF", 3)
method = get_embedding_method_for_checkpointing(V.ufl_element())
Expand Down

0 comments on commit fc78d12

Please sign in to comment.