Skip to content

Commit

Permalink
Merge branch 'master' into dependabot/pip/distributed-lt-2024.7
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas authored Jun 19, 2024
2 parents b180bb1 + 9fcc4e3 commit b3393fc
Show file tree
Hide file tree
Showing 12 changed files with 580 additions and 142 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-mlir-mpi-openmp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33
- name: Test with MPI + openmp
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-mlir-mpi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33
- name: Test with MPI - no Openmp
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci-mlir-openmp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33
- name: Test no-MPI, Openmp
run: |
export DEVITO_MPI=0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-mlir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33
- name: Test no-MPI, no-Openmp
run: |
Expand Down
11 changes: 5 additions & 6 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from devito.core.operator import CoreOperator, CustomOperator, ParTile
from devito.exceptions import InvalidOperator
from devito.passes.equations import collect_derivatives
from devito.tools import timed_pass

from devito.passes.clusters import (Lift, blocking, buffering, cire, cse,
factorize, fission, fuse, optimize_hyperplanes,
optimize_pows)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, hoist_prodders,
linearize, mpiize, relax_incr_dimensions)
factorize, fission, fuse, optimize_pows,
optimize_hyperplanes)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, mpiize,
hoist_prodders, relax_incr_dimensions)
from devito.tools import timed_pass


__all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator',
Expand Down
85 changes: 34 additions & 51 deletions devito/core/cpu_xdsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,20 @@
from devito.logger import info, perf
from devito.mpi import MPI
from devito.operator.profiling import create_profile
from devito.tools import filter_sorted, flatten, OrderedSet
from devito.types import TimeFunction
from devito.types.dense import DiscreteFunction, Function
from devito.types.mlir_types import f32, ptr_of
from devito.tools import filter_sorted, flatten, as_tuple

from xdsl.printer import Printer
from xdsl.xdsl_opt_main import xDSLOptMain

from devito.ir.ietxdsl.cluster_to_ssa import (ExtractDevitoStencilConversion,
finalize_module_with_globals) # noqa
finalize_module_with_globals,
setup_memref_args) # noqa

from devito.ir.ietxdsl.profiling import apply_timers
from devito.passes.iet import CTarget, OmpTarget
from devito.core.cpu import Cpu64OperatorMixin


__all__ = ['XdslnoopOperator', 'XdslAdvOperator']


Expand All @@ -57,12 +56,12 @@ def _build(cls, expressions, **kwargs):
Callable.__init__(op, **op.args)

# Header files, etc.
op._headers = OrderedSet(*cls._default_headers)
op._headers.update(byproduct.headers)
op._globals = OrderedSet(*cls._default_globals)
op._includes = OrderedSet(*cls._default_includes)
op._includes.update(profiler._default_includes)
op._includes.update(byproduct.includes)
# op._headers = OrderedSet(*cls._default_headers)
# op._headers.update(byproduct.headers)
# op._globals = OrderedSet(*cls._default_globals)
# op._includes = OrderedSet(*cls._default_includes)
# op._includes.update(profiler._default_includes)
# op._includes.update(byproduct.includes)

# Required for the jit-compilation
op._compiler = kwargs['compiler']
Expand Down Expand Up @@ -94,7 +93,7 @@ def _build(cls, expressions, **kwargs):
op._dtype, op._dspace = irs.clusters.meta
op._profiler = profiler
kwargs['xdsl_num_sections'] = len(FindNodes(Section).visit(irs.iet))
module = cls._lower_stencil(irs.expressions, **kwargs)
module = cls._lower_stencil(expressions, **kwargs)
op._module = module

return op
Expand All @@ -107,8 +106,8 @@ def _lower_stencil(cls, expressions, **kwargs):
Apply timers to the module
"""

conv = ExtractDevitoStencilConversion()
module = conv.convert(expressions, **kwargs)
conv = ExtractDevitoStencilConversion(cls)
module = conv.convert(as_tuple(expressions), **kwargs)
# print(module)
apply_timers(module, timed=True, **kwargs)

Expand Down Expand Up @@ -302,16 +301,16 @@ def cfunction(self):
suffix=".o", delete=delete)
self._make_interop_o()
self._jit_compile()
self.setup_memref_args()
self._jit_kernel_constants.update(setup_memref_args(self.functions))
self._lib = self._compiler.load(self._tf.name)
self._lib.name = self._tf.name

if self._cfunction is None:
self._cfunction = getattr(self._lib, self.name)
# Associate a C type to each argument for runtime type check
argtypes = self._construct_cfunction_args(self._jit_kernel_constants,
get_types=True)
self._cfunction.argtypes = argtypes
# argtypes = self._construct_cfunction_args(self._jit_kernel_constants,
# get_types=True)
# self._cfunction.argtypes = argtypes

return self._cfunction

Expand Down Expand Up @@ -345,49 +344,33 @@ def compile(self, cmd, stdout=None):

return stdout

def setup_memref_args(self):
"""
Add memrefs to args dictionary so they can be passed to the cfunction
"""
args = dict()
for arg in self.functions:
# For every TimeFunction add memref
if isinstance(arg, TimeFunction):
data = arg._data
for t in range(data.shape[0]):
args[f'{arg._C_name}{t}'] = data[t, ...].ctypes.data_as(ptr_of(f32))
if isinstance(arg, Function):
args[f'{arg._C_name}'] = arg._data[...].ctypes.data_as(ptr_of(f32))

self._jit_kernel_constants.update(args)

def _construct_cfunction_args(self, args, get_types=False):
"""
Either construct the args for the cfunction, or construct the
arg types for it.
"""
ps = {
p._C_name: p._C_ctype for p in self.parameters
}
def _construct_cfunction_types(self, args):
# Unused, maybe drop
ps = {p._C_name: p._C_ctype for p in self.parameters}

objects = []
objects_types = []

for name in get_arg_names_from_module(self._module):
object = args[name]
objects.append(object)
if name in ps:
object_type = ps[name]
if object_type == DiscreteFunction._C_ctype:
if object_type == DiscreteFunction._C_ctype: # noqa
object_type = dict(object_type._type_._fields_)['data']
objects_types.append(object_type)
else:
objects_types.append(type(object))
return objects_types

def _construct_cfunction_args(self, args):
"""
Either construct the args for the cfunction, or construct the
arg types for it.
"""

objects = []
for name in get_arg_names_from_module(self._module):
object = args[name]
objects.append(object)

if get_types:
return objects_types
else:
return objects
return objects


class XdslAdvOperator(XdslnoopOperator):
Expand Down
Loading

0 comments on commit b3393fc

Please sign in to comment.