Skip to content

Commit

Permalink
cleanup: Restructure
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Jun 19, 2024
1 parent 756d79f commit 4812168
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 36 deletions.
33 changes: 5 additions & 28 deletions devito/core/cpu_xdsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,18 @@
from devito.mpi import MPI
from devito.operator.profiling import create_profile
from devito.tools import filter_sorted, flatten, as_tuple
from devito.types import TimeFunction
from devito.types.dense import DiscreteFunction, Function
from devito.types.mlir_types import f32, ptr_of

from xdsl.printer import Printer
from xdsl.xdsl_opt_main import xDSLOptMain

from devito.ir.ietxdsl.cluster_to_ssa import (ExtractDevitoStencilConversion,
finalize_module_with_globals) # noqa
finalize_module_with_globals,
setup_memref_args) # noqa

from devito.ir.ietxdsl.profiling import apply_timers
from devito.passes.iet import CTarget, OmpTarget
from devito.core.cpu import Cpu64OperatorMixin

from examples.seismic.source import PointSource

__all__ = ['XdslnoopOperator', 'XdslAdvOperator']

Expand Down Expand Up @@ -304,7 +301,7 @@ def cfunction(self):
suffix=".o", delete=delete)
self._make_interop_o()
self._jit_compile()
self.setup_memref_args()
self._jit_kernel_constants.update(setup_memref_args(self.functions))
self._lib = self._compiler.load(self._tf.name)
self._lib.name = self._tf.name

Expand Down Expand Up @@ -347,35 +344,15 @@ def compile(self, cmd, stdout=None):

return stdout

def setup_memref_args(self):
"""
Add memrefs to args dictionary so they can be passed to the cfunction
"""
args = dict()
for arg in self.functions:
# For every TimeFunction add memref
if isinstance(arg, TimeFunction):
data = arg._data
for t in range(data.shape[0]):
args[f'{arg._C_name}{t}'] = data[t, ...].ctypes.data_as(ptr_of(f32))
elif isinstance(arg, Function):
args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32))

elif isinstance(arg, PointSource):
args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32))
else:
raise NotImplementedError(f"type {type(arg)} not implemented")

self._jit_kernel_constants.update(args)

def _construct_cfunction_types(self, args):
# Unused, maybe drop
ps = {p._C_name: p._C_ctype for p in self.parameters}

objects_types = []
for name in get_arg_names_from_module(self._module):
if name in ps:
object_type = ps[name]
if object_type == DiscreteFunction._C_ctype:
if object_type == DiscreteFunction._C_ctype: # noqa
object_type = dict(object_type._type_._fields_)['data']
objects_types.append(object_type)
else:
Expand Down
1 change: 1 addition & 0 deletions devito/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from devito.data.meta import * # noqa
from devito.data.allocators import * # noqa
from devito.data.allocators_xdsl import * # noqa
from devito.data.decomposition import * # noqa
from devito.data.data import * # noqa
from devito.data.utils import * # noqa
34 changes: 26 additions & 8 deletions devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from functools import reduce
import numpy as np

# ------------- General imports -------------#

from typing import Any, Iterable
Expand Down Expand Up @@ -45,19 +47,14 @@

# ------------- devito-xdsl SSA imports -------------#
from devito.ir.ietxdsl import iet_ssa
from devito.ir.ietxdsl.utils import is_int, is_float
import numpy as np
from devito.ir.ietxdsl.utils import is_int, is_float, dtypes_to_xdsltypes
from devito.types.mlir_types import f32, ptr_of


from examples.seismic.source import PointSource
from tests.test_interpolation import points
from tests.test_timestepping import d

dtypes_to_xdsltypes = {
np.float32: builtin.f32,
np.float64: builtin.f64,
np.int32: builtin.i32,
np.int64: builtin.i64,
}

# flake8: noqa

Expand All @@ -73,6 +70,27 @@ def field_from_function(f: DiscreteFunction) -> stencil.FieldType:
return stencil.FieldType(bounds, element_type=dtypes_to_xdsltypes[f.dtype])


def setup_memref_args(functions):
"""
Add memrefs to args dictionary so they can be passed to the cfunction
"""
args = dict()
for arg in functions:
# For every TimeFunction add memref
if isinstance(arg, TimeFunction):
data = arg._data
for t in range(data.shape[0]):
args[f'{arg._C_name}{t}'] = data[t, ...].ctypes.data_as(ptr_of(f32))
elif isinstance(arg, Function):
args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32))

elif isinstance(arg, PointSource):
args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32))
else:
raise NotImplementedError(f"type {type(arg)} not implemented")

return args

class ExtractDevitoStencilConversion:
"""
Lower Devito equations to the stencil dialect
Expand Down
10 changes: 10 additions & 0 deletions devito/ir/ietxdsl/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from xdsl.dialects import builtin
from xdsl.ir import SSAValue

Expand All @@ -8,3 +10,11 @@ def is_int(val: SSAValue):

def is_float(val: SSAValue):
return val.type in (builtin.f32, builtin.f64)


dtypes_to_xdsltypes = {
np.float32: builtin.f32,
np.float64: builtin.f64,
np.int32: builtin.i32,
np.int64: builtin.i64,
}

0 comments on commit 4812168

Please sign in to comment.