diff --git a/devito/core/cpu.py b/devito/core/cpu.py index 8f7cff18b6..386b81fdb4 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -15,6 +15,8 @@ from devito.core.operator import CoreOperator, CustomOperator, ParTile from devito.exceptions import InvalidOperator from devito.ir.iet import Callable, MetaCall +from devito.ir.iet.nodes import Section +from devito.ir.iet.visitors import FindNodes from devito.logger import info, perf from devito.mpi import MPI from devito.operator.profiling import create_profile @@ -277,6 +279,7 @@ def _build(cls, expressions, **kwargs): op._dimensions = set().union(*[e.dimensions for e in irs.expressions]) op._dtype, op._dspace = irs.clusters.meta op._profiler = profiler + kwargs['xdsl_num_sections'] = len(FindNodes(Section).visit(irs.iet)) module = cls._lower_stencil(irs.expressions, **kwargs) op._module = module diff --git a/devito/ir/ietxdsl/cluster_to_ssa.py b/devito/ir/ietxdsl/cluster_to_ssa.py index 602b6f56ee..aa17e2de0f 100644 --- a/devito/ir/ietxdsl/cluster_to_ssa.py +++ b/devito/ir/ietxdsl/cluster_to_ssa.py @@ -49,10 +49,12 @@ class ExtractDevitoStencilConversion: eqs: list[LoweredEq] block: Block temps: dict[tuple[DiscreteFunction, int], SSAValue] + symbol_values: dict[str, SSAValue] time_offs: int def __init__(self): self.temps = dict() + self.symbol_values = dict() time_offs: int @@ -75,6 +77,13 @@ def convert_function_eq(self, eq: LoweredEq, **kwargs): # Get the function carriers of the equation self._build_step_body(step_dim, eq) + def convert_symbol_eq(self, symbol: Symbol, rhs: LoweredEq, **kwargs): + """ + Convert a symbol equation to xDSL. + """ + self.symbol_values[symbol.name] = self._visit_math_nodes(None, rhs, None) + self.symbol_values[symbol.name].name_hint = symbol.name + def _convert_eq(self, eq: LoweredEq, **kwargs): """ # Docs here Need rewriting @@ -122,7 +131,7 @@ def _convert_eq(self, eq: LoweredEq, **kwargs): f"Function of type {type(write_function.function)} not supported" # noqa ) case Symbol(): - self.convert_symbol_eq(write_function, eq, **kwargs) + self.convert_symbol_eq(write_function, eq.rhs, **kwargs) case _: raise NotImplementedError(f"LHS of type {type(write_function)} not supported") # noqa diff --git a/devito/ir/ietxdsl/profiling.py b/devito/ir/ietxdsl/profiling.py index aa02ee5006..15f0130979 100644 --- a/devito/ir/ietxdsl/profiling.py +++ b/devito/ir/ietxdsl/profiling.py @@ -56,6 +56,8 @@ def apply_timers(module, **kwargs): """ Apply timers to a module """ + if kwargs['xdsl_num_sections'] < 1: + return name = kwargs.get("name", "Kernel") grpa = GreedyRewritePatternApplier([MakeFunctionTimed(name)]) PatternRewriteWalker(grpa, walk_regions_first=True).rewrite_module(module) diff --git a/tests/test_xdsl_base.py b/tests/test_xdsl_base.py index 2dda69081a..3b9db7b30d 100644 --- a/tests/test_xdsl_base.py +++ b/tests/test_xdsl_base.py @@ -3,7 +3,7 @@ from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad, diag, solve, Operator, Eq, Constant, norm, SpaceDimension) -from devito.types import Symbol, Array, Function, TimeFunction +from devito.types import Array, Function, TimeFunction from xdsl.dialects.scf import For, Yield from xdsl.dialects.arith import Addi @@ -789,19 +789,6 @@ def test_function_IV(): class TestOperatorUnsupported(object): - @pytest.mark.xfail(reason="Symbols are not supported in xDSL yet") - def test_symbol_I(self): - # Define a simple Devito a = 1 operator - - a = Symbol('a') - eq0 = Eq(a, 1) - - op = Operator([eq0], opt='xdsl') - - op.apply() - - assert a == 1 - @pytest.mark.xfail(reason="stencil.return operation does not verify for i64") def test_forward_assignment(self): # simple forward assignment diff --git a/tests/test_xdsl_op_correctness.py b/tests/test_xdsl_op_correctness.py index 7c7fc44d42..e8f3b8b538 100644 --- a/tests/test_xdsl_op_correctness.py +++ b/tests/test_xdsl_op_correctness.py @@ -4,12 +4,14 @@ # flake8: noqa from xdsl.dialects.scf import For, Yield -from xdsl.dialects.arith import Addi +from xdsl.dialects.arith import Addi, Constant from xdsl.dialects.func import Call, Return from xdsl.dialects.stencil import FieldType, ApplyOp, LoadOp, StoreOp from xdsl.dialects.llvm import LLVMPointerType from xdsl.printer import Printer +from devito.types.basic import Symbol + def test_udx(): # Define a simple Devito Operator @@ -106,6 +108,23 @@ def test_u_and_v_conversion(): assert type(ops[8] == StoreOp) assert type(ops[9] == Return) +def test_symbol_I(): + # Define a simple Devito a = 1 operator + + a = Symbol('a') + eq0 = Eq(a, 1) + + op = Operator([eq0], opt='xdsl') + + op.apply() + + assert len(op._module.regions[0].blocks[0].ops.first.body.blocks[0].ops) == 2 + + ops = list(op._module.regions[0].blocks[0].ops.first.body.blocks[0].ops) + assert isinstance(ops[0], Constant) + assert ops[0].result.name_hint == a.name + assert type(ops[0] == Return) + # This test should fail, as we are trying to use an inplace operation @pytest.mark.xfail(reason="Cannot store to a field that is loaded from")