From 4e6cba30bf590643ed244d6c0e7dbcdeb8fe2ec6 Mon Sep 17 00:00:00 2001 From: Jack Betteridge <43041811+JDBetteridge@users.noreply.github.com> Date: Wed, 9 Oct 2024 16:46:13 +0100 Subject: [PATCH] JDBetteridge/update caching (#3730) --------- Co-authored-by: David A. Ham --- firedrake/extrusion_utils.py | 6 +- firedrake/interpolation.py | 13 +- firedrake/logging.py | 3 +- firedrake/parameters.py | 5 +- firedrake/preconditioners/pmg.py | 8 +- firedrake/slate/slac/compiler.py | 41 +++-- firedrake/tsfc_interface.py | 159 +++++++----------- requirements-git.txt | 1 - tests/regression/test_ensembleparallelism.py | 2 +- tests/test_tsfc_interface.py | 128 +++++--------- tests/vertexonly/test_vertex_only_fs.py | 7 + .../test_vertex_only_mesh_generation.py | 31 +++- 12 files changed, 186 insertions(+), 218 deletions(-) diff --git a/firedrake/extrusion_utils.py b/firedrake/extrusion_utils.py index 0a69eecb41..b038d904af 100644 --- a/firedrake/extrusion_utils.py +++ b/firedrake/extrusion_utils.py @@ -5,7 +5,7 @@ import finat from pyop2 import op2 -from pyop2.caching import cached +from pyop2.caching import serial_cache from firedrake.petsc import PETSc from firedrake.utils import IntType, RealType, ScalarType from tsfc.finatinterface import create_element @@ -338,7 +338,7 @@ def make_offset_key(finat_element): return entity_dofs_key(finat_element.entity_dofs()), is_real_tensor_product_element(finat_element) -@cached({}, key=make_offset_key) +@serial_cache(hashkey=make_offset_key) def calculate_dof_offset(finat_element): """Return the offset between the neighbouring cells of a column for each DoF. @@ -366,7 +366,7 @@ def calculate_dof_offset(finat_element): return dof_offset -@cached({}, key=make_offset_key) +@serial_cache(hashkey=make_offset_key) def calculate_dof_offset_quotient(finat_element): """Return the offset quotient for each DoF within the base cell. diff --git a/firedrake/interpolation.py b/firedrake/interpolation.py index a7442d7820..f1e9fa93d0 100644 --- a/firedrake/interpolation.py +++ b/firedrake/interpolation.py @@ -12,7 +12,7 @@ from ufl.domain import as_domain, extract_unique_domain from pyop2 import op2 -from pyop2.caching import disk_cached +from pyop2.caching import memory_and_disk_cache from tsfc.finatinterface import create_element, as_fiat_cell from tsfc import compile_expression_dual_evaluation @@ -1204,14 +1204,15 @@ def _interpolator(V, tensor, expr, subset, arguments, access, bcs=None): def _compile_expression_key(comm, expr, to_element, ufl_element, domain, parameters, log): """Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`.""" - # Since the caching is collective, this function must return a 2-tuple of - # the form (comm, key) where comm is the communicator the cache is collective over. - # FIXME FInAT elements are not safely hashable so we ignore them here key = hash_expr(expr), hash(ufl_element), utils.tuplify(parameters), log - return comm, key + return key -@disk_cached({}, _expr_cachedir, key=_compile_expression_key, collective=True) +@memory_and_disk_cache( + hashkey=_compile_expression_key, + cachedir=tsfc_interface._cachedir +) +@PETSc.Log.EventDecorator() def compile_expression(comm, *args, **kwargs): return compile_expression_dual_evaluation(*args, **kwargs) diff --git a/firedrake/logging.py b/firedrake/logging.py index 7b524fbe9c..1841f7a4e4 100644 --- a/firedrake/logging.py +++ b/firedrake/logging.py @@ -5,6 +5,7 @@ import tsfc.logging # noqa: F401 import pyop2.logger # noqa: F401 +from pyop2.configuration import configuration from pyop2.mpi import COMM_WORLD @@ -79,7 +80,7 @@ def set_log_handlers(handlers=None, comm=COMM_WORLD): handler = logging.StreamHandler() handler.setFormatter(logging.Formatter(fmt="%(name)s:%(levelname)s %(message)s")) - if comm is not None and comm.rank != 0: + if comm is not None and comm.rank != 0 and not configuration["spmd_strict"]: handler = logging.NullHandler() logger.addHandler(handler) diff --git a/firedrake/parameters.py b/firedrake/parameters.py index 9d379daab0..5863e76a77 100644 --- a/firedrake/parameters.py +++ b/firedrake/parameters.py @@ -37,7 +37,10 @@ def rename(self, name): def __getstate__(self): # Remove non-picklable update function slot d = self.__dict__.copy() - del d["_update_function"] + try: + del d["_update_function"] + except KeyError: + pass return d def set_update_function(self, callable): diff --git a/firedrake/preconditioners/pmg.py b/firedrake/preconditioners/pmg.py index 7673403017..426e71c030 100644 --- a/firedrake/preconditioners/pmg.py +++ b/firedrake/preconditioners/pmg.py @@ -12,7 +12,7 @@ from tsfc.finatinterface import create_element from tsfc import compile_expression_dual_evaluation from pyop2 import op2 -from pyop2.caching import cached +from pyop2.caching import serial_cache from pyop2.utils import as_tuple import firedrake @@ -589,7 +589,7 @@ def get_readonly_view(arr): return result -@cached({}, key=generate_key_evaluate_dual) +@serial_cache(hashkey=generate_key_evaluate_dual) def evaluate_dual(source, target, derivative=None): """Evaluate the action of a set of dual functionals of the target element on the (derivative of the) basis functions of the source element. @@ -627,7 +627,7 @@ def evaluate_dual(source, target, derivative=None): return get_readonly_view(numpy.dot(A, B)) -@cached({}, key=generate_key_evaluate_dual) +@serial_cache(hashkey=generate_key_evaluate_dual) def compare_element(e1, e2): """Numerically compare two :class:`FIAT.elements`. Equality is satisfied if e2.dual_basis(e1.primal_basis) == identity.""" @@ -639,7 +639,7 @@ def compare_element(e1, e2): return numpy.allclose(B, numpy.eye(B.shape[0]), rtol=1E-14, atol=1E-14) -@cached({}, key=lambda V: V.ufl_element()) +@serial_cache(hashkey=lambda V: V.ufl_element()) @PETSc.Log.EventDecorator("GetLineElements") def get_permutation_to_line_elements(V): """Find DOF permutation to factor out the EnrichedElement expansion diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 2fccb2cfca..1333a04039 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -8,7 +8,6 @@ expressions (finite element variational forms written in UFL). """ import time -from hashlib import md5 from firedrake_citations import Citations from firedrake.tsfc_interface import SplitKernel, KernelInfo, TSFCKernel @@ -28,6 +27,7 @@ from pyop2.utils import get_petsc_dir from pyop2.mpi import COMM_WORLD from pyop2.codegen.rep2loopy import SolveCallable, INVCallable +from pyop2.caching import memory_and_disk_cache import firedrake.slate.slate as slate import numpy as np @@ -67,18 +67,30 @@ class SlateKernel(TSFCKernel): - @classmethod - def _cache_key(cls, expr, compiler_parameters): - return md5( - (expr.expression_hash + str(sorted(compiler_parameters.items()))).encode()).hexdigest(), expr.ufl_domains()[0].comm - def __init__(self, expr, compiler_parameters): - if self._initialized: - return self.split_kernel = generate_loopy_kernel(expr, compiler_parameters) - self._initialized = True +def _compile_expression_hashkey(slate_expr, compiler_parameters=None): + params = copy.deepcopy(parameters) + if compiler_parameters and "slate_compiler" in compiler_parameters.keys(): + params["slate_compiler"].update(compiler_parameters.pop("slate_compiler")) + if compiler_parameters: + params["form_compiler"].update(compiler_parameters) + return getattr(slate_expr, "expression_hash", "ERROR") + str(sorted(params.items())) + + +def _compile_expression_comm(*args, **kwargs): + # args[0] is a slate_expr + return args[0].ufl_domains()[0].comm + + +@memory_and_disk_cache( + hashkey=_compile_expression_hashkey, + comm_fetcher=_compile_expression_comm, + cachedir=tsfc_interface._cachedir +) +@PETSc.Log.EventDecorator() def compile_expression(slate_expr, compiler_parameters=None): """Takes a Slate expression `slate_expr` and returns the appropriate ``pyop2.op2.Kernel`` object representing the Slate expression. @@ -102,15 +114,8 @@ def compile_expression(slate_expr, compiler_parameters=None): if compiler_parameters: params["form_compiler"].update(compiler_parameters) - # If the expression has already been symbolically compiled, then - # simply reuse the produced kernel. - cache = slate_expr._metakernel_cache - key = str(sorted(params.items())) - try: - return cache[key] - except KeyError: - kernel = SlateKernel(slate_expr, params).split_kernel - return cache.setdefault(key, kernel) + kernel = SlateKernel(slate_expr, params).split_kernel + return kernel def get_temp_info(loopy_kernel): diff --git a/firedrake/tsfc_interface.py b/firedrake/tsfc_interface.py index a9e84a8827..a4a57ae0cb 100644 --- a/firedrake/tsfc_interface.py +++ b/firedrake/tsfc_interface.py @@ -4,28 +4,23 @@ passing to the backends. """ -import pickle - -from hashlib import md5 from os import path, environ, getuid, makedirs -import gzip -import os -import zlib import tempfile import collections +import cachetools import ufl import finat.ufl from ufl import Form, conj from .ufl_expr import TestFunction -from tsfc import compile_form as tsfc_compile_form +from tsfc import compile_form as original_tsfc_compile_form from tsfc.parameters import PARAMETERS as tsfc_default_parameters from tsfc.ufl_utils import extract_firedrake_constants from pyop2 import op2 -from pyop2.caching import Cached -from pyop2.mpi import COMM_WORLD, MPI +from pyop2.caching import memory_and_disk_cache, default_parallel_hashkey +from pyop2.mpi import COMM_WORLD from firedrake.formmanipulation import split_form from firedrake.parameters import parameters as default_parameters @@ -52,75 +47,31 @@ "events"]) -class TSFCKernel(Cached): - - _cache = {} - - _cachedir = environ.get('FIREDRAKE_TSFC_KERNEL_CACHE_DIR', - path.join(tempfile.gettempdir(), - 'firedrake-tsfc-kernel-cache-uid%d' % getuid())) - - @classmethod - def _cache_lookup(cls, key): - key, comm = key - # comm has to be part of the in memory key so that when - # compiling the same code on different subcommunicators we - # don't get deadlocks. But MPI_Comm objects are not hashable, - # so use comm.py2f() since this is an internal communicator and - # hence the C handle is stable. - commkey = comm.py2f() - assert commkey != MPI.COMM_NULL.py2f() - return cls._cache.get((key, commkey)) or cls._read_from_disk(key, comm) - - @classmethod - def _read_from_disk(cls, key, comm): - if comm.rank == 0: - cache = cls._cachedir - shard, disk_key = key[:2], key[2:] - filepath = os.path.join(cache, shard, disk_key) - val = None - if os.path.exists(filepath): - try: - with gzip.open(filepath, 'rb') as f: - val = f.read() - except zlib.error: - pass - - comm.bcast(val, root=0) - else: - val = comm.bcast(None, root=0) +_cachedir = environ.get( + 'FIREDRAKE_TSFC_KERNEL_CACHE_DIR', + path.join(tempfile.gettempdir(), f'firedrake-tsfc-kernel-cache-uid{getuid()}') +) - if val is None: - raise KeyError(f"Object with key {key} not found") - return cls._cache.setdefault((key, comm.py2f()), pickle.loads(val)) - @classmethod - def _cache_store(cls, key, val): - key, comm = key - cls._cache[(key, comm.py2f())] = val - _ensure_cachedir(comm=comm) - if comm.rank == 0: - val._key = key - shard, disk_key = key[:2], key[2:] - filepath = os.path.join(cls._cachedir, shard, disk_key) - tempfile = os.path.join(cls._cachedir, shard, "%s_p%d.tmp" % (disk_key, os.getpid())) - # No need for a barrier after this, since non root - # processes will never race on this file. - os.makedirs(os.path.join(cls._cachedir, shard), exist_ok=True) - with gzip.open(tempfile, 'wb') as f: - pickle.dump(val, f, 0) - os.rename(tempfile, filepath) - comm.barrier() - - @classmethod - def _cache_key(cls, form, name, parameters, coefficient_numbers, constant_numbers, interface, diagonal=False): - return md5((form.signature() + name - + str(sorted(parameters.items())) - + str(coefficient_numbers) - + str(constant_numbers) - + str(type(interface)) - + str(diagonal)).encode()).hexdigest(), form.ufl_domains()[0].comm +def tsfc_compile_form_hashkey(form, prefix, parameters, interface, diagonal, log): + # Drop prefix as it's only used for naming and log + return default_parallel_hashkey(form.signature(), prefix, parameters, interface, diagonal) + + +def tsfc_compile_form_comm_fetcher(*args, **kwargs): + # args[0] is a form + return args[0].ufl_domains()[0].comm + + +# Decorate the original tsfc.compile_form with a cache +tsfc_compile_form = memory_and_disk_cache( + hashkey=tsfc_compile_form_hashkey, + comm_fetcher=tsfc_compile_form_comm_fetcher, + cachedir=_cachedir +)(original_tsfc_compile_form) + +class TSFCKernel: def __init__( self, form, @@ -141,8 +92,6 @@ def __init__( :arg interface: the KernelBuilder interface for TSFC (may be None) :arg diagonal: If assembling a matrix is it diagonal? """ - if self._initialized: - return tree = tsfc_compile_form(form, prefix=name, parameters=parameters, interface=interface, diagonal=diagonal, log=PETSc.Log.isActive()) @@ -179,13 +128,33 @@ def __init__( arguments=kernel.arguments, events=events)) self.kernels = tuple(kernels) - self._initialized = True -SplitKernel = collections.namedtuple("SplitKernel", ["indices", - "kinfo"]) +SplitKernel = collections.namedtuple("SplitKernel", ["indices", "kinfo"]) +def _compile_form_hashkey(*args, **kwargs): + # form, name, parameters, split, diagonal + parameters = kwargs.pop("parameters", None) + key = cachetools.keys.hashkey( + args[0].signature(), + *args[1:], + utils.tuplify(parameters), + **kwargs + ) + kwargs.setdefault("parameters", parameters) + return key + + +def _compile_form_comm(*args, **kwargs): + return args[0].ufl_domains()[0].comm + + +@memory_and_disk_cache( + hashkey=_compile_form_hashkey, + comm_fetcher=_compile_form_comm, + cachedir=_cachedir +) @PETSc.Log.EventDecorator() def compile_form(form, name, parameters=None, split=True, interface=None, diagonal=False): """Compile a form using TSFC. @@ -222,16 +191,6 @@ def compile_form(form, name, parameters=None, split=True, interface=None, diagon parameters = default_parameters["form_compiler"].copy() parameters.update(_) - # We stash the compiled kernels on the form so we don't have to recompile - # if we assemble the same form again with the same optimisations - cache = form._cache.setdefault("firedrake_kernels", {}) - - key = (name, utils.tuplify(parameters), split, diagonal) - try: - return cache[key] - except KeyError: - pass - kernels = [] numbering = form.terminal_numbering() if split: @@ -258,15 +217,19 @@ def compile_form(form, name, parameters=None, split=True, interface=None, diagon numbering[c] for c in extract_firedrake_constants(f) ) prefix = name + "".join(map(str, (i for i in idx if i is not None))) - kinfos = TSFCKernel(f, prefix, parameters, - coefficient_numbers, - constant_numbers, - interface, diagonal).kernels - for kinfo in kinfos: + tsfc_kernel = TSFCKernel( + f, + prefix, + parameters, + coefficient_numbers, + constant_numbers, + interface, diagonal + ) + for kinfo in tsfc_kernel.kernels: kernels.append(SplitKernel(idx, kinfo)) kernels = tuple(kernels) - return cache.setdefault(key, kernels) + return kernels def _real_mangle(form): @@ -291,7 +254,7 @@ def clear_cache(comm=None): comm = comm or COMM_WORLD if comm.rank == 0: import shutil - shutil.rmtree(TSFCKernel._cachedir, ignore_errors=True) + shutil.rmtree(_cachedir, ignore_errors=True) _ensure_cachedir(comm=comm) @@ -299,7 +262,7 @@ def _ensure_cachedir(comm=None): """Ensure that the TSFC kernel cache directory exists.""" comm = comm or COMM_WORLD if comm.rank == 0: - makedirs(TSFCKernel._cachedir, exist_ok=True) + makedirs(_cachedir, exist_ok=True) def gather_integer_subdomain_ids(knls): diff --git a/requirements-git.txt b/requirements-git.txt index 8bf05aad21..b6e9e8e1dd 100644 --- a/requirements-git.txt +++ b/requirements-git.txt @@ -5,4 +5,3 @@ git+https://github.com/firedrakeproject/tsfc.git#egg=tsfc git+https://github.com/OP2/PyOP2.git#egg=pyop2 git+https://github.com/dolfin-adjoint/pyadjoint.git#egg=pyadjoint git+https://github.com/firedrakeproject/petsc.git@firedrake#egg=petsc -git+https://github.com/firedrakeproject/pytest-mpi.git@main#egg=pytest-mpi diff --git a/tests/regression/test_ensembleparallelism.py b/tests/regression/test_ensembleparallelism.py index dd7a8a7157..faa3db99dc 100644 --- a/tests/regression/test_ensembleparallelism.py +++ b/tests/regression/test_ensembleparallelism.py @@ -67,7 +67,7 @@ def ensemble(): def mesh(ensemble): if COMM_WORLD.size == 1: return - return UnitSquareMesh(10, 10, comm=ensemble.comm) + return UnitSquareMesh(10, 10, comm=ensemble.comm, distribution_parameters={"partitioner_type": "simple"}) # mixed function space diff --git a/tests/test_tsfc_interface.py b/tests/test_tsfc_interface.py index 39218e4e64..97a2525e14 100644 --- a/tests/test_tsfc_interface.py +++ b/tests/test_tsfc_interface.py @@ -1,8 +1,5 @@ import pytest from firedrake import * -import os -import subprocess -import sys import loopy @@ -48,87 +45,52 @@ def rhs2(fs): return inner(f, v) * dx + inner(g, v) * ds -@pytest.fixture -def cache_key(mass): - return tsfc_interface.TSFCKernel(mass, 'mass', parameters["form_compiler"], (), (), None).cache_key +def test_tsfc_same_form(mass): + """Compiling the same form twice should load kernels from cache.""" + k1 = tsfc_interface.compile_form(mass, 'mass') + k2 = tsfc_interface.compile_form(mass, 'mass') + assert k1 is k2 + assert all(k1_[-1] is k2_[-1] for k1_, k2_ in zip(k1, k2)) -class TestTSFCCache: - """TSFC code generation cache tests.""" +def test_tsfc_same_mixed_form(mixed_mass): + """Compiling a mixed form twice should load kernels from cache.""" + k1 = tsfc_interface.compile_form(mixed_mass, 'mixed_mass') + k2 = tsfc_interface.compile_form(mixed_mass, 'mixed_mass') - def test_cache_key_persistent_across_invocations(self, tmpdir): - code = """ -from firedrake import * -mesh = UnitSquareMesh(1, 1) -V = FunctionSpace(mesh, "CG", 1) -u = TrialFunction(V) -v = TestFunction(V) -key = tsfc_interface.TSFCKernel(inner(u,v)*dx, "mass", parameters["form_compiler"], (), (), None).cache_key -with open("{file}", "w") as f: - f.write(key) - """ - filea = tmpdir.join("a") - fileb = tmpdir.join("b") - subprocess.check_call([sys.executable, "-c", code.format(file=filea)]) - subprocess.check_call([sys.executable, "-c", code.format(file=fileb)]) - with filea.open("r") as f: - key1 = f.read() - with fileb.open("r") as f: - key2 = f.read() - assert key1 == key2 - - def test_tsfc_cache_persist_on_disk(self, cache_key): - """TSFCKernel should be persisted on disk.""" - shard, key = cache_key[:2], cache_key[2:] - assert os.path.exists( - os.path.join(tsfc_interface.TSFCKernel._cachedir, shard, key)) - - def test_tsfc_cache_read_from_disk(self, cache_key): - """Loading an TSFCKernel from disk should yield the right object.""" - assert tsfc_interface.TSFCKernel._read_from_disk( - cache_key, COMM_WORLD).cache_key == cache_key - - def test_tsfc_same_form(self, mass): - """Compiling the same form twice should load kernels from cache.""" - k1 = tsfc_interface.compile_form(mass, 'mass') - k2 = tsfc_interface.compile_form(mass, 'mass') - - assert k1 is k2 - assert all(k1_[-1] is k2_[-1] for k1_, k2_ in zip(k1, k2)) - - def test_tsfc_same_mixed_form(self, mixed_mass): - """Compiling a mixed form twice should load kernels from cache.""" - k1 = tsfc_interface.compile_form(mixed_mass, 'mixed_mass') - k2 = tsfc_interface.compile_form(mixed_mass, 'mixed_mass') - - assert k1 is k2 - assert all(k1_[-1] is k2_[-1] for k1_, k2_ in zip(k1, k2)) - - def test_tsfc_different_forms(self, mass, laplace): - """Compiling different forms should not load kernels from cache.""" - k1, = tsfc_interface.compile_form(mass, 'mass') - k2, = tsfc_interface.compile_form(laplace, 'mass') - - assert k1[-1] is not k2[-1] - - def test_tsfc_different_names(self, mass): - """Compiling different forms should not load kernels from cache.""" - k1, = tsfc_interface.compile_form(mass, 'mass') - k2, = tsfc_interface.compile_form(mass, 'laplace') - - assert k1[-1] is not k2[-1] - - def test_tsfc_cell_kernel(self, mass): - k = tsfc_interface.compile_form(mass, 'mass') - assert len(k) == 1 and 'cell_integral' in loopy.generate_code_v2(k[0][1][0].code).device_code() - - def test_tsfc_exterior_facet_kernel(self, rhs): - k = tsfc_interface.compile_form(rhs, 'rhs') - assert len(k) == 1 and 'exterior_facet_integral' in loopy.generate_code_v2(k[0][1][0].code).device_code() - - def test_tsfc_cell_exterior_facet_kernel(self, rhs2): - k = tsfc_interface.compile_form(rhs2, 'rhs2') - kernel_name = sorted(k_[1][0].name for k_ in k) - assert len(k) == 2 and 'cell_integral' in kernel_name[0] and \ - 'exterior_facet_integral' in kernel_name[1] + assert k1 is k2 + assert all(k1_[-1] is k2_[-1] for k1_, k2_ in zip(k1, k2)) + + +def test_tsfc_different_forms(mass, laplace): + """Compiling different forms should not load kernels from cache.""" + k1, = tsfc_interface.compile_form(mass, 'mass') + k2, = tsfc_interface.compile_form(laplace, 'mass') + + assert k1[-1] is not k2[-1] + + +def test_tsfc_different_names(mass): + """Compiling different forms should not load kernels from cache.""" + k1, = tsfc_interface.compile_form(mass, 'mass') + k2, = tsfc_interface.compile_form(mass, 'laplace') + + assert k1[-1] is not k2[-1] + + +def test_tsfc_cell_kernel(mass): + k = tsfc_interface.compile_form(mass, 'mass') + assert len(k) == 1 and 'cell_integral' in loopy.generate_code_v2(k[0][1][0].code).device_code() + + +def test_tsfc_exterior_facet_kernel(rhs): + k = tsfc_interface.compile_form(rhs, 'rhs') + assert len(k) == 1 and 'exterior_facet_integral' in loopy.generate_code_v2(k[0][1][0].code).device_code() + + +def test_tsfc_cell_exterior_facet_kernel(rhs2): + k = tsfc_interface.compile_form(rhs2, 'rhs2') + kernel_name = sorted(k_[1][0].name for k_ in k) + assert len(k) == 2 and 'cell_integral' in kernel_name[0] and \ + 'exterior_facet_integral' in kernel_name[1] diff --git a/tests/vertexonly/test_vertex_only_fs.py b/tests/vertexonly/test_vertex_only_fs.py index 866a31b810..a44791d0ea 100644 --- a/tests/vertexonly/test_vertex_only_fs.py +++ b/tests/vertexonly/test_vertex_only_fs.py @@ -323,10 +323,14 @@ def test_input_ordering_missing_point(): # put data on the input ordering P0DG_input_ordering = FunctionSpace(vm.input_ordering, "DG", 0) data_input_ordering = Function(P0DG_input_ordering) + if vm.comm.rank == 0: data_input_ordering.dat.data_wo[:] = data + # Accessing data_ro [*here] is collective, hence this redundant call + _ = len(data_input_ordering.dat.data_ro) else: data_input_ordering.dat.data_wo[:] = [] + # [*here] assert not len(data_input_ordering.dat.data_ro) # shouldn't have any halos @@ -348,6 +352,9 @@ def test_input_ordering_missing_point(): data_input_ordering.interpolate(data_on_vm) if vm.comm.rank == 0: assert np.allclose(data_input_ordering.dat.data_ro[0:3], 2*data[0:3]) + # [*here] assert np.allclose(data_input_ordering.dat.data_ro[3], data[3]) else: assert not len(data_input_ordering.dat.data_ro) + # Accessing data_ro [*here] is collective, hence this redundant call + _ = len(data_input_ordering.dat.data_ro) diff --git a/tests/vertexonly/test_vertex_only_mesh_generation.py b/tests/vertexonly/test_vertex_only_mesh_generation.py index c3fe5dda12..3ac5bb43ba 100644 --- a/tests/vertexonly/test_vertex_only_mesh_generation.py +++ b/tests/vertexonly/test_vertex_only_mesh_generation.py @@ -144,24 +144,40 @@ def verify_vertexonly_mesh(m, vm, inputvertexcoords, name): total_cells = MPI.COMM_WORLD.allreduce(len(vm.coordinates.dat.data_ro), op=MPI.SUM) total_in_bounds = MPI.COMM_WORLD.allreduce(len(in_bounds), op=MPI.SUM) skip_in_bounds_checks = False + local_cells = len(vm.coordinates.dat.data_ro) if total_cells != total_in_bounds: assert MPI.COMM_WORLD.size > 1 # i.e. we're in parallel assert total_cells < total_in_bounds # i.e. some points are duplicated - local_cells = len(vm.coordinates.dat.data_ro) local_in_bounds = len(in_bounds) if not local_cells == local_in_bounds and local_in_bounds > 0: - assert max(ref_cell_dists_l1) > 0.5*m.tolerance + # This assertion needs to happen in parallel! + assertion = (max(ref_cell_dists_l1) > 0.5*m.tolerance) skip_in_bounds_checks = True + else: + assertion = True + else: + assertion = True + # FIXME: Replace with parallel assert when it's merged into pytest-mpi + assert min(MPI.COMM_WORLD.allgather([assertion])) + # Correct local coordinates (though not guaranteed to be in same order) if not skip_in_bounds_checks: # Correct local coordinates (though not guaranteed to be in same order) + # [*here] np.allclose(np.sort(vm.coordinates.dat.data_ro), np.sort(inputvertexcoords[in_bounds])) + else: + # Accessing data_ro [*here] is collective, hence this redundant call + _ = len(vm.coordinates.dat.data_ro) # Correct parent topology assert vm._parent_mesh is m assert vm.topology._parent_mesh is m.topology # Correct generic cell properties if not skip_in_bounds_checks: + # [*here] assert vm.cell_closure.shape == (len(vm.coordinates.dat.data_ro_with_halos), 1) + else: + # Accessing data_ro [*here] is collective, hence this redundant call + _ = len(vm.coordinates.dat.data_ro_with_halos) with pytest.raises(AttributeError): vm.exterior_facets() with pytest.raises(AttributeError): @@ -169,8 +185,13 @@ def verify_vertexonly_mesh(m, vm, inputvertexcoords, name): with pytest.raises(AttributeError): vm.cell_to_facets if not skip_in_bounds_checks: + # [*here] assert vm.num_cells() == vm.cell_closure.shape[0] == len(vm.coordinates.dat.data_ro_with_halos) == vm.cell_set.total_size assert vm.cell_set.size == len(inputvertexcoords[in_bounds]) == len(vm.coordinates.dat.data_ro) + else: + # Accessing data_ro and data_ro_with_halos [*here] is collective, hence this redundant call + _ = len(vm.coordinates.dat.data_ro_with_halos) + _ = len(vm.coordinates.dat.data_ro) assert vm.num_facets() == 0 assert vm.num_faces() == vm.num_entities(2) == 0 assert vm.num_edges() == vm.num_entities(1) == 0 @@ -257,11 +278,17 @@ def test_generate_cell_midpoints(parentmesh, redundant): out_of_mesh_point = np.full((1, parentmesh.geometric_dimension()), np.inf) for i in range(max_len): if i < len(vm.coordinates.dat.data_ro): + # [*here] cell_num = parentmesh.locate_cell(vm.coordinates.dat.data_ro[i]) else: cell_num = parentmesh.locate_cell(out_of_mesh_point) # should return None + # Accessing data_ro [*here] is collective, hence this redundant call + _ = len(vm.coordinates.dat.data_ro) if cell_num is not None: assert (f.dat.data_ro[cell_num] == vm.coordinates.dat.data_ro[i]).all() + else: + _ = len(f.dat.data_ro) + _ = len(vm.coordinates.dat.data_ro) # Have correct pyop2 labels as implied by cell set sizes if parentmesh.extruded: