Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comm reference fixes + Remove __del__ method and add weakref.finalizer #712

Merged
merged 15 commits into from
Jan 17, 2024
Merged
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
PETSC_ARCH: default
PETSC_CONFIGURE_OPTIONS: --with-debugging=1 --with-shared-libraries=1 --with-c2html=0 --with-fortran-bindings=0
RDMAV_FORK_SAFE: 1
PYOP2_CI_TESTS: 1
timeout-minutes: 60

steps:
Expand Down
7 changes: 0 additions & 7 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,6 @@ class ObjectCached(object):
details). The object on which the cache is stored should contain
a dict in its ``_cache`` attribute.

.. warning ::

This kind of cache sets up a circular reference. If either of
the objects implements ``__del__``, the Python garbage
collector will not be able to collect this cycle, and hence
the cache will never be evicted.

.. warning::

The derived class' :meth:`__init__` is still called if the
Expand Down
10 changes: 2 additions & 8 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,8 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co
self._debug = configuration["debug"]

# Compilation communicators are reference counted on the PyOP2 comm
self.pcomm = mpi.internal_comm(comm)
self.comm = mpi.compilation_comm(self.pcomm)

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)
if hasattr(self, "pcomm"):
mpi.decref(self.pcomm)
self.pcomm = mpi.internal_comm(comm, self)
self.comm = mpi.compilation_comm(self.pcomm, self)

def __repr__(self):
return f"<{self._name} compiler, version {self.version or 'unknown'}>"
Expand Down
139 changes: 92 additions & 47 deletions pyop2/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import glob
import os
import tempfile
import weakref

from pyop2.configuration import configuration
from pyop2.exceptions import CompilationError
Expand Down Expand Up @@ -74,6 +75,8 @@
_DUPED_COMM_DICT = {}
# Flag to indicate whether we are in cleanup (at exit)
PYOP2_FINALIZED = False
# Flag for outputting information at the end of testing (do not abuse!)
_running_on_ci = bool(os.environ.get('PYOP2_CI_TESTS'))


class PyOP2CommError(ValueError):
Expand Down Expand Up @@ -175,28 +178,46 @@ def delcomm_outer(comm, keyval, icomm):
:arg icomm: The inner communicator, should have a reference to
``comm``.
"""
# This will raise errors at cleanup time as some objects are already
# deleted, so we just skip
if not PYOP2_FINALIZED:
if keyval not in (innercomm_keyval, compilationcomm_keyval):
raise PyOP2CommError("Unexpected keyval")
ocomm = icomm.Get_attr(outercomm_keyval)
if ocomm is None:
raise PyOP2CommError("Inner comm does not have expected reference to outer comm")

if ocomm != comm:
raise PyOP2CommError("Inner comm has reference to non-matching outer comm")
icomm.Delete_attr(outercomm_keyval)

# Once we have removed the reference to the inner/compilation comm we can free it
cidx = icomm.Get_attr(cidx_keyval)
cidx = cidx[0]
del _DUPED_COMM_DICT[cidx]
gc.collect()
refcount = icomm.Get_attr(refcount_keyval)
if refcount[0] > 1:
raise PyOP2CommError("References to comm still held, this will cause deadlock")
icomm.Free()
# Use debug printer that is safe to use at exit time
debug = finalize_safe_debug()
if keyval not in (innercomm_keyval, compilationcomm_keyval):
raise PyOP2CommError("Unexpected keyval")

if keyval == innercomm_keyval:
debug(f'Deleting innercomm keyval on {comm.name}')
if keyval == compilationcomm_keyval:
debug(f'Deleting compilationcomm keyval on {comm.name}')

ocomm = icomm.Get_attr(outercomm_keyval)
if ocomm is None:
raise PyOP2CommError("Inner comm does not have expected reference to outer comm")

if ocomm != comm:
raise PyOP2CommError("Inner comm has reference to non-matching outer comm")
icomm.Delete_attr(outercomm_keyval)

# An inner comm may or may not hold a reference to a compilation comm
comp_comm = icomm.Get_attr(compilationcomm_keyval)
if comp_comm is not None:
debug('Removing compilation comm on inner comm')
decref(comp_comm)
icomm.Delete_attr(compilationcomm_keyval)

# Once we have removed the reference to the inner/compilation comm we can free it
cidx = icomm.Get_attr(cidx_keyval)
cidx = cidx[0]
del _DUPED_COMM_DICT[cidx]
gc.collect()
refcount = icomm.Get_attr(refcount_keyval)
if refcount[0] > 1:
# In the case where `comm` is a custom user communicator there may be references
# to the inner comm still held and this is not an issue, but there is not an
# easy way to distinguish this case, so we just log the event.
debug(
f"There are still {refcount[0]} references to {comm.name}, "
"this will cause deadlock if the communicator has been incorrectly freed"
)
icomm.Free()


# Reference count, creation index, inner/outer/compilation communicator
Expand All @@ -215,14 +236,10 @@ def is_pyop2_comm(comm):

:arg comm: Communicator to query
"""
global PYOP2_FINALIZED
if isinstance(comm, PETSc.Comm):
ispyop2comm = False
elif comm == MPI.COMM_NULL:
if not PYOP2_FINALIZED:
raise PyOP2CommError("Communicator passed to is_pyop2_comm() is COMM_NULL")
else:
ispyop2comm = True
raise PyOP2CommError("Communicator passed to is_pyop2_comm() is COMM_NULL")
elif isinstance(comm, MPI.Comm):
ispyop2comm = bool(comm.Get_attr(refcount_keyval))
else:
Expand All @@ -231,7 +248,8 @@ def is_pyop2_comm(comm):


