Skip to content

Commit

Permalink
Add PYOP2_SPMD_STRICT environment variable for checking MPI correctness
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 27, 2024
1 parent ec62654 commit 1bfbcc4
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 103 deletions.
203 changes: 117 additions & 86 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def make_obj():


def cache_filter(comm=None, comm_name=None, alive=True, function=None, cache_type=None):
""" Filter PyOP2 caches based on communicator, function or cache type.
"""
caches = _KNOWN_CACHES
if comm is not None:
with temp_internal_comm(comm) as icomm:
Expand All @@ -175,6 +177,8 @@ def cache_filter(comm=None, comm_name=None, alive=True, function=None, cache_typ


class _CacheRecord:
""" Object for keeping a record of Pyop2 Cache statistics.
"""
def __init__(self, cidx, comm, func, cache):
self.cidx = cidx
self.comm = comm
Expand Down Expand Up @@ -220,6 +224,8 @@ def finalize(self, cache):


def print_cache_stats(*args, **kwargs):
""" Print out the cache hit/miss/size/maxsize stats for PyOP2 caches.
"""
data = defaultdict(lambda: defaultdict(list))
for entry in cache_filter(*args, **kwargs):
active = (entry.comm != MPI.COMM_NULL)
Expand Down Expand Up @@ -278,6 +284,8 @@ def _as_hexdigest(*args):


class DictLikeDiskAccess(MutableMapping):
""" A Dictionary like interface for storing and retrieving objects from a disk cache.
"""
def __init__(self, cachedir, extension=".pickle"):
"""
Expand Down Expand Up @@ -344,6 +352,8 @@ def write(self, filehandle, value):


def default_comm_fetcher(*args, **kwargs):
""" A sensible default comm fetcher for use with `parallel_cache`.
"""
comms = filter(
lambda arg: isinstance(arg, MPI.Comm),
args + tuple(kwargs.values())
Expand All @@ -356,8 +366,10 @@ def default_comm_fetcher(*args, **kwargs):


def default_parallel_hashkey(*args, **kwargs):
""" We now want to actively remove any comms from args and kwargs to get the same disk cache key
""" A sensible default hash key for use with `parallel_cache`.
"""
# We now want to actively remove any comms from args and kwargs to get
# the same disk cache key.
hash_args = tuple(filter(
lambda arg: not isinstance(arg, MPI.Comm),
args
Expand All @@ -370,6 +382,8 @@ def default_parallel_hashkey(*args, **kwargs):


def instrument(cls):
""" Class decorator for dict-like objects for counting cache hits/misses.
"""
@wraps(cls, updated=())
class _wrapper(cls):
instrument__ = True
Expand Down Expand Up @@ -410,106 +424,123 @@ class DEFAULT_CACHE(dict):
DictLikeDiskAccess = instrument(DictLikeDiskAccess)


# JBTODO: This functionality should only be enabled with a PYOP2_SPMD_STRICT
# environment variable.
def parallel_cache(
hashkey=default_parallel_hashkey,
comm_fetcher=default_comm_fetcher,
cache_factory=lambda: DEFAULT_CACHE(),
broadcast=True
):
"""Memory only cache decorator.
Decorator for wrapping a function to be called over a communicator in a
cache that stores broadcastable values in memory. If the value is found in
the cache of rank 0 it is broadcast to all other ranks.
:arg key: Callable returning the cache key for the function inputs. This
function must return a 2-tuple where the first entry is the
communicator to be collective over and the second is the key. This is
required to ensure that deadlocks do not occur when using different
subcommunicators.
"""
def decorator(func):
@PETSc.Log.EventDecorator("PyOP2 Cache Wrapper")
@wraps(func)
def wrapper(*args, **kwargs):
""" Extract the key and then try the memory cache before falling back
on calling the function and populating the cache.
"""
k = hashkey(*args, **kwargs)
key = _as_hexdigest(*k), func.__qualname__
# Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits
with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm:
# Fetch the per-comm cache_collection or set it up if not present
# A collection is required since different types of cache can be set up on the same comm
cache_collection = comm.Get_attr(comm_cache_keyval)
if cache_collection is None:
cache_collection = {}
comm.Set_attr(comm_cache_keyval, cache_collection)
# If this kind of cache is already present on the
# cache_collection, get it, otherwise create it
local_cache = cache_collection.setdefault(
(cf := cache_factory()).__class__.__name__,
cf
)
local_cache = cache_collection[cf.__class__.__name__]

# If this is a new cache or function add it to the list of known caches
if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]:
# When a comm is freed we do not hold a reference to the cache.
# We attach a finalizer that extracts the stats before the cache
# is deleted.
_KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache))

