Skip to content

Commit

Permalink
Merge pull request #1327 from firedrakeproject/multiple-transferopera…
Browse files Browse the repository at this point in the history
…tors

* multiple-transferoperators:
  mg: Add standalone test of multiple transfer operators
  dmhook: Hopefully fix functionspace weakrefs
  Transferring transfer operators. Who transfers the transfer transferrers?
  mg: Tests of multiple transfers on mixed spaces
  mg: Add test of multiple custom transfer operators
  mg: Support multiple transfer operators in monolithic MG
  solving: Allow setting multiple transfer_operators managers
  • Loading branch information
wence- committed Nov 14, 2018
2 parents a2bd424 + 185df22 commit 8c39d71
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 82 deletions.
76 changes: 38 additions & 38 deletions firedrake/dmhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,17 @@ def get_function_space(dm):
:arg dm: The DM to get the function space from.
:raises RuntimeError: if no function space was found.
"""
V = dm.getAttr("__fs_info__")()
if V is None:
raise RuntimeError("FunctionSpace not found on DM")
info = dm.getAttr("__fs_info__")
meshref, element, indices, (name, names) = info
mesh = meshref()
if mesh is None:
raise RuntimeError("Somehow your mesh was collected, this should never happen")
V = firedrake.FunctionSpace(mesh, element, name=name)
if len(V) > 1:
for V_, name in zip(V, names):
V_.topological.name = name
for index in indices:
V = V.sub(index)
return V


Expand All @@ -64,11 +72,27 @@ def set_function_space(dm, V):
.. note::
This stores a weakref to the function space in the DM, so you
should hold a strong reference somewhere else.
This stores the information necessary to make a function space given a DM.
"""
dm.setAttr("__fs_info__", weakref.ref(V))
mesh = V.mesh()

indices = []
names = []
while V.parent is not None:
if V.index is not None:
assert V.component is None
indices.append(V.index)
if V.component is not None:
assert V.index is None
indices.append(V.component)
V = V.parent
if len(V) > 1:
names = tuple(V_.name for V_ in V)
element = V.ufl_element()

info = (weakref.ref(mesh), element, tuple(reversed(indices)), (V.name, names))
dm.setAttr("__fs_info__", info)


def push_appctx(dm, ctx):
Expand Down Expand Up @@ -339,11 +363,6 @@ def create_subdm(dm, fields, *args, **kwargs):
:arg DM: The DM.
:arg fields: The fields in the new sub-DM.
.. note::
This should, but currently does not, transfer appropriately
split application contexts onto the sub-DMs.
"""
W = get_function_space(dm)
ctx = get_appctx(dm)
Expand All @@ -359,21 +378,12 @@ def create_subdm(dm, fields, *args, **kwargs):
push_ctx_coarsener(subdm, coarsen)
return iset, subdm
else:
try:
# Look up the subspace in the cache
iset, subspace = W._subspaces[tuple(fields)]
except KeyError:
# Need to build an MFS for the subspace
subspace = firedrake.MixedFunctionSpace([W[f] for f in fields])
# Index set mapping from W into subspace.
iset = PETSc.IS().createGeneral(numpy.concatenate([W._ises[f].indices
for f in fields]),
comm=W.comm)
# Keep hold of strong reference to created subspace (given we
# only hold a weakref in the shell DM), and so we can
# reuse it later.
W._subspaces[tuple(fields)] = iset, subspace

# Need to build an MFS for the subspace
subspace = firedrake.MixedFunctionSpace([W[f] for f in fields])
# Index set mapping from W into subspace.
iset = PETSc.IS().createGeneral(numpy.concatenate([W._ises[f].indices
for f in fields]),
comm=W.comm)
if ctx is not None:
ctx, = ctx.split([fields])
push_appctx(subspace.dm, ctx)
Expand All @@ -392,28 +402,18 @@ def coarsen(dm, comm):
"""
from firedrake.mg.utils import get_level
V = get_function_space(dm)
if V is None:
raise RuntimeError("No functionspace found on DM")
hierarchy, level = get_level(V.mesh())
if level < 1:
raise RuntimeError("Cannot coarsen coarsest DM")
if hasattr(V, "_coarse"):
cdm = V._coarse.dm
else:
coarsen = get_ctx_coarsener(dm)
V._coarse = coarsen(V, coarsen)
cdm = V._coarse.dm

transfer = get_transfer_operators(dm)
push_transfer_operators(cdm, *transfer)
coarsen = get_ctx_coarsener(dm)
Vc = coarsen(V, coarsen)
cdm = Vc.dm
push_ctx_coarsener(cdm, coarsen)
ctx = get_appctx(dm)
if ctx is not None:
push_appctx(cdm, coarsen(ctx, coarsen))
# Necessary for MG inside a fieldsplit in a SNES.
cdm.setKSPComputeOperators(firedrake.solving_utils._SNESContext.compute_operators)
V._coarse._fine = V
return cdm


Expand Down
15 changes: 7 additions & 8 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,13 @@ def solve(self, x, b):
if self.A.has_bcs:
b = self._lifted(b)

with self.inserted_options():
with b.dat.vec_ro as rhs:
if self.ksp.getInitialGuessNonzero():
acc = x.dat.vec
else:
acc = x.dat.vec_wo
with acc as solution:
self.ksp.solve(rhs, solution)
if self.ksp.getInitialGuessNonzero():
acc = x.dat.vec
else:
acc = x.dat.vec_wo

with self.inserted_options(), b.dat.vec_ro as rhs, acc as solution:
self.ksp.solve(rhs, solution)

r = self.ksp.getConvergedReason()
if r < 0:
Expand Down
9 changes: 6 additions & 3 deletions firedrake/mg/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def prolong(coarse, fine):
if len(Vc) != len(Vf):
raise ValueError("Mixed spaces have different lengths")
for in_, out in zip(coarse.split(), fine.split()):
prolong(in_, out)
myprolong, _, _ = firedrake.dmhooks.get_transfer_operators(in_.function_space().dm)
myprolong(in_, out)
return

if Vc.ufl_element().family() == "Real" or Vf.ufl_element().family() == "Real":
Expand Down Expand Up @@ -91,7 +92,8 @@ def restrict(fine_dual, coarse_dual):
if len(Vc) != len(Vf):
raise ValueError("Mixed spaces have different lengths")
for in_, out in zip(fine_dual.split(), coarse_dual.split()):
restrict(in_, out)
_, myrestrict, _ = firedrake.dmhooks.get_transfer_operators(in_.function_space().dm)
myrestrict(in_, out)
return

if Vc.ufl_element().family() == "Real" or Vf.ufl_element().family() == "Real":
Expand Down Expand Up @@ -152,7 +154,8 @@ def inject(fine, coarse):
if len(Vc) != len(Vf):
raise ValueError("Mixed spaces have different lengths")
for in_, out in zip(fine.split(), coarse.split()):
inject(in_, out)
_, _, myinject = firedrake.dmhooks.get_transfer_operators(in_.function_space().dm)
myinject(in_, out)
return

if Vc.ufl_element().family() == "Real" or Vf.ufl_element().family() == "Real":
Expand Down
11 changes: 11 additions & 0 deletions firedrake/mg/ufl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,22 @@ def coarsen_function_space(V, self, coefficient_mapping=None):

mesh = self(V.mesh(), self)

Vf = V
V = firedrake.FunctionSpace(mesh, V.ufl_element())

from firedrake.dmhooks import get_transfer_operators, push_transfer_operators
transfer = get_transfer_operators(Vf.dm)
push_transfer_operators(V.dm, *transfer)
if len(V) > 1:
for V_, Vc_ in zip(Vf, V):
transfer = get_transfer_operators(V_.dm)
push_transfer_operators(Vc_.dm, *transfer)

for i in reversed(indices):
V = V.sub(i)
V._fine = fine
fine._coarse = V

return V


Expand Down
33 changes: 17 additions & 16 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import ufl
from itertools import chain
from contextlib import ExitStack

from firedrake import dmhooks
from firedrake import slate
Expand Down Expand Up @@ -191,18 +193,18 @@ def update_diffusivity(current_solution):
self.set_from_options(self.snes)

# Used for custom grid transfer.
self._transfer_operators = None
self._transfer_operators = ()
self._setup = False

def set_transfer_operators(self, contextmanager):
r"""Set a context manager which manages which grid transfer operators should be used.
def set_transfer_operators(self, *contextmanagers):
r"""Set context managers which manages which grid transfer operators should be used.
:arg contextmanager: an instance of :class:`~.dmhooks.transfer_operators`.
:arg contextmanagers: instances of :class:`~.dmhooks.transfer_operators`.
:raises RuntimeError: if called after calling solve.
"""
if self._setup:
raise RuntimeError("Cannot set transfer operators after solve")
self._transfer_operators = contextmanager
self._transfer_operators = tuple(contextmanagers)

def solve(self, bounds=None):
r"""Solve the variational problem.
Expand All @@ -227,17 +229,16 @@ def solve(self, bounds=None):
with lower.dat.vec_ro as lb, upper.dat.vec_ro as ub:
self.snes.setVariableBounds(lb, ub)
work = self._work
# Ensure options database has full set of options (so monitors work right)
with self.inserted_options(), dmhooks.appctx(dm, self._ctx):
with self._problem.u.dat.vec as u:
u.copy(work)
if self._transfer_operators is not None:
with self._transfer_operators:
self.snes.solve(None, work)
else:
self.snes.solve(None, work)
work.copy(u)

with self._problem.u.dat.vec as u:
u.copy(work)
with ExitStack() as stack:
# Ensure options database has full set of options (so monitors
# work right)
for ctx in chain((self.inserted_options(), dmhooks.appctx(dm, self._ctx)),
self._transfer_operators):
stack.enter_context(ctx)
self.snes.solve(None, work)
work.copy(u)
self._setup = True
solving_utils.check_snes_convergence(self.snes)

Expand Down
Loading

0 comments on commit 8c39d71

Please sign in to comment.