def pyop2_comm_status():
""" Prints the reference counts for all comms PyOP2 has duplicated
""" Return string containing a table of the reference counts for all
communicators PyOP2 has duplicated.
"""
status_string = 'PYOP2 Communicator reference counts:\n'
status_string += '| Communicator name | Count |\n'
Expand All @@ -255,10 +273,7 @@ class temp_internal_comm:
"""
def __init__(self, comm):
self.user_comm = comm
self.internal_comm = internal_comm(self.user_comm)

def __del__(self):
decref(self.internal_comm)
self.internal_comm = internal_comm(self.user_comm, self)

def __enter__(self):
""" Returns an internal comm that will be safely decref'd
Expand All @@ -272,10 +287,12 @@ def __exit__(self, exc_type, exc_value, traceback):
pass


def internal_comm(comm):
def internal_comm(comm, obj):
""" Creates an internal comm from the user comm.
If comm is None, create an internal communicator from COMM_WORLD
:arg comm: A communicator or None
:arg obj: The object which the comm is an attribute of
(usually `self`)

:returns pyop2_comm: A PyOP2 internal communicator
"""
Expand All @@ -298,6 +315,7 @@ def internal_comm(comm):
pyop2_comm = comm
else:
pyop2_comm = dup_comm(comm)
weakref.finalize(obj, decref, pyop2_comm)
return pyop2_comm


Expand All @@ -312,19 +330,18 @@ def incref(comm):
def decref(comm):
""" Decrement communicator reference count
"""
if not PYOP2_FINALIZED:
if comm == MPI.COMM_NULL:
# This case occurs if the the outer communicator has already been freed by
# the user
debug("Cannot decref an already freed communicator")
else:
assert is_pyop2_comm(comm)
refcount = comm.Get_attr(refcount_keyval)
refcount[0] -= 1
if refcount[0] == 1:
# Freeing the comm is handled by the destruction of the user comm
pass
elif refcount[0] < 1:
# Freeing the internal comm is handled by the destruction of the user comm
if refcount[0] < 1:
raise PyOP2CommError("Reference count is less than 1, decref called too many times")

elif comm != MPI.COMM_NULL:
comm.Free()


def dup_comm(comm_in):
"""Given a communicator return a communicator for internal use.
Expand Down Expand Up @@ -440,10 +457,13 @@ def set_compilation_comm(comm, comp_comm):


@collective
def compilation_comm(comm):
def compilation_comm(comm, obj):
"""Get a communicator for compilation.

:arg comm: The input communicator, must be a PyOP2 comm.
:arg obj: The object which the comm is an attribute of
(usually `self`)

:returns: A communicator used for compilation (may be smaller)
"""
if not is_pyop2_comm(comm):
Expand All @@ -465,29 +485,54 @@ def compilation_comm(comm):
else:
comp_comm = comm
incref(comp_comm)
weakref.finalize(obj, decref, comp_comm)
return comp_comm


def finalize_safe_debug():
''' Return function for debug output.

When Python is finalizing the logging module may be finalized before we have
finished writing debug information. In this case we fall back to using the
Python `print` function to output debugging information.

Furthermore, we always want to see this finalization information when
running the CI tests.
'''
global debug
if PYOP2_FINALIZED:
if logger.level > DEBUG and not _running_on_ci:
debug = lambda string: None
else:
debug = lambda string: print(string)
return debug


@atexit.register
def _free_comms():
"""Free all outstanding communicators."""
global PYOP2_FINALIZED
PYOP2_FINALIZED = True
if logger.level > DEBUG:
debug = lambda string: None
else:
debug = lambda string: print(string)
debug = finalize_safe_debug()
debug("PyOP2 Finalizing")
# Collect garbage as it may hold on to communicator references

debug("Calling gc.collect()")
gc.collect()
debug("STATE0")
debug(pyop2_comm_status())

