From 1bfbcc46474eba34adfd63946f10e78bdb8dcd08 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 27 Aug 2024 14:59:18 +0100 Subject: [PATCH] Add PYOP2_SPMD_STRICT environment variable for checking MPI correctness --- pyop2/caching.py | 203 +++++++++++++++++------------- pyop2/compilation.py | 2 +- pyop2/configuration.py | 10 +- pyop2/mpi.py | 68 ++++++++-- test/unit/test_updated_caching.py | 7 +- 5 files changed, 187 insertions(+), 103 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 6e0d47a36..96c64de75 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -150,6 +150,8 @@ def make_obj(): def cache_filter(comm=None, comm_name=None, alive=True, function=None, cache_type=None): + """ Filter PyOP2 caches based on communicator, function or cache type. + """ caches = _KNOWN_CACHES if comm is not None: with temp_internal_comm(comm) as icomm: @@ -175,6 +177,8 @@ def cache_filter(comm=None, comm_name=None, alive=True, function=None, cache_typ class _CacheRecord: + """ Object for keeping a record of Pyop2 Cache statistics. + """ def __init__(self, cidx, comm, func, cache): self.cidx = cidx self.comm = comm @@ -220,6 +224,8 @@ def finalize(self, cache): def print_cache_stats(*args, **kwargs): + """ Print out the cache hit/miss/size/maxsize stats for PyOP2 caches. + """ data = defaultdict(lambda: defaultdict(list)) for entry in cache_filter(*args, **kwargs): active = (entry.comm != MPI.COMM_NULL) @@ -278,6 +284,8 @@ def _as_hexdigest(*args): class DictLikeDiskAccess(MutableMapping): + """ A Dictionary like interface for storing and retrieving objects from a disk cache. + """ def __init__(self, cachedir, extension=".pickle"): """ @@ -344,6 +352,8 @@ def write(self, filehandle, value): def default_comm_fetcher(*args, **kwargs): + """ A sensible default comm fetcher for use with `parallel_cache`. + """ comms = filter( lambda arg: isinstance(arg, MPI.Comm), args + tuple(kwargs.values()) @@ -356,8 +366,10 @@ def default_comm_fetcher(*args, **kwargs): def default_parallel_hashkey(*args, **kwargs): - """ We now want to actively remove any comms from args and kwargs to get the same disk cache key + """ A sensible default hash key for use with `parallel_cache`. """ + # We now want to actively remove any comms from args and kwargs to get + # the same disk cache key. hash_args = tuple(filter( lambda arg: not isinstance(arg, MPI.Comm), args @@ -370,6 +382,8 @@ def default_parallel_hashkey(*args, **kwargs): def instrument(cls): + """ Class decorator for dict-like objects for counting cache hits/misses. + """ @wraps(cls, updated=()) class _wrapper(cls): instrument__ = True @@ -410,106 +424,123 @@ class DEFAULT_CACHE(dict): DictLikeDiskAccess = instrument(DictLikeDiskAccess) -# JBTODO: This functionality should only be enabled with a PYOP2_SPMD_STRICT -# environment variable. -def parallel_cache( - hashkey=default_parallel_hashkey, - comm_fetcher=default_comm_fetcher, - cache_factory=lambda: DEFAULT_CACHE(), - broadcast=True -): - """Memory only cache decorator. - - Decorator for wrapping a function to be called over a communicator in a - cache that stores broadcastable values in memory. If the value is found in - the cache of rank 0 it is broadcast to all other ranks. - - :arg key: Callable returning the cache key for the function inputs. This - function must return a 2-tuple where the first entry is the - communicator to be collective over and the second is the key. This is - required to ensure that deadlocks do not occur when using different - subcommunicators. - """ - def decorator(func): - @PETSc.Log.EventDecorator("PyOP2 Cache Wrapper") - @wraps(func) - def wrapper(*args, **kwargs): - """ Extract the key and then try the memory cache before falling back - on calling the function and populating the cache. - """ - k = hashkey(*args, **kwargs) - key = _as_hexdigest(*k), func.__qualname__ - # Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits - with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm: - # Fetch the per-comm cache_collection or set it up if not present - # A collection is required since different types of cache can be set up on the same comm - cache_collection = comm.Get_attr(comm_cache_keyval) - if cache_collection is None: - cache_collection = {} - comm.Set_attr(comm_cache_keyval, cache_collection) - # If this kind of cache is already present on the - # cache_collection, get it, otherwise create it - local_cache = cache_collection.setdefault( - (cf := cache_factory()).__class__.__name__, - cf - ) - local_cache = cache_collection[cf.__class__.__name__] - - # If this is a new cache or function add it to the list of known caches - if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]: - # When a comm is freed we do not hold a reference to the cache. - # We attach a finalizer that extracts the stats before the cache - # is deleted. - _KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache)) - - # JBTODO: Replace everything below here with: - # value = local_cache.get(key, CACHE_MISS) - # and add an optional PYOP2_SPMD_STRICT environment variable - - if broadcast: - # Grab value from rank 0 memory cache and broadcast result - if comm.rank == 0: - value = local_cache.get(key, CACHE_MISS) - if value is CACHE_MISS: - debug( - f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " - f"{k} {local_cache.__class__.__name__} cache miss" - ) - else: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') - # JBTODO: Add communication tags to avoid cross-broadcasting - comm.bcast(value, root=0) - else: - value = comm.bcast(CACHE_MISS, root=0) - if isinstance(value, _CacheMiss): - # We might have the CACHE_MISS from rank 0 and - # `(value is CACHE_MISS) == False` which is confusing, - # so we set it back to the local value - value = CACHE_MISS - else: +if configuration["spmd_strict"]: + def parallel_cache( + hashkey=default_parallel_hashkey, + comm_fetcher=default_comm_fetcher, + cache_factory=lambda: DEFAULT_CACHE(), + ): + """Parallel cache decorator (SPMD strict-enabled). + """ + def decorator(func): + @PETSc.Log.EventDecorator("PyOP2 Cache Wrapper") + @wraps(func) + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory cache before falling back + on calling the function and populating the cache. SPMD strict ensures + that all ranks cache hit or miss to ensure that the function evaluation + always occurs in parallel. + """ + k = hashkey(*args, **kwargs) + key = _as_hexdigest(*k), func.__qualname__ + # Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits + with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm: + # Fetch the per-comm cache_collection or set it up if not present + # A collection is required since different types of cache can be set up on the same comm + cache_collection = comm.Get_attr(comm_cache_keyval) + if cache_collection is None: + cache_collection = {} + comm.Set_attr(comm_cache_keyval, cache_collection) + # If this kind of cache is already present on the + # cache_collection, get it, otherwise create it + local_cache = cache_collection.setdefault( + (cf := cache_factory()).__class__.__name__, + cf + ) + local_cache = cache_collection[cf.__class__.__name__] + + # If this is a new cache or function add it to the list of known caches + if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]: + # When a comm is freed we do not hold a reference to the cache. + # We attach a finalizer that extracts the stats before the cache + # is deleted. + _KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache)) + # Grab value from all ranks cache and broadcast cache hit/miss value = local_cache.get(key, CACHE_MISS) + debug_string = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " + debug_string += f"key={k} in cache: {local_cache.__class__.__name__} cache " if value is CACHE_MISS: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache miss') + debug(debug_string + "miss") cache_hit = False else: - debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit') + debug(debug_string + "hit") cache_hit = True all_present = comm.allgather(cache_hit) - # If not present in the cache of all ranks we need to recompute on all ranks + # If not present in the cache of all ranks we force re-evaluation on all ranks if not min(all_present): value = CACHE_MISS - if value is CACHE_MISS: - value = func(*args, **kwargs) - return local_cache.setdefault(key, value) + if value is CACHE_MISS: + value = func(*args, **kwargs) + return local_cache.setdefault(key, value) + + return wrapper + return decorator +else: + def parallel_cache( + hashkey=default_parallel_hashkey, + comm_fetcher=default_comm_fetcher, + cache_factory=lambda: DEFAULT_CACHE(), + ): + """Parallel cache decorator. + """ + def decorator(func): + @PETSc.Log.EventDecorator("PyOP2 Cache Wrapper") + @wraps(func) + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory cache before falling back + on calling the function and populating the cache. + """ + k = hashkey(*args, **kwargs) + key = _as_hexdigest(*k), func.__qualname__ + # Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits + with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm: + # Fetch the per-comm cache_collection or set it up if not present + # A collection is required since different types of cache can be set up on the same comm + cache_collection = comm.Get_attr(comm_cache_keyval) + if cache_collection is None: + cache_collection = {} + comm.Set_attr(comm_cache_keyval, cache_collection) + # If this kind of cache is already present on the + # cache_collection, get it, otherwise create it + local_cache = cache_collection.setdefault( + (cf := cache_factory()).__class__.__name__, + cf + ) + local_cache = cache_collection[cf.__class__.__name__] + + # If this is a new cache or function add it to the list of known caches + if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]: + # When a comm is freed we do not hold a reference to the cache. + # We attach a finalizer that extracts the stats before the cache + # is deleted. + _KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache)) - return wrapper - return decorator + value = local_cache.get(key, CACHE_MISS) + + if value is CACHE_MISS: + value = func(*args, **kwargs) + return local_cache.setdefault(key, value) + + return wrapper + return decorator def clear_memory_cache(comm): + """ Completely remove all PyOP2 caches on a given communicator. + """ with temp_internal_comm(comm) as icomm: if icomm.Get_attr(comm_cache_keyval) is not None: icomm.Set_attr(comm_cache_keyval, {}) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index c7a278feb..86db95b9e 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -425,7 +425,7 @@ def load_hashkey(*args, **kwargs): @mpi.collective -@memory_cache(hashkey=load_hashkey, broadcast=False) +@memory_cache(hashkey=load_hashkey) @PETSc.Log.EventDecorator() def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), argtypes=None, restype=None, comm=None): diff --git a/pyop2/configuration.py b/pyop2/configuration.py index dc4db1679..0005ceeca 100644 --- a/pyop2/configuration.py +++ b/pyop2/configuration.py @@ -40,7 +40,6 @@ from pyop2.exceptions import ConfigurationError -# JBTODO: Add a PYOP2_SPMD_STRICT environment variable to add various SPMD checks. class Configuration(dict): r"""PyOP2 configuration parameters @@ -68,13 +67,16 @@ class Configuration(dict): to a node-local filesystem too. :param log_level: How chatty should PyOP2 be? Valid values are "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL". - :param print_cache_size: Should PyOP2 print the size of caches at + :param print_cache_size: Should PyOP2 print the cache information at program exit? :param matnest: Should matrices on mixed maps be built as nests? (Default yes) :param block_sparsity: Should sparsity patterns on datasets with cdim > 1 be built as block sparsities, or dof sparsities. The former saves memory but changes which preconditioners are available for the resulting matrices. (Default yes) + :param spmd_strict: Enable barriers for calls marked with @collective and + for cache access. This adds considerable overhead, but is useful for + tracking down deadlocks. (Default no) """ # name, env variable, type, default, write once cache_dir = os.path.join(gettempdir(), "pyop2-cache-uid%s" % os.getuid()) @@ -114,7 +116,9 @@ class Configuration(dict): "matnest": ("PYOP2_MATNEST", bool, True), "block_sparsity": - ("PYOP2_BLOCK_SPARSITY", bool, True) + ("PYOP2_BLOCK_SPARSITY", bool, True), + "spmd_strict": + ("PYOP2_SPMD_STRICT", bool, False), } """Default values for PyOP2 configuration parameters""" diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 2831bc04f..7e88b8dd0 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -37,6 +37,7 @@ from petsc4py import PETSc from mpi4py import MPI # noqa from itertools import count +from functools import wraps import atexit import gc import glob @@ -160,15 +161,64 @@ class PyOP2CommError(ValueError): # PYOP2_FINALISED flag. -# JBTODO: Make this decorator infinitely more useful by adding barriers before -# and after the function call, if being run with PYOP2_SPMD_STRICT=1. -def collective(fn): - extra = trim(""" - This function is logically collective over MPI ranks, it is an - error to call it on fewer than all the ranks in MPI communicator. - """) - fn.__doc__ = "%s\n\n%s" % (trim(fn.__doc__), extra) if fn.__doc__ else extra - return fn +if configuration["spmd_strict"]: + def collective(fn): + extra = trim(""" + This function is logically collective over MPI ranks, it is an + error to call it on fewer than all the ranks in MPI communicator. + PYOP2_SPMD_STRICT=1 is in your environment and function calls will be + guarded by a barrier where possible. + """) + + @wraps(fn) + def wrapper(*args, **kwargs): + comms = filter( + lambda arg: isinstance(arg, MPI.Comm), + args + tuple(kwargs.values()) + ) + try: + comm = next(comms) + except StopIteration: + if args and hasattr(args[0], "comm"): + comm = args[0].comm + else: + comm = None + + if comm is None: + debug( + "`@collective` wrapper found no communicators in args or kwargs, " + "this means that the call is implicitly collective over an " + "unknown communicator. " + f"The following call to {fn.__module__}.{fn.__qualname__} is " + "not protected by an MPI barrier." + ) + subcomm = ", UNKNOWN Comm" + else: + subcomm = f", {comm.name} R{comm.rank}" + + debug_string_pt1 = f"{COMM_WORLD.name} R{COMM_WORLD.rank}{subcomm}: " + debug_string_pt2 = f" {fn.__module__}.{fn.__qualname__}" + debug(debug_string_pt1 + "Entering" + debug_string_pt2) + if comm is not None: + comm.Barrier() + value = fn(*args, **kwargs) + debug(debug_string_pt1 + "Leaving" + debug_string_pt2) + if comm is not None: + comm.Barrier() + return value + + wrapper.__doc__ = f"{trim(fn.__doc__)}\n\n{extra}" if fn.__doc__ else extra + return wrapper +else: + def collective(fn): + extra = trim(""" + This function is logically collective over MPI ranks, it is an + error to call it on fewer than all the ranks in MPI communicator. + You can set PYOP2_SPMD_STRICT=1 in your environment to try and catch + non-collective calls. + """) + fn.__doc__ = f"{trim(fn.__doc__)}\n\n{extra}" if fn.__doc__ else extra + return fn def delcomm_outer(comm, keyval, icomm): diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py index 5066554a1..1d9424b05 100644 --- a/test/unit/test_updated_caching.py +++ b/test/unit/test_updated_caching.py @@ -2,7 +2,6 @@ import pytest import os import tempfile -from functools import partial from itertools import chain from textwrap import dedent @@ -63,7 +62,7 @@ def state(): @pytest.mark.parametrize("decorator, uncached_function", [ (memory_cache, twople), - (partial(memory_cache, broadcast=False), n_comms), + (memory_cache, n_comms), (memory_and_disk_cache, twople), (disk_only_cache, twople) ]) @@ -89,7 +88,7 @@ def test_function_args_twice_caches(request, state, decorator, uncached_function @pytest.mark.parametrize("decorator, uncached_function", [ (memory_cache, twople), - (partial(memory_cache, broadcast=False), n_comms), + (memory_cache, n_comms), (memory_and_disk_cache, twople), (disk_only_cache, twople) ]) @@ -114,7 +113,7 @@ def test_function_args_different(request, state, decorator, uncached_function, t @pytest.mark.parallel(nprocs=3) @pytest.mark.parametrize("decorator, uncached_function", [ (memory_cache, twople), - (partial(memory_cache, broadcast=False), n_comms), + (memory_cache, n_comms), (memory_and_disk_cache, twople), (disk_only_cache, twople) ])