# JBTODO: Replace everything below here with:
# value = local_cache.get(key, CACHE_MISS)
# and add an optional PYOP2_SPMD_STRICT environment variable

if broadcast:
# Grab value from rank 0 memory cache and broadcast result
if comm.rank == 0:
value = local_cache.get(key, CACHE_MISS)
if value is CACHE_MISS:
debug(
f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: "
f"{k} {local_cache.__class__.__name__} cache miss"
)
else:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit')
# JBTODO: Add communication tags to avoid cross-broadcasting
comm.bcast(value, root=0)
else:
value = comm.bcast(CACHE_MISS, root=0)
if isinstance(value, _CacheMiss):
# We might have the CACHE_MISS from rank 0 and
# `(value is CACHE_MISS) == False` which is confusing,
# so we set it back to the local value
value = CACHE_MISS
else:
if configuration["spmd_strict"]:
def parallel_cache(
hashkey=default_parallel_hashkey,
comm_fetcher=default_comm_fetcher,
cache_factory=lambda: DEFAULT_CACHE(),
):
"""Parallel cache decorator (SPMD strict-enabled).
"""
def decorator(func):
@PETSc.Log.EventDecorator("PyOP2 Cache Wrapper")
@wraps(func)
def wrapper(*args, **kwargs):
""" Extract the key and then try the memory cache before falling back
on calling the function and populating the cache. SPMD strict ensures
that all ranks cache hit or miss to ensure that the function evaluation
always occurs in parallel.
"""
k = hashkey(*args, **kwargs)
key = _as_hexdigest(*k), func.__qualname__
# Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits
with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm:
# Fetch the per-comm cache_collection or set it up if not present
# A collection is required since different types of cache can be set up on the same comm
cache_collection = comm.Get_attr(comm_cache_keyval)
if cache_collection is None:
cache_collection = {}
comm.Set_attr(comm_cache_keyval, cache_collection)
# If this kind of cache is already present on the
# cache_collection, get it, otherwise create it
local_cache = cache_collection.setdefault(
(cf := cache_factory()).__class__.__name__,
cf
)
local_cache = cache_collection[cf.__class__.__name__]

# If this is a new cache or function add it to the list of known caches
if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]:
# When a comm is freed we do not hold a reference to the cache.
# We attach a finalizer that extracts the stats before the cache
# is deleted.
_KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache))

# Grab value from all ranks cache and broadcast cache hit/miss
value = local_cache.get(key, CACHE_MISS)
debug_string = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: "
debug_string += f"key={k} in cache: {local_cache.__class__.__name__} cache "
if value is CACHE_MISS:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache miss')
debug(debug_string + "miss")
cache_hit = False
else:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit')
debug(debug_string + "hit")
cache_hit = True
all_present = comm.allgather(cache_hit)

# If not present in the cache of all ranks we need to recompute on all ranks
# If not present in the cache of all ranks we force re-evaluation on all ranks
if not min(all_present):
value = CACHE_MISS

if value is CACHE_MISS:
value = func(*args, **kwargs)
return local_cache.setdefault(key, value)
if value is CACHE_MISS:
value = func(*args, **kwargs)
return local_cache.setdefault(key, value)

