From 481216847e410c0ccb8df384257e2409887b690a Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Wed, 19 Jun 2024 13:20:19 +0100 Subject: [PATCH] cleanup: Restructure --- devito/core/cpu_xdsl.py | 33 +++++----------------------- devito/data/__init__.py | 1 + devito/ir/ietxdsl/cluster_to_ssa.py | 34 ++++++++++++++++++++++------- devito/ir/ietxdsl/utils.py | 10 +++++++++ 4 files changed, 42 insertions(+), 36 deletions(-) diff --git a/devito/core/cpu_xdsl.py b/devito/core/cpu_xdsl.py index b9ddc5546e..6c4414a63a 100644 --- a/devito/core/cpu_xdsl.py +++ b/devito/core/cpu_xdsl.py @@ -19,21 +19,18 @@ from devito.mpi import MPI from devito.operator.profiling import create_profile from devito.tools import filter_sorted, flatten, as_tuple -from devito.types import TimeFunction -from devito.types.dense import DiscreteFunction, Function -from devito.types.mlir_types import f32, ptr_of from xdsl.printer import Printer from xdsl.xdsl_opt_main import xDSLOptMain from devito.ir.ietxdsl.cluster_to_ssa import (ExtractDevitoStencilConversion, - finalize_module_with_globals) # noqa + finalize_module_with_globals, + setup_memref_args) # noqa from devito.ir.ietxdsl.profiling import apply_timers from devito.passes.iet import CTarget, OmpTarget from devito.core.cpu import Cpu64OperatorMixin -from examples.seismic.source import PointSource __all__ = ['XdslnoopOperator', 'XdslAdvOperator'] @@ -304,7 +301,7 @@ def cfunction(self): suffix=".o", delete=delete) self._make_interop_o() self._jit_compile() - self.setup_memref_args() + self._jit_kernel_constants.update(setup_memref_args(self.functions)) self._lib = self._compiler.load(self._tf.name) self._lib.name = self._tf.name @@ -347,35 +344,15 @@ def compile(self, cmd, stdout=None): return stdout - def setup_memref_args(self): - """ - Add memrefs to args dictionary so they can be passed to the cfunction - """ - args = dict() - for arg in self.functions: - # For every TimeFunction add memref - if isinstance(arg, TimeFunction): - data = arg._data - for t in range(data.shape[0]): - args[f'{arg._C_name}{t}'] = data[t, ...].ctypes.data_as(ptr_of(f32)) - elif isinstance(arg, Function): - args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32)) - - elif isinstance(arg, PointSource): - args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32)) - else: - raise NotImplementedError(f"type {type(arg)} not implemented") - - self._jit_kernel_constants.update(args) - def _construct_cfunction_types(self, args): + # Unused, maybe drop ps = {p._C_name: p._C_ctype for p in self.parameters} objects_types = [] for name in get_arg_names_from_module(self._module): if name in ps: object_type = ps[name] - if object_type == DiscreteFunction._C_ctype: + if object_type == DiscreteFunction._C_ctype: # noqa object_type = dict(object_type._type_._fields_)['data'] objects_types.append(object_type) else: diff --git a/devito/data/__init__.py b/devito/data/__init__.py index 09fd26f353..c481fdc294 100644 --- a/devito/data/__init__.py +++ b/devito/data/__init__.py @@ -1,5 +1,6 @@ from devito.data.meta import * # noqa from devito.data.allocators import * # noqa +from devito.data.allocators_xdsl import * # noqa from devito.data.decomposition import * # noqa from devito.data.data import * # noqa from devito.data.utils import * # noqa diff --git a/devito/ir/ietxdsl/cluster_to_ssa.py b/devito/ir/ietxdsl/cluster_to_ssa.py index c6d9b0ad3d..c8c1cbe26b 100644 --- a/devito/ir/ietxdsl/cluster_to_ssa.py +++ b/devito/ir/ietxdsl/cluster_to_ssa.py @@ -1,4 +1,6 @@ from functools import reduce +import numpy as np + # ------------- General imports -------------# from typing import Any, Iterable @@ -45,19 +47,14 @@ # ------------- devito-xdsl SSA imports -------------# from devito.ir.ietxdsl import iet_ssa -from devito.ir.ietxdsl.utils import is_int, is_float -import numpy as np +from devito.ir.ietxdsl.utils import is_int, is_float, dtypes_to_xdsltypes +from devito.types.mlir_types import f32, ptr_of + from examples.seismic.source import PointSource from tests.test_interpolation import points from tests.test_timestepping import d -dtypes_to_xdsltypes = { - np.float32: builtin.f32, - np.float64: builtin.f64, - np.int32: builtin.i32, - np.int64: builtin.i64, -} # flake8: noqa @@ -73,6 +70,27 @@ def field_from_function(f: DiscreteFunction) -> stencil.FieldType: return stencil.FieldType(bounds, element_type=dtypes_to_xdsltypes[f.dtype]) +def setup_memref_args(functions): + """ + Add memrefs to args dictionary so they can be passed to the cfunction + """ + args = dict() + for arg in functions: + # For every TimeFunction add memref + if isinstance(arg, TimeFunction): + data = arg._data + for t in range(data.shape[0]): + args[f'{arg._C_name}{t}'] = data[t, ...].ctypes.data_as(ptr_of(f32)) + elif isinstance(arg, Function): + args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32)) + + elif isinstance(arg, PointSource): + args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32)) + else: + raise NotImplementedError(f"type {type(arg)} not implemented") + + return args + class ExtractDevitoStencilConversion: """ Lower Devito equations to the stencil dialect diff --git a/devito/ir/ietxdsl/utils.py b/devito/ir/ietxdsl/utils.py index 04fcbd761b..da213d076a 100644 --- a/devito/ir/ietxdsl/utils.py +++ b/devito/ir/ietxdsl/utils.py @@ -1,3 +1,5 @@ +import numpy as np + from xdsl.dialects import builtin from xdsl.ir import SSAValue @@ -8,3 +10,11 @@ def is_int(val: SSAValue): def is_float(val: SSAValue): return val.type in (builtin.f32, builtin.f64) + + +dtypes_to_xdsltypes = { + np.float32: builtin.f32, + np.float64: builtin.f64, + np.int32: builtin.i32, + np.int64: builtin.i64, +}