diff --git a/.github/workflows/ci-mlir-mpi-openmp.yml b/.github/workflows/ci-mlir-mpi-openmp.yml index 75217e08d3..5277b7f19a 100644 --- a/.github/workflows/ci-mlir-mpi-openmp.yml +++ b/.github/workflows/ci-mlir-mpi-openmp.yml @@ -36,7 +36,7 @@ jobs: run: | pip install -e .[tests] pip install mpi4py - pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997 + pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33 - name: Test with MPI + openmp run: | diff --git a/.github/workflows/ci-mlir-mpi.yml b/.github/workflows/ci-mlir-mpi.yml index b22e68bebd..f5f6be2638 100644 --- a/.github/workflows/ci-mlir-mpi.yml +++ b/.github/workflows/ci-mlir-mpi.yml @@ -36,7 +36,7 @@ jobs: run: | pip install -e .[tests] pip install mpi4py - pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997 + pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33 - name: Test with MPI - no Openmp run: | diff --git a/.github/workflows/ci-mlir-openmp.yml b/.github/workflows/ci-mlir-openmp.yml index bfadf023cb..f8a8e6f04e 100644 --- a/.github/workflows/ci-mlir-openmp.yml +++ b/.github/workflows/ci-mlir-openmp.yml @@ -36,8 +36,8 @@ jobs: run: | pip install -e .[tests] pip install mpi4py - pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997 - + pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33 + - name: Test no-MPI, Openmp run: | export DEVITO_MPI=0 diff --git a/.github/workflows/ci-mlir.yml b/.github/workflows/ci-mlir.yml index a1eb856b18..3ed33c97cc 100644 --- a/.github/workflows/ci-mlir.yml +++ b/.github/workflows/ci-mlir.yml @@ -36,7 +36,7 @@ jobs: run: | pip install -e .[tests] pip install mpi4py - pip install git+https://github.com/xdslproject/xdsl@210181350d926f91ee5fdb27f0eb5d1cf53a8997 + pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33 - name: Test no-MPI, no-Openmp run: | diff --git a/devito/core/cpu.py b/devito/core/cpu.py index 3f80d5e04e..ecf6dce42f 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -3,13 +3,12 @@ from devito.core.operator import CoreOperator, CustomOperator, ParTile from devito.exceptions import InvalidOperator from devito.passes.equations import collect_derivatives -from devito.tools import timed_pass - from devito.passes.clusters import (Lift, blocking, buffering, cire, cse, - factorize, fission, fuse, optimize_hyperplanes, - optimize_pows) -from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, hoist_prodders, - linearize, mpiize, relax_incr_dimensions) + factorize, fission, fuse, optimize_pows, + optimize_hyperplanes) +from devito.passes.iet import (CTarget, OmpTarget, avoid_denormals, linearize, mpiize, + hoist_prodders, relax_incr_dimensions) +from devito.tools import timed_pass __all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator', diff --git a/devito/core/cpu_xdsl.py b/devito/core/cpu_xdsl.py index 91f52b2ffd..6c4414a63a 100644 --- a/devito/core/cpu_xdsl.py +++ b/devito/core/cpu_xdsl.py @@ -18,21 +18,20 @@ from devito.logger import info, perf from devito.mpi import MPI from devito.operator.profiling import create_profile -from devito.tools import filter_sorted, flatten, OrderedSet -from devito.types import TimeFunction -from devito.types.dense import DiscreteFunction, Function -from devito.types.mlir_types import f32, ptr_of +from devito.tools import filter_sorted, flatten, as_tuple from xdsl.printer import Printer from xdsl.xdsl_opt_main import xDSLOptMain from devito.ir.ietxdsl.cluster_to_ssa import (ExtractDevitoStencilConversion, - finalize_module_with_globals) # noqa + finalize_module_with_globals, + setup_memref_args) # noqa from devito.ir.ietxdsl.profiling import apply_timers from devito.passes.iet import CTarget, OmpTarget from devito.core.cpu import Cpu64OperatorMixin + __all__ = ['XdslnoopOperator', 'XdslAdvOperator'] @@ -57,12 +56,12 @@ def _build(cls, expressions, **kwargs): Callable.__init__(op, **op.args) # Header files, etc. - op._headers = OrderedSet(*cls._default_headers) - op._headers.update(byproduct.headers) - op._globals = OrderedSet(*cls._default_globals) - op._includes = OrderedSet(*cls._default_includes) - op._includes.update(profiler._default_includes) - op._includes.update(byproduct.includes) + # op._headers = OrderedSet(*cls._default_headers) + # op._headers.update(byproduct.headers) + # op._globals = OrderedSet(*cls._default_globals) + # op._includes = OrderedSet(*cls._default_includes) + # op._includes.update(profiler._default_includes) + # op._includes.update(byproduct.includes) # Required for the jit-compilation op._compiler = kwargs['compiler'] @@ -94,7 +93,7 @@ def _build(cls, expressions, **kwargs): op._dtype, op._dspace = irs.clusters.meta op._profiler = profiler kwargs['xdsl_num_sections'] = len(FindNodes(Section).visit(irs.iet)) - module = cls._lower_stencil(irs.expressions, **kwargs) + module = cls._lower_stencil(expressions, **kwargs) op._module = module return op @@ -107,8 +106,8 @@ def _lower_stencil(cls, expressions, **kwargs): Apply timers to the module """ - conv = ExtractDevitoStencilConversion() - module = conv.convert(expressions, **kwargs) + conv = ExtractDevitoStencilConversion(cls) + module = conv.convert(as_tuple(expressions), **kwargs) # print(module) apply_timers(module, timed=True, **kwargs) @@ -302,16 +301,16 @@ def cfunction(self): suffix=".o", delete=delete) self._make_interop_o() self._jit_compile() - self.setup_memref_args() + self._jit_kernel_constants.update(setup_memref_args(self.functions)) self._lib = self._compiler.load(self._tf.name) self._lib.name = self._tf.name if self._cfunction is None: self._cfunction = getattr(self._lib, self.name) # Associate a C type to each argument for runtime type check - argtypes = self._construct_cfunction_args(self._jit_kernel_constants, - get_types=True) - self._cfunction.argtypes = argtypes + # argtypes = self._construct_cfunction_args(self._jit_kernel_constants, + # get_types=True) + # self._cfunction.argtypes = argtypes return self._cfunction @@ -345,49 +344,33 @@ def compile(self, cmd, stdout=None): return stdout - def setup_memref_args(self): - """ - Add memrefs to args dictionary so they can be passed to the cfunction - """ - args = dict() - for arg in self.functions: - # For every TimeFunction add memref - if isinstance(arg, TimeFunction): - data = arg._data - for t in range(data.shape[0]): - args[f'{arg._C_name}{t}'] = data[t, ...].ctypes.data_as(ptr_of(f32)) - if isinstance(arg, Function): - args[f'{arg._C_name}'] = arg._data[...].ctypes.data_as(ptr_of(f32)) - - self._jit_kernel_constants.update(args) - - def _construct_cfunction_args(self, args, get_types=False): - """ - Either construct the args for the cfunction, or construct the - arg types for it. - """ - ps = { - p._C_name: p._C_ctype for p in self.parameters - } + def _construct_cfunction_types(self, args): + # Unused, maybe drop + ps = {p._C_name: p._C_ctype for p in self.parameters} - objects = [] objects_types = [] - for name in get_arg_names_from_module(self._module): - object = args[name] - objects.append(object) if name in ps: object_type = ps[name] - if object_type == DiscreteFunction._C_ctype: + if object_type == DiscreteFunction._C_ctype: # noqa object_type = dict(object_type._type_._fields_)['data'] objects_types.append(object_type) else: objects_types.append(type(object)) + return objects_types + + def _construct_cfunction_args(self, args): + """ + Either construct the args for the cfunction, or construct the + arg types for it. + """ + + objects = [] + for name in get_arg_names_from_module(self._module): + object = args[name] + objects.append(object) - if get_types: - return objects_types - else: - return objects + return objects class XdslAdvOperator(XdslnoopOperator): diff --git a/devito/ir/ietxdsl/cluster_to_ssa.py b/devito/ir/ietxdsl/cluster_to_ssa.py index 4b00463285..c8c1cbe26b 100644 --- a/devito/ir/ietxdsl/cluster_to_ssa.py +++ b/devito/ir/ietxdsl/cluster_to_ssa.py @@ -1,8 +1,24 @@ +from functools import reduce +import numpy as np + # ------------- General imports -------------# from typing import Any, Iterable from dataclasses import dataclass, field -from sympy import Add, Expr, Float, Indexed, Integer, Mod, Mul, Pow, Symbol +from sympy import Add, And, Expr, Float, GreaterThan, Indexed, Integer, LessThan, Mod, Mul, Number, Pow, StrictGreaterThan, StrictLessThan, Symbol, floor +from sympy.core.relational import Relational +from sympy.logic.boolalg import BooleanFunction +from devito.ir.equations.equation import OpInc +from devito.operations.interpolators import Injection +from devito.operator.operator import Operator +from devito.symbolics.search import retrieve_dimensions, retrieve_functions +from devito.symbolics.extended_sympy import INT +from devito.tools.data_structures import OrderedSet +from devito.tools.utils import as_tuple +from devito.types.basic import Scalar +from devito.types.dense import DiscreteFunction, Function, TimeFunction +from devito.types.dimension import SpaceDimension, TimeDimension +from devito.types.equation import Eq # ------------- xdsl imports -------------# from xdsl.dialects import (arith, builtin, func, memref, scf, @@ -15,8 +31,10 @@ PatternRewriteWalker, RewritePattern, op_type_rewrite_pattern, + InsertPoint ) from xdsl.builder import ImplicitBuilder +from xdsl.transforms.experimental.convert_stencil_to_ll_mlir import StencilToMemRefType # ------------- devito imports -------------# from devito import Grid, SteppingDimension @@ -29,23 +47,56 @@ # ------------- devito-xdsl SSA imports -------------# from devito.ir.ietxdsl import iet_ssa -from devito.ir.ietxdsl.utils import is_int, is_float +from devito.ir.ietxdsl.utils import is_int, is_float, dtypes_to_xdsltypes +from devito.types.mlir_types import f32, ptr_of + + +from examples.seismic.source import PointSource +from tests.test_interpolation import points +from tests.test_timestepping import d + # flake8: noqa def field_from_function(f: DiscreteFunction) -> stencil.FieldType: - halo = [f.halo[d] for d in f.grid.dimensions] - shape = f.grid.shape + # import pdb;pdb.set_trace() + halo = [f.halo[d] for d in f.dimensions] + shape = f.shape bounds = [(-h[0], s+h[1]) for h, s in zip(halo, shape)] - return stencil.FieldType(bounds, element_type=dtype_to_xdsltype(f.dtype)) + if isinstance(f, TimeFunction): + bounds = bounds[1:] + + return stencil.FieldType(bounds, element_type=dtypes_to_xdsltypes[f.dtype]) +def setup_memref_args(functions): + """ + Add memrefs to args dictionary so they can be passed to the cfunction + """ + args = dict() + for arg in functions: + # For every TimeFunction add memref + if isinstance(arg, TimeFunction): + data = arg._data + for t in range(data.shape[0]): + args[f'{arg._C_name}{t}'] = data[t, ...].ctypes.data_as(ptr_of(f32)) + elif isinstance(arg, Function): + args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32)) + + elif isinstance(arg, PointSource): + args[arg._C_name] = arg._data[...].ctypes.data_as(ptr_of(f32)) + else: + raise NotImplementedError(f"type {type(arg)} not implemented") + + return args + class ExtractDevitoStencilConversion: """ Lower Devito equations to the stencil dialect """ + operator: type[Operator] eqs: list[LoweredEq] block: Block temps: dict[tuple[DiscreteFunction, int], SSAValue] @@ -58,6 +109,10 @@ def __init__(self): time_offs: int + def __init__(self, operator: type[Operator]): + self.temps = dict() + self.operator = operator + def convert_function_eq(self, eq: LoweredEq, **kwargs): # Read the grid containing necessary discretization information # (size, halo width, ...) @@ -74,8 +129,20 @@ def convert_function_eq(self, eq: LoweredEq, **kwargs): else: raise NotImplementedError(f"Function of type {type(write_function)} not supported") - # Get the function carriers of the equation - self._build_step_body(step_dim, eq) + dims = retrieve_dimensions(eq.lhs.indices) + + if not all(isinstance(d, (SteppingDimension, SpaceDimension)) for d in dims): + self.build_generic_step(step_dim, eq) + else: + # Get the function carriers of the equation + self.build_stencil_step(step_dim, eq) + + def convert_symbol_eq(self, symbol: Symbol, rhs: LoweredEq, **kwargs): + """ + Convert a symbol equation to xDSL. + """ + self.symbol_values[symbol.name] = self._visit_math_nodes(None, rhs, None) + self.symbol_values[symbol.name].name_hint = symbol.name def convert_symbol_eq(self, symbol: Symbol, rhs: LoweredEq, **kwargs): """ @@ -116,7 +183,6 @@ def _convert_eq(self, eq: LoweredEq, **kwargs): "stencil.store"(%4, %u_t1) {"lb" = #stencil.index<0>, "ub" = #stencil.index<3>} : (!stencil.temp, !stencil.field<[-1,4]xf32>) -> () ``` """ - # Get the left hand side, called "output function" here because it tells us # Where to write the results of each step. write_function = eq.lhs @@ -139,18 +205,35 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, output_indexed: Indexed) -> SSAValue: # Handle Indexeds if isinstance(node, Indexed): - space_offsets = [] - for d in node.function.space_dimensions: - space_offsets.append(node.indices[d] - output_indexed.indices[d]) + # If we have a time function, we compute its time offset if isinstance(node.function, TimeFunction): time_offset = (node.indices[dim] - dim) % node.function.time_size - elif isinstance(node.function, Function): + elif isinstance(node.function, (Function, PointSource)): time_offset = 0 else: - raise NotImplementedError(f"reading function of type {type(node.func)} not supported") # noqa - temp = self.apply_temps[(node.function, time_offset)] - access = stencil.AccessOp.get(temp, space_offsets) - return access.res + raise NotImplementedError(f"reading function of type {type(node.func)} not supported") + # If we are in a stencil (encoded by having the output_indexed passed), we + # compute the relative space offsets and make it a stencil offset + if output_indexed is not None: + space_offsets = [node.indices[d] - output_indexed.indices[d] for d in node.function.space_dimensions] + temp = self.function_values[(node.function, time_offset)] + access = stencil.AccessOp.get(temp, space_offsets) + return access.res + # Otherwise, generate a load op + else: + temp = self.function_values[(node.function, time_offset)] + memtemp = builtin.UnrealizedConversionCastOp.get(temp, StencilToMemRefType(temp.type)).results[0] + memtemp.name_hint = temp.name_hint + "_mem" + indices = node.indices + if isinstance(node.function, TimeFunction): + indices = indices[1:] + ssa_indices = [self._visit_math_nodes(dim, i, output_indexed) for i in node.indices] + for i, ssa_i in enumerate(ssa_indices): + if isinstance(ssa_i.type, builtin.IntegerType): + ssa_indices[i] = arith.IndexCastOp(ssa_i, builtin.IndexType()) + return memref.Load.get(memtemp, ssa_indices).res + + import pdb; pdb.set_trace() # Handle Integers elif isinstance(node, Integer): cst = arith.Constant.from_int_and_width(int(node), builtin.i64) @@ -161,8 +244,12 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, return cst.result # Handle Symbols elif isinstance(node, Symbol): - symb = iet_ssa.LoadSymbolic.get(node.name, builtin.f32) - return symb.result + if node.name in self.symbol_values: + return self.symbol_values[node.name] + else: + mlir_dtype = dtypes_to_xdsltypes[node.dtype] + symb = iet_ssa.LoadSymbolic.get(node.name, mlir_dtype) + return symb.result # Handle Add Mul elif isinstance(node, (Add, Mul)): args = [self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args] @@ -170,13 +257,13 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, # get first element out, store the rest in args # this makes the reduction easier carry, *args = self._ensure_same_type(*args) - # select the correct op from arith.Addi, arith.Addf, arith.Muli, arith.Mulf - if isinstance(carry.type, builtin.IntegerType): + # select the correct op from arith.addi, arith.addf, arith.muli, arith.mulf + if is_int(carry): op_cls = arith.Addi if isinstance(node, Add) else arith.Muli elif isinstance(carry.type, builtin.Float32Type): op_cls = arith.Addf if isinstance(node, Add) else arith.Mulf else: - raise("Add support for another type") + raise NotImplementedError(f"Add support for another type {carry.type}") for arg in args: op = op_cls(carry, arg) carry = op.result @@ -205,12 +292,40 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, op = op_cls(base, ex) return op.result # Handle Mod - elif isinstance(node, Mod): - raise NotImplementedError("Go away, no mod here. >:(") + elif isinstance(node, INT): + assert len(node.args) == 1, "Expected single argument for integer cast." + return arith.FPToSIOp(self._visit_math_nodes(dim, node.args[0], output_indexed), builtin.i64).result + elif isinstance(node, floor): + assert len(node.args) == 1, "Expected single argument for floor." + return math.FloorOp(self._visit_math_nodes(dim, node.args[0], output_indexed)).result + elif isinstance(node, And): + SSAargs = (self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args) + return reduce(lambda x,y : arith.AndI(x,y).result, SSAargs) + elif isinstance(node, Relational): + if isinstance(node, GreaterThan): + mnemonic = "sge" + elif isinstance(node, LessThan): + mnemonic = "sle" + elif isinstance(node, StrictGreaterThan): + mnemonic = "sgt" + elif isinstance(node, StrictLessThan): + mnemonic = "slt" + else: + raise NotImplementedError(f"Unimplemented comparison {type(node)}") + + # import pdb; + # pdb.set_trace() + SSAargs = (self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args) + # Operands must have the same type + # TODO: look at here if index stuff does not make sense + # The s in sgt means *signed* greater than + # Writer has no clue if this should rather be u for unsigned + return arith.Cmpi(*self._ensure_same_type(*SSAargs), mnemonic).result + else: - raise NotImplementedError(f"Unknown math: {node}", node) + raise NotImplementedError(f"Unknown math:{type(node)} {node}", node) - def _build_step_body(self, dim: SteppingDimension, eq: LoweredEq) -> None: + def build_stencil_step(self, dim: SteppingDimension, eq:LoweredEq) -> None: """ Builds the body of the step function for a given dimension and equation. @@ -223,13 +338,15 @@ def _build_step_body(self, dim: SteppingDimension, eq: LoweredEq) -> None: """ read_functions = set() for f in retrieve_function_carriers(eq.rhs): - if isinstance(f.function, TimeFunction): - time_offset = (f.indices[dim] - dim) % f.function.time_size + if isinstance(f.function, PointSource): + time_offset = 0 + elif isinstance(f.function, TimeFunction): + time_offset = (f.indices[dim]-dim) % f.function.time_size elif isinstance(f.function, Function): time_offset = 0 else: - raise NotImplementedError(f"reading function of type {type(f.func)}" - "not supported") + import pdb;pdb.set_trace() + raise NotImplementedError(f"reading function of type {type(f.function)} not supported") read_functions.add((f.function, time_offset)) for f, t in read_functions: @@ -258,7 +375,9 @@ def _build_step_body(self, dim: SteppingDimension, eq: LoweredEq) -> None: assert "temp" in apply_op.name_hint apply_arg.name_hint = apply_op.name_hint.replace("temp", "blk") - self.apply_temps = {k: v for k, v in zip(read_functions, apply.region.block.args)} + self.apply_temps = {k:v for k,v in zip(read_functions, apply.region.block.args)} + # Update the function values with the new temps + self.function_values |= self.apply_temps with ImplicitBuilder(apply.region.block): stencil.ReturnOp.get([self._visit_math_nodes(dim, eq.rhs, eq.lhs)]) @@ -277,8 +396,43 @@ def _build_step_body(self, dim: SteppingDimension, eq: LoweredEq) -> None: store.temp_with_halo.name_hint = f"{write_function.name}_t{self.out_time_buffer[1]}_temp" # noqa self.temps[self.out_time_buffer] = store.temp_with_halo + def build_generic_step_expression(self, dim: SteppingDimension, eq: LoweredEq): + # Sources + value = self._visit_math_nodes(dim, eq.rhs, None) + temp = self.function_values[self.out_time_buffer] + memtemp = builtin.UnrealizedConversionCastOp.get([temp], [StencilToMemRefType(temp.type)]).results[0] + memtemp.name_hint = temp.name_hint + "_mem" + indices = eq.lhs.indices + if isinstance(eq.lhs.function, TimeFunction): + indices = indices[1:] + ssa_indices = [self._visit_math_nodes(dim, i, None) for i in indices] + for i, ssa_i in enumerate(ssa_indices): + if isinstance(ssa_i.type, builtin.IntegerType): + ssa_indices[i] = arith.IndexCastOp(ssa_i, builtin.IndexType()) + + match eq.operation: + case None: + memref.Store.get(value, memtemp, ssa_indices) + case OpInc: + memref.AtomicRMWOp(operands=[value, memtemp, ssa_indices], result_types=[value.type], properties={"kind" : builtin.IntegerAttr(0, builtin.i64)}) + + def build_condition(self, dim: SteppingDimension, eq: BooleanFunction): + return self._visit_math_nodes(dim, eq, None) + + def build_generic_step(self, dim: SteppingDimension, eq: LoweredEq): + if eq.conditionals: + condition = And(*eq.conditionals.values(), evaluate=False) + cond = self.build_condition(dim, condition) + if_ = scf.If(cond, (), Region(Block())) + with ImplicitBuilder(if_.true_region.block): + self.build_generic_step_expression(dim, eq) + scf.Yield() + else: + # Build the expression + self.build_generic_step_expression(dim, eq) + def build_time_loop( - self, eqs: list[LoweredEq], step_dim: SteppingDimension, **kwargs + self, eqs: list[Any], step_dim: SteppingDimension, **kwargs ): # Bounds and step boilerpalte lb = iet_ssa.LoadSymbolic.get( @@ -319,7 +473,9 @@ def build_time_loop( ) # Name the 'time' step iterator - loop.body.block.args[0].name_hint = step_dim.root.name # 'time' + loop.body.block.args[0].name_hint = "time" + # Store for later reference + self.symbol_values["time"] = loop.body.block.args[0] # Store a mapping from time_buffers to their corresponding block # arguments for easier access later. @@ -341,10 +497,66 @@ def build_time_loop( ] scf.Yield(*yield_args) - def generate_equations(self, eqs: list[LoweredEq], **kwargs): + def generate_equations(self, eqs: list[Any], **kwargs): # Lower equations to their xDSL equivalent for eq in eqs: - self._convert_eq(eq, **kwargs) + if isinstance(eq, Eq): + # Nested lowering? TO re-think approach + lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs) + for lo in lowered: + self._convert_eq(lo) + elif isinstance(eq, Injection): + lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs) + self._lower_injection(lowered) + else: + raise NotImplementedError(f"Expression {eq} of type {type(eq)} not supported") + + def _lower_injection(self, eqs: list[LoweredEq]): + """ + Lower an injection to xDSL. + """ + # We assert that all equations of one Injection share the same iteration space! + ispaces = [e.ispace for e in eqs] + assert all(ispaces[0] == isp for isp in ispaces[1:]) + ispace = ispaces[0] + assert isinstance(ispace.dimensions[0], TimeDimension) + + lbs = [] + ubs = [] + for interval in ispace[1:]: + lower = interval.symbolic_min + if isinstance(lower, Scalar): + lb = iet_ssa.LoadSymbolic.get(lower._C_name, builtin.IndexType()) + elif isinstance(lower, (Number, int)): + lb = arith.Constant.from_int_and_width(int(lower), builtin.IndexType()) + else: + raise NotImplementedError(f"Lower bound of type {type(lower)} not supported") + lb.result.name_hint = f"{interval.dim.name}_m" + + upper = interval.symbolic_max + if isinstance(upper, Scalar): + ub = iet_ssa.LoadSymbolic.get(upper._C_name, builtin.IndexType()) + elif isinstance(upper, (Number, int)): + ub = arith.Constant.from_int_and_width(int(upper), builtin.IndexType()) + else: + raise NotImplementedError( + f"Upper bound of type {type(upper)} not supported" + ) + ub.result.name_hint = f"{interval.dim.name}_M" + lbs.append(lb) + ubs.append(ub) + + steps = [arith.Constant.from_int_and_width(1, builtin.IndexType()).result]*len(ubs) + ubs = [arith.Addi(ub, steps[0]) for ub in ubs] + + with ImplicitBuilder(scf.ParallelOp(lbs, ubs, steps, [pblock := Block(arg_types=[builtin.IndexType()]*len(ubs))]).body): + for arg, interval in zip(pblock.args, ispace[1:], strict=True): + arg.name_hint = interval.dim.name + self.symbol_values[interval.dim.name] = arg + for eq in eqs: + self._convert_eq(eq) + scf.Yield() + # raise NotImplementedError("Injections not supported yet") def convert(self, eqs: Iterable[Eq], **kwargs) -> builtin.ModuleOp: """ @@ -387,15 +599,33 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> builtin.ModuleOp: calling the operator. """ # Instantiate the module. - self.function_values: dict[tuple[Function, int], SSAValue] = {} + self.function_values : dict[tuple[Function, int], SSAValue] = {} + self.symbol_values : dict[str, SSAValue] = {} module = builtin.ModuleOp(Region([block := Block([])])) with ImplicitBuilder(block): # Get all functions used in the equations - functions = OrderedSet( - *(f.function for eq in eqs for f in retrieve_function_carriers(eq)) - ) - self.time_buffers: list[TimeFunction] = [] - self.functions: list[Function] = [] + functions = OrderedSet() + for eq in eqs: + if isinstance(eq, Eq): + # Use funcs not carriers + funcs = retrieve_functions(eq) + + for f in funcs: + functions.add(f.function) + + elif isinstance(eq, Injection): + # import pdb; pdb.set_trace() + functions.add(eq.field.function) + for f in retrieve_functions(eq.expr): + if isinstance(f, PointSource): + functions.add(f._coordinates) + functions.add(f.function) + + else: + raise NotImplementedError(f"Expression {eq} of type {type(eq)} not supported") + + self.time_buffers : list[TimeFunction] = [] + self.functions : list[Function] = [] for f in functions: match f: case TimeFunction(): @@ -403,6 +633,11 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> builtin.ModuleOp: self.time_buffers.append((f, i)) case Function(): self.functions.append(f) + case PointSource(): + self.functions.append(f.coordinates) + self.functions.append(f) + case _: + raise NotImplementedError(f"Function of type {type(f)} not supported") # For each used time_buffer, define a stencil.field type for the function. # Those represent DeVito's buffers in xDSL/stencil terms. @@ -418,33 +653,26 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> builtin.ModuleOp: self.function_args = {} for i, (f, t) in enumerate(self.time_buffers): # Also define argument names to help with debugging - xdsl_func.body.block.args[i].name_hint = f"{f.name}_vec{t}" + xdsl_func.body.block.args[i].name_hint = f._C_name + str(t) self.function_args[(f, t)] = xdsl_func.body.block.args[i] for i, f in enumerate(self.functions): - xdsl_func.body.block.args[len(self.time_buffers) + i].name_hint = f"{f.name}_vec" # noqa - # tofix what is this 0 in [(f, 0)] - self.function_args[(f, 0)] = xdsl_func.body.block.args[len(self.time_buffers) + i] # noqa + # Sources + xdsl_func.body.block.args[len(self.time_buffers)+i].name_hint = f._C_name + self.function_args[(f, 0)] = xdsl_func.body.block.args[len(self.time_buffers)+i] # Union operation? self.function_values |= self.function_args - # print(self.function_values) # Move on to generate the function body with ImplicitBuilder(xdsl_func.body.block): - # Start building the time loop - # TODO: This should be moved to the cluster codegen. In the meantime, - # we stick to similar assumptions and just use the first equation's grid - # for the time loop information. - - # Get the stepping dimension. It's usually time, and usually the first one. - # Getting it here; more readable and less input assumptions :) - time_functions = [f for (f, _) in self.time_buffers] + # Get the stepping dimension, if there is any in the whole input + time_functions = [f for (f,_) in self.time_buffers] dimensions = { d for f in (self.functions + time_functions) for d in f.dimensions } - step_dim = next((d for d in dimensions - if isinstance(d, SteppingDimension)), None) + + step_dim = next((d for d in dimensions if isinstance(d, SteppingDimension)), None) if step_dim is not None: self.build_time_loop(eqs, step_dim, **kwargs) else: @@ -458,12 +686,21 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> builtin.ModuleOp: def _ensure_same_type(self, *vals: SSAValue): if all(isinstance(val.type, builtin.IntegerAttr) for val in vals): return vals + if all(isinstance(val.type, builtin.IndexType) for val in vals): + # Sources + return vals if all(is_float(val) for val in vals): return vals # not everything homogeneous + cast_to_floats = True + if all(is_int(val) for val in vals): + cast_to_floats = False processed = [] for val in vals: - if is_float(val): + if cast_to_floats and is_float(val): + processed.append(val) + continue + if (not cast_to_floats) and isinstance(val.type, builtin.IndexType): processed.append(val) continue # if the val is the result of a arith.constant with no uses, @@ -473,14 +710,23 @@ def _ensure_same_type(self, *vals: SSAValue): and isinstance(val.op, arith.Constant) and val.uses == 0 ): - val.type = builtin.f32 - val.op.attributes["value"] = builtin.FloatAttr( - float(val.op.value.value.data), builtin.f32 - ) + if cast_to_floats: + val.type = builtin.f32 + val.op.attributes["value"] = builtin.FloatAttr( + float(val.op.value.value.data), builtin.f32 + ) + else: + val.type = builtin.IndexType() + val.op.value.type = builtin.IndexType() processed.append(val) continue - # insert an integer to float cast op - conv = arith.SIToFPOp(val, builtin.f32) + # insert a cast op + if cast_to_floats: + if val.type == builtin.IndexType(): + val = arith.IndexCastOp(val, builtin.i64).result + conv = arith.SIToFPOp(val, builtin.f32) + else: + conv = arith.IndexCastOp(val, builtin.IndexType()) processed.append(conv.result) return processed @@ -531,6 +777,49 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): rewriter.insert_op_after_matched_op(wrapper) +class TimerRewritePattern(RewritePattern): + """ + Base class for time benchmarking related rewrite patterns + """ + pass + + +@dataclass +class MakeFunctionTimed(TimerRewritePattern): + """ + Populate the section0 devito timer with the total runtime of the function + """ + func_name: str + seen_ops: set[func.Func] = field(default_factory=set) + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): + if op.sym_name.data != self.func_name or op in self.seen_ops: + return + + # only apply once + self.seen_ops.add(op) + + # Insert timer start and end calls + rewriter.insert_op([ + t0 := func.Call('timer_start', [], [builtin.f64]) + ], InsertPoint.at_start(op.body.block)) + + ret = op.get_return_op() + assert ret is not None + + rewriter.insert_op_before([ + timers := iet_ssa.LoadSymbolic.get('timers', llvm.LLVMPointerType.opaque()), + t1 := func.Call('timer_end', [t0], [builtin.f64]), + llvm.StoreOp(t1, timers), + ], ret) + + rewriter.insert_op([ + func.FuncOp.external('timer_start', [], [builtin.f64]), + func.FuncOp.external('timer_end', [builtin.f64], [builtin.f64]), + ], InsertPoint.after(rewriter.current_operation)) + + def get_containing_func(op: Operation) -> func.FuncOp | None: while op is not None and not isinstance(op, func.FuncOp): op = op.parent_op() diff --git a/devito/ir/ietxdsl/profiling.py b/devito/ir/ietxdsl/profiling.py index 15f0130979..f767a34ff1 100644 --- a/devito/ir/ietxdsl/profiling.py +++ b/devito/ir/ietxdsl/profiling.py @@ -6,7 +6,7 @@ from xdsl.pattern_rewriter import (RewritePattern, op_type_rewrite_pattern, GreedyRewritePatternApplier, PatternRewriter, - PatternRewriteWalker) + PatternRewriteWalker, InsertPoint) class TimerRewritePattern(RewritePattern): @@ -33,9 +33,10 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): self.seen_ops.add(op) # Insert timer start and end calls - rewriter.insert_op_at_start([ + # Insert timer start and end calls + rewriter.insert_op([ t0 := func.Call('timer_start', [], [builtin.f64]) - ], op.body.block) + ], InsertPoint.at_start(op.body.block)) ret = op.get_return_op() assert ret is not None @@ -46,10 +47,10 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): llvm.StoreOp(t1, timers), ], ret) - rewriter.insert_op_after_matched_op([ + rewriter.insert_op([ func.FuncOp.external('timer_start', [], [builtin.f64]), - func.FuncOp.external('timer_end', [builtin.f64], [builtin.f64]) - ]) + func.FuncOp.external('timer_end', [builtin.f64], [builtin.f64]), + ], InsertPoint.after(rewriter.current_operation)) def apply_timers(module, **kwargs): diff --git a/devito/ir/ietxdsl/utils.py b/devito/ir/ietxdsl/utils.py index 6d413c8898..da213d076a 100644 --- a/devito/ir/ietxdsl/utils.py +++ b/devito/ir/ietxdsl/utils.py @@ -1,10 +1,20 @@ +import numpy as np + from xdsl.dialects import builtin from xdsl.ir import SSAValue def is_int(val: SSAValue): - return isinstance(val.type, builtin.IntegerType) + return isinstance(val.type, (builtin.IntegerType, builtin.IndexType)) def is_float(val: SSAValue): return val.type in (builtin.f32, builtin.f64) + + +dtypes_to_xdsltypes = { + np.float32: builtin.f32, + np.float64: builtin.f64, + np.int32: builtin.i32, + np.int64: builtin.i64, +} diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 788e66bd94..1c828a88a0 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -859,8 +859,6 @@ def apply(self, **kwargs): # Output summary of performance achieved return self._emit_apply_profiling(args) - # Performance profiling - def _emit_build_profiling(self): if not is_log_enabled_for('PERF'): return diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 759b86a3ae..85ab8f2fbc 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -79,13 +79,13 @@ class TestAdjoint(object): ('layers-viscoacoustic', (20, 25), 'maxwell', 4, 2, viscoacoustic_setup), ('layers-viscoacoustic', (20, 25), 'maxwell', 2, 2, viscoacoustic_setup), # 3D Deng Mcmechan Viscoacoustic tests with varying space and equation orders - ('layers-viscoacoustic', (20, 25, 20), 'maxwell', 4, 1, \ + ('layers-viscoacoustic', (20, 25, 20), 'maxwell', 4, 1, viscoacoustic_setup), - ('layers-viscoacoustic', (20, 25, 20), 'maxwell', 2, 1, \ + ('layers-viscoacoustic', (20, 25, 20), 'maxwell', 2, 1, viscoacoustic_setup), - ('layers-viscoacoustic', (20, 25, 20), 'maxwell', 4, 2, \ + ('layers-viscoacoustic', (20, 25, 20), 'maxwell', 4, 2, viscoacoustic_setup), - ('layers-viscoacoustic', (20, 25, 20), 'maxwell', 2, 2, \ + ('layers-viscoacoustic', (20, 25, 20), 'maxwell', 2, 2, viscoacoustic_setup), ]) def test_adjoint_F(self, mkey, shape, kernel, space_order, time_order, setup_func): diff --git a/tests/test_xdsl_base.py b/tests/test_xdsl_base.py index 0e96fba182..0625a8d1e6 100644 --- a/tests/test_xdsl_base.py +++ b/tests/test_xdsl_base.py @@ -2,17 +2,18 @@ import pytest from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad, diag, solve, - Operator, Eq, Constant, norm, SpaceDimension) + Operator, Eq, Constant, norm, SpaceDimension, switchconfig) from devito.types import Array, Function, TimeFunction from devito.tools import as_tuple +from examples.seismic.source import RickerSource, TimeAxis from xdsl.dialects.scf import For, Yield from xdsl.dialects.arith import Addi +from xdsl.dialects.arith import Constant as xdslconstant from xdsl.dialects.func import Call, Return from xdsl.dialects.stencil import FieldType, ApplyOp, LoadOp, StoreOp from xdsl.dialects.llvm import LLVMPointerType - -from examples.seismic.source import RickerSource, TimeAxis +from xdsl.dialects.memref import Load def test_xdsl_I(): @@ -251,6 +252,160 @@ def test_standard_mlir_rewrites(shape, so, to, nt): xdslop.apply(time=nt, dt=dt) +class TestSources: + + @switchconfig(openmp=False) + @pytest.mark.parametrize('shape', [(8, 8), (38, 38), ]) + @pytest.mark.parametrize('tn', [20, 80]) + @pytest.mark.parametrize('factor', [-0.1, 0.0, 0.1, 0.5, 1.1]) + @pytest.mark.parametrize('factor2', [-0.1, 0.1, 0.5, 1.1]) + def test_source_only(self, shape, tn, factor, factor2): + spacing = (10.0, 10.0) + extent = tuple(np.array(spacing) * (shape[0] - 1)) + origin = (0.0, 0.0) + + v = np.empty(shape, dtype=np.float32) + v[:, :51] = 1.5 + v[:, 51:] = 2.5 + + grid = Grid(shape=shape, extent=extent, origin=origin) + + t0 = 0.0 + # Comes from args + tn = tn + dt = 1.6 + time_range = TimeAxis(start=t0, stop=tn, step=dt) + + f0 = 0.010 + src = RickerSource(name="src", grid=grid, f0=f0, npoint=5, time_range=time_range) + + domain_size = np.array(extent) + + src.coordinates.data[0, :] = domain_size * factor + src.coordinates.data[0, -1] = 19.0 * factor2 + + u = TimeFunction(name="u", grid=grid, space_order=2) + m = Function(name='m', grid=grid) + m.data[:] = 1./(v*v) + + src_term = src.inject(field=u.forward, expr=src * dt**2 / m) + + op = Operator([src_term], opt="advanced") + op(time=time_range.num-1, dt=dt) + normdv = np.linalg.norm(u.data[0]) + u.data[:, :] = 0 + + opx = Operator([src_term], opt="xdsl") + opx(time=time_range.num-1, dt=dt) + normxdsl = np.linalg.norm(u.data[0]) + + assert np.isclose(normdv, normxdsl, rtol=1e-04) + + @switchconfig(openmp=False) + @pytest.mark.parametrize('shape', [(8, 8)]) + @pytest.mark.parametrize('tn', [20, 80]) + @pytest.mark.parametrize('factor', [-0.1, 0.0, 0.1, 0.5, 1.1]) + @pytest.mark.parametrize('factor2', [-0.1, 0.1, 0.5, 1.1]) + def test_source_structure(self, shape, tn, factor, factor2): + spacing = (10.0, 10.0) + extent = tuple(np.array(spacing) * (shape[0] - 1)) + origin = (0.0, 0.0) + + v = np.empty(shape, dtype=np.float32) + v[:, :51] = 1.5 + v[:, 51:] = 2.5 + + grid = Grid(shape=shape, extent=extent, origin=origin) + + t0 = 0.0 + # Comes from args + tn = tn + dt = 1.6 + time_range = TimeAxis(start=t0, stop=tn, step=dt) + + f0 = 0.010 + src = RickerSource(name="src", grid=grid, f0=f0, npoint=5, time_range=time_range) + + domain_size = np.array(extent) + + src.coordinates.data[0, :] = domain_size * factor + src.coordinates.data[0, -1] = 19.0 * factor2 + + u = TimeFunction(name="u", grid=grid, space_order=2) + m = Function(name='m', grid=grid) + m.data[:] = 1./(v*v) + + src_term = src.inject(field=u.forward, expr=src * dt**2 / m) + + opx = Operator([src_term], opt="xdsl") + opx(time=time_range.num-1, dt=dt) + + # Code structure + calls = sum(isinstance(op, Call) for op in opx._module.walk()) + assert calls == 2 + fors = sum(isinstance(op, For) for op in opx._module.walk()) + assert fors == 1 + loads = [op for op in opx._module.walk() if isinstance(op, Load)] + assert len(loads) == 8 + consts = [op for op in opx._module.walk() if isinstance(op, xdslconstant)] + assert len(consts) == 65 + + @switchconfig(openmp=False) + @pytest.mark.parametrize('shape', [(38, 38), ]) + @pytest.mark.parametrize('tn', [20, 80]) + @pytest.mark.parametrize('factor', [0.5, 0.8]) + @pytest.mark.parametrize('factor2', [0.5, 0.8]) + def test_forward_src_stencil(self, shape, tn, factor, factor2): + spacing = (10.0, 10.0) + extent = tuple(np.array(spacing) * (shape[0] - 1)) + origin = (0.0, 0.0) + + v = np.empty(shape, dtype=np.float32) + v[:, :51] = 1.5 + v[:, 51:] = 2.5 + + grid = Grid(shape=shape, extent=extent, origin=origin) + + t0 = 0.0 + # Comes from args + tn = tn + dt = 1.6 + time_range = TimeAxis(start=t0, stop=tn, step=dt) + + f0 = 0.010 + src = RickerSource(name="src", grid=grid, f0=f0, npoint=5, time_range=time_range) + + domain_size = np.array(extent) + + src.coordinates.data[0, :] = domain_size * factor + src.coordinates.data[0, -1] = 100.0 * factor2 + + u = TimeFunction(name="u", grid=grid, space_order=2, time_order=2) + m = Function(name='m', grid=grid) + m.data[:] = 1./(v*v) + + src_term = src.inject(field=u.forward, expr=src * dt**2 / m) + + pde = u.dt2 - u.laplace + eq0 = solve(pde, u.forward) + stencil = Eq(u.forward, eq0) + + op = Operator([stencil, src_term], opt="advanced") + op(time=time_range.num-1, dt=dt) + # normdv = norm(u) + normdv = np.linalg.norm(u.data) + + u.data[:, :] = 0 + + opx = Operator([stencil, src_term], opt="xdsl") + opx(time=time_range.num-1, dt=dt) + normxdsl = np.linalg.norm(u.data) + # normxdsl = norm(u) + + assert not np.isclose(normdv, 0.0, rtol=1e-04) + assert np.isclose(normdv, normxdsl, rtol=1e-04) + + def test_xdsl_mul_eqs_I(): # Define a Devito Operator with multiple eqs grid = Grid(shape=(4, 4)) @@ -907,8 +1062,11 @@ def test_elastic_2D(self, shape, so, nt): op = Operator([u_v] + [u_t] + src_xx + src_zz) op(dt=dt) - op = Operator([u_v] + [u_t], opt='xdsl') - op(dt=dt, time_M=nt) + opx = Operator([u_v] + [u_t], opt='xdsl') + opx(dt=dt, time_M=nt) + + store_ops = [op for op in opx._module.walk() if isinstance(op, StoreOp)] + assert len(store_ops) == 5 xdsl_norm_v0 = norm(v[0]) xdsl_norm_v1 = norm(v[1])