return wrapper
return decorator
else:
def parallel_cache(
hashkey=default_parallel_hashkey,
comm_fetcher=default_comm_fetcher,
cache_factory=lambda: DEFAULT_CACHE(),
):
"""Parallel cache decorator.
"""
def decorator(func):
@PETSc.Log.EventDecorator("PyOP2 Cache Wrapper")
@wraps(func)
def wrapper(*args, **kwargs):
""" Extract the key and then try the memory cache before falling back
on calling the function and populating the cache.
"""
k = hashkey(*args, **kwargs)
key = _as_hexdigest(*k), func.__qualname__
# Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits
with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm:
# Fetch the per-comm cache_collection or set it up if not present
# A collection is required since different types of cache can be set up on the same comm
cache_collection = comm.Get_attr(comm_cache_keyval)
if cache_collection is None:
cache_collection = {}
comm.Set_attr(comm_cache_keyval, cache_collection)
# If this kind of cache is already present on the
# cache_collection, get it, otherwise create it
local_cache = cache_collection.setdefault(
(cf := cache_factory()).__class__.__name__,
cf
)
local_cache = cache_collection[cf.__class__.__name__]

# If this is a new cache or function add it to the list of known caches
if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]:
# When a comm is freed we do not hold a reference to the cache.
# We attach a finalizer that extracts the stats before the cache
# is deleted.
_KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache))

return wrapper
return decorator
value = local_cache.get(key, CACHE_MISS)

if value is CACHE_MISS:
value = func(*args, **kwargs)
return local_cache.setdefault(key, value)

return wrapper
return decorator


def clear_memory_cache(comm):
""" Completely remove all PyOP2 caches on a given communicator.
"""
with temp_internal_comm(comm) as icomm:
if icomm.Get_attr(comm_cache_keyval) is not None:
icomm.Set_attr(comm_cache_keyval, {})
Expand Down
2 changes: 1 addition & 1 deletion pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def load_hashkey(*args, **kwargs):


@mpi.collective
@memory_cache(hashkey=load_hashkey, broadcast=False)
@memory_cache(hashkey=load_hashkey)
@PETSc.Log.EventDecorator()
def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(),
argtypes=None, restype=None, comm=None):
Expand Down
10 changes: 7 additions & 3 deletions pyop2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from pyop2.exceptions import ConfigurationError


# JBTODO: Add a PYOP2_SPMD_STRICT environment variable to add various SPMD checks.
class Configuration(dict):
r"""PyOP2 configuration parameters
Expand Down Expand Up @@ -68,13 +67,16 @@ class Configuration(dict):
to a node-local filesystem too.
:param log_level: How chatty should PyOP2 be? Valid values
are "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL".
:param print_cache_size: Should PyOP2 print the size of caches at
:param print_cache_size: Should PyOP2 print the cache information at
program exit?
:param matnest: Should matrices on mixed maps be built as nests? (Default yes)
:param block_sparsity: Should sparsity patterns on datasets with
cdim > 1 be built as block sparsities, or dof sparsities. The
former saves memory but changes which preconditioners are
available for the resulting matrices. (Default yes)
:param spmd_strict: Enable barriers for calls marked with @collective and
for cache access. This adds considerable overhead, but is useful for
tracking down deadlocks. (Default no)
"""
# name, env variable, type, default, write once
cache_dir = os.path.join(gettempdir(), "pyop2-cache-uid%s" % os.getuid())
Expand Down Expand Up @@ -114,7 +116,9 @@ class Configuration(dict):
"matnest":
("PYOP2_MATNEST", bool, True),
"block_sparsity":
("PYOP2_BLOCK_SPARSITY", bool, True)
("PYOP2_BLOCK_SPARSITY", bool, True),
"spmd_strict":
("PYOP2_SPMD_STRICT", bool, False),
}
"""Default values for PyOP2 configuration parameters"""

Expand Down
Loading

0 comments on commit 1bfbcc4

Please sign in to comment.