Skip to content

Commit

Permalink
Merge pull request #95 from xdslproject/emilien/source
Browse files Browse the repository at this point in the history
api: Introduce Source injection
  • Loading branch information
georgebisbas authored Jun 19, 2024
2 parents a9843db + aa82eb6 commit 756d79f
Show file tree
Hide file tree
Showing 12 changed files with 554 additions and 121 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-mlir-mpi-openmp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33
- name: Test with MPI + openmp
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-mlir-mpi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33
- name: Test with MPI - no Openmp
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci-mlir-openmp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33
- name: Test no-MPI, Openmp
run: |
export DEVITO_MPI=0
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-mlir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: |
pip install -e .[tests]
pip install mpi4py
pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997
pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33
- name: Test no-MPI, no-Openmp
run: |
Expand Down
11 changes: 5 additions & 6 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from devito.core.operator import CoreOperator, CustomOperator, ParTile
from devito.exceptions import InvalidOperator
from devito.passes.equations import collect_derivatives
from devito.tools import timed_pass

from devito.passes.clusters import (Lift, blocking, buffering, cire, cse,
factorize, fission, fuse, optimize_hyperplanes,
optimize_pows)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, hoist_prodders,
linearize, mpiize, relax_incr_dimensions)
factorize, fission, fuse, optimize_pows,
optimize_hyperplanes)
from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, mpiize,
hoist_prodders, relax_incr_dimensions)
from devito.tools import timed_pass


__all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator',
Expand Down
68 changes: 37 additions & 31 deletions devito/core/cpu_xdsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from devito.logger import info, perf
from devito.mpi import MPI
from devito.operator.profiling import create_profile
from devito.tools import filter_sorted, flatten, OrderedSet
from devito.tools import filter_sorted, flatten, as_tuple
from devito.types import TimeFunction
from devito.types.dense import DiscreteFunction, Function
from devito.types.mlir_types import f32, ptr_of
Expand All @@ -33,6 +33,8 @@
from devito.passes.iet import CTarget, OmpTarget
from devito.core.cpu import Cpu64OperatorMixin

from examples.seismic.source import PointSource

__all__ = ['XdslnoopOperator', 'XdslAdvOperator']


Expand All @@ -57,12 +59,12 @@ def _build(cls, expressions, **kwargs):
Callable.__init__(op, **op.args)

# Header files, etc.
op._headers = OrderedSet(*cls._default_headers)
op._headers.update(byproduct.headers)
op._globals = OrderedSet(*cls._default_globals)
op._includes = OrderedSet(*cls._default_includes)
op._includes.update(profiler._default_includes)
op._includes.update(byproduct.includes)
# op._headers = OrderedSet(*cls._default_headers)
# op._headers.update(byproduct.headers)
# op._globals = OrderedSet(*cls._default_globals)
# op._includes = OrderedSet(*cls._default_includes)
# op._includes.update(profiler._default_includes)
# op._includes.update(byproduct.includes)

# Required for the jit-compilation
op._compiler = kwargs['compiler']
Expand Down Expand Up @@ -94,7 +96,7 @@ def _build(cls, expressions, **kwargs):
op._dtype, op._dspace = irs.clusters.meta
op._profiler = profiler
kwargs['xdsl_num_sections'] = len(FindNodes(Section).visit(irs.iet))
module = cls._lower_stencil(irs.expressions, **kwargs)
module = cls._lower_stencil(expressions, **kwargs)
op._module = module

return op
Expand All @@ -107,8 +109,8 @@ def _lower_stencil(cls, expressions, **kwargs):
Apply timers to the module
"""

conv = ExtractDevitoStencilConversion()
module = conv.convert(expressions, **kwargs)
conv = ExtractDevitoStencilConversion(cls)
module = conv.convert(as_tuple(expressions), **kwargs)
# print(module)
apply_timers(module, timed=True, **kwargs)

Expand Down Expand Up @@ -309,9 +311,9 @@ def cfunction(self):
if self._cfunction is None:
self._cfunction = getattr(self._lib, self.name)
# Associate a C type to each argument for runtime type check
argtypes = self._construct_cfunction_args(self._jit_kernel_constants,
get_types=True)
self._cfunction.argtypes = argtypes
# argtypes = self._construct_cfunction_args(self._jit_kernel_constants,
# get_types=True)
# self._cfunction.argtypes = argtypes

return self._cfunction

Expand Down Expand Up @@ -356,38 +358,42 @@ def setup_memref_args(self):
data = arg._data
for t in range(data.shape[0]):
args[f'{arg._C_name}{t}'] = data[t, ...].ctypes.data_as(ptr_of(f32))
if isinstance(arg, Function):
args[f'{arg._C_name}'] = arg._data[...].ctypes.data_as(ptr_of(f32))
elif isinstance(arg, Function):
args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32))

elif isinstance(arg, PointSource):
args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32))
else:
raise NotImplementedError(f"type {type(arg)} not implemented")

self._jit_kernel_constants.update(args)

def _construct_cfunction_args(self, args, get_types=False):
"""
Either construct the args for the cfunction, or construct the
arg types for it.
"""
ps = {
p._C_name: p._C_ctype for p in self.parameters
}
def _construct_cfunction_types(self, args):
ps = {p._C_name: p._C_ctype for p in self.parameters}

objects = []
objects_types = []

for name in get_arg_names_from_module(self._module):
object = args[name]
objects.append(object)
if name in ps:
object_type = ps[name]
if object_type == DiscreteFunction._C_ctype:
object_type = dict(object_type._type_._fields_)['data']
objects_types.append(object_type)
else:
objects_types.append(type(object))
return objects_types

def _construct_cfunction_args(self, args):
"""
Either construct the args for the cfunction, or construct the
arg types for it.
"""

objects = []
for name in get_arg_names_from_module(self._module):
object = args[name]
objects.append(object)

if get_types:
return objects_types
else:
return objects
return objects


class XdslAdvOperator(XdslnoopOperator):
Expand Down
Loading

0 comments on commit 756d79f

Please sign in to comment.