debug("Freeing PYOP2_COMM_WORLD")
COMM_WORLD.Free()
debug("STATE1")
debug(pyop2_comm_status())

debug("Freeing PYOP2_COMM_SELF")
COMM_SELF.Free()
debug("STATE2")
debug(pyop2_comm_status())
debug(f"Freeing comms in list (length {len(_DUPED_COMM_DICT)})")
for key in sorted(_DUPED_COMM_DICT.keys()):
for key in sorted(_DUPED_COMM_DICT.keys(), reverse=True):
comm = _DUPED_COMM_DICT[key]
if comm != MPI.COMM_NULL:
refcount = comm.Get_attr(refcount_keyval)
Expand Down
6 changes: 1 addition & 5 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,9 @@ def __init__(self, global_knl, iterset, arguments):

self.global_kernel = global_knl
self.iterset = iterset
self.comm = mpi.internal_comm(iterset.comm)
self.comm = mpi.internal_comm(iterset.comm, self)
self.arguments, self.reduced_globals = self.prepare_reduced_globals(arguments, global_knl)

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

@property
def local_kernel(self):
return self.global_kernel.local_kernel
Expand Down
8 changes: 2 additions & 6 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,13 @@ def __init__(self, dataset, data=None, dtype=None, name=None):
EmptyDataMixin.__init__(self, data, dtype, self._shape)

self._dataset = dataset
self.comm = mpi.internal_comm(dataset.comm)
self.comm = mpi.internal_comm(dataset.comm, self)
self.halo_valid = True
self._name = name or "dat_#x%x" % id(self)

self._halo_frozen = False
self._frozen_access_mode = None

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

@utils.cached_property
def _kernel_args_(self):
return (self._data.ctypes.data, )
Expand Down Expand Up @@ -823,7 +819,7 @@ def what(x):
if not all(d.dtype == self._dats[0].dtype for d in self._dats):
raise ex.DataValueError('MixedDat with different dtypes is not supported')
# TODO: Think about different communicators on dats (c.f. MixedSet)
self.comm = mpi.internal_comm(self._dats[0].comm)
self.comm = mpi.internal_comm(self._dats[0].comm, self)

@property
def dat_version(self):
Expand Down
12 changes: 3 additions & 9 deletions pyop2/types/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,13 @@ def __init__(self, iter_set, dim=1, name=None):
return
if isinstance(iter_set, Subset):
raise NotImplementedError("Deriving a DataSet from a Subset is unsupported")
self.comm = mpi.internal_comm(iter_set.comm)
self.comm = mpi.internal_comm(iter_set.comm, self)
self._set = iter_set
self._dim = utils.as_tuple(dim, numbers.Integral)
self._cdim = np.prod(self._dim).item()
self._name = name or "dset_#x%x" % id(self)
self._initialized = True

def __del__(self):
# Cannot use hasattr here, since we define `__getattr__`
# This causes infinite recursion when looked up!
if "comm" in self.__dict__:
mpi.decref(self.comm)

@classmethod
def _process_args(cls, *args, **kwargs):
return (args[0], ) + args, kwargs
Expand Down Expand Up @@ -211,7 +205,7 @@ def __init__(self, global_):
if self._initialized:
return
self._global = global_
self.comm = mpi.internal_comm(global_.comm)
self.comm = mpi.internal_comm(global_.comm, self)
self._globalset = GlobalSet(comm=self.comm)
self._name = "gdset_#x%x" % id(self)
self._initialized = True
Expand Down Expand Up @@ -360,7 +354,7 @@ def __init__(self, arg, dims=None):
comm = self._process_args(arg, dims)[0][0].comm
except AttributeError:
comm = None
self.comm = mpi.internal_comm(comm)
self.comm = mpi.internal_comm(comm, self)
self._initialized = True

@classmethod
Expand Down
10 changes: 1 addition & 9 deletions pyop2/types/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ def __init__(self, dim, data=None, dtype=None, name=None):
self._buf = np.empty(self.shape, dtype=self.dtype)
self._name = name or "%s_#x%x" % (self.__class__.__name__.lower(), id(self))

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

@utils.cached_property
def _kernel_args_(self):
return (self._data.ctypes.data, )
Expand Down Expand Up @@ -247,16 +243,12 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None):
super().__init__(dim, data, dtype, name)
if comm is None:
warnings.warn("PyOP2.Global has no comm, this is likely to break in parallel!")
self.comm = mpi.internal_comm(comm)
self.comm = mpi.internal_comm(comm, self)

# Object versioning setup
petsc_counter = (comm and self.dtype == PETSc.ScalarType)
VecAccessMixin.__init__(self, petsc_counter=petsc_counter)

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

def __str__(self):
return "OP2 Global Argument: %s with dim %s and value %s" \
% (self._name, self._dim, self._data)
Expand Down
Loading
Loading