From dc53d5a4e4ae86ce088918e567ed1b5a8e064ccd Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Fri, 17 May 2024 02:53:42 +0100 Subject: [PATCH 1/3] Symbols! --- devito/core/cpu.py | 1 + devito/ir/ietxdsl/cluster_to_ssa.py | 9 ++++++++- tests/test_xdsl_base.py | 13 ------------- tests/test_xdsl_op_correctness.py | 21 ++++++++++++++++++++- 4 files changed, 29 insertions(+), 15 deletions(-) diff --git a/devito/core/cpu.py b/devito/core/cpu.py index 8f7cff18b6..5c71af9bf0 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -277,6 +277,7 @@ def _build(cls, expressions, **kwargs): op._dimensions = set().union(*[e.dimensions for e in irs.expressions]) op._dtype, op._dspace = irs.clusters.meta op._profiler = profiler + kwargs['xdsl_num_sections'] = len(FindNodes(Section).visit(irs.iet)) module = cls._lower_stencil(irs.expressions, **kwargs) op._module = module diff --git a/devito/ir/ietxdsl/cluster_to_ssa.py b/devito/ir/ietxdsl/cluster_to_ssa.py index 602b6f56ee..6516e229b3 100644 --- a/devito/ir/ietxdsl/cluster_to_ssa.py +++ b/devito/ir/ietxdsl/cluster_to_ssa.py @@ -75,6 +75,13 @@ def convert_function_eq(self, eq: LoweredEq, **kwargs): # Get the function carriers of the equation self._build_step_body(step_dim, eq) + def convert_symbol_eq(self, symbol: Symbol, rhs: LoweredEq, **kwargs): + """ + Convert a symbol equation to xDSL. + """ + self.symbol_values[symbol.name] = self._visit_math_nodes(None, rhs, None) + self.symbol_values[symbol.name].name_hint = symbol.name + def _convert_eq(self, eq: LoweredEq, **kwargs): """ # Docs here Need rewriting @@ -122,7 +129,7 @@ def _convert_eq(self, eq: LoweredEq, **kwargs): f"Function of type {type(write_function.function)} not supported" # noqa ) case Symbol(): - self.convert_symbol_eq(write_function, eq, **kwargs) + self.convert_symbol_eq(write_function, eq.rhs, **kwargs) case _: raise NotImplementedError(f"LHS of type {type(write_function)} not supported") # noqa diff --git a/tests/test_xdsl_base.py b/tests/test_xdsl_base.py index 2dda69081a..d4a4dfd543 100644 --- a/tests/test_xdsl_base.py +++ b/tests/test_xdsl_base.py @@ -789,19 +789,6 @@ def test_function_IV(): class TestOperatorUnsupported(object): - @pytest.mark.xfail(reason="Symbols are not supported in xDSL yet") - def test_symbol_I(self): - # Define a simple Devito a = 1 operator - - a = Symbol('a') - eq0 = Eq(a, 1) - - op = Operator([eq0], opt='xdsl') - - op.apply() - - assert a == 1 - @pytest.mark.xfail(reason="stencil.return operation does not verify for i64") def test_forward_assignment(self): # simple forward assignment diff --git a/tests/test_xdsl_op_correctness.py b/tests/test_xdsl_op_correctness.py index 7c7fc44d42..e8f3b8b538 100644 --- a/tests/test_xdsl_op_correctness.py +++ b/tests/test_xdsl_op_correctness.py @@ -4,12 +4,14 @@ # flake8: noqa from xdsl.dialects.scf import For, Yield -from xdsl.dialects.arith import Addi +from xdsl.dialects.arith import Addi, Constant from xdsl.dialects.func import Call, Return from xdsl.dialects.stencil import FieldType, ApplyOp, LoadOp, StoreOp from xdsl.dialects.llvm import LLVMPointerType from xdsl.printer import Printer +from devito.types.basic import Symbol + def test_udx(): # Define a simple Devito Operator @@ -106,6 +108,23 @@ def test_u_and_v_conversion(): assert type(ops[8] == StoreOp) assert type(ops[9] == Return) +def test_symbol_I(): + # Define a simple Devito a = 1 operator + + a = Symbol('a') + eq0 = Eq(a, 1) + + op = Operator([eq0], opt='xdsl') + + op.apply() + + assert len(op._module.regions[0].blocks[0].ops.first.body.blocks[0].ops) == 2 + + ops = list(op._module.regions[0].blocks[0].ops.first.body.blocks[0].ops) + assert isinstance(ops[0], Constant) + assert ops[0].result.name_hint == a.name + assert type(ops[0] == Return) + # This test should fail, as we are trying to use an inplace operation @pytest.mark.xfail(reason="Cannot store to a field that is loaded from") From 8a20f8b2efc675dd4a9e2eb610c3b0a18743b1c6 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Sun, 19 May 2024 11:46:55 +0100 Subject: [PATCH 2/3] Missing bit. --- devito/ir/ietxdsl/cluster_to_ssa.py | 1 + 1 file changed, 1 insertion(+) diff --git a/devito/ir/ietxdsl/cluster_to_ssa.py b/devito/ir/ietxdsl/cluster_to_ssa.py index 6516e229b3..c6aa1db16a 100644 --- a/devito/ir/ietxdsl/cluster_to_ssa.py +++ b/devito/ir/ietxdsl/cluster_to_ssa.py @@ -49,6 +49,7 @@ class ExtractDevitoStencilConversion: eqs: list[LoweredEq] block: Block temps: dict[tuple[DiscreteFunction, int], SSAValue] + symbol_values: dict[str, SSAValue] time_offs: int def __init__(self): From 55132c68729f7d4241c8ed10a9cc81ba6cf95c61 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Tue, 21 May 2024 12:47:54 +0100 Subject: [PATCH 3/3] Post-rebase tweaks. --- devito/core/cpu.py | 2 ++ devito/ir/ietxdsl/cluster_to_ssa.py | 1 + devito/ir/ietxdsl/profiling.py | 2 ++ tests/test_xdsl_base.py | 2 +- 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/devito/core/cpu.py b/devito/core/cpu.py index 5c71af9bf0..386b81fdb4 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -15,6 +15,8 @@ from devito.core.operator import CoreOperator, CustomOperator, ParTile from devito.exceptions import InvalidOperator from devito.ir.iet import Callable, MetaCall +from devito.ir.iet.nodes import Section +from devito.ir.iet.visitors import FindNodes from devito.logger import info, perf from devito.mpi import MPI from devito.operator.profiling import create_profile diff --git a/devito/ir/ietxdsl/cluster_to_ssa.py b/devito/ir/ietxdsl/cluster_to_ssa.py index c6aa1db16a..aa17e2de0f 100644 --- a/devito/ir/ietxdsl/cluster_to_ssa.py +++ b/devito/ir/ietxdsl/cluster_to_ssa.py @@ -54,6 +54,7 @@ class ExtractDevitoStencilConversion: def __init__(self): self.temps = dict() + self.symbol_values = dict() time_offs: int diff --git a/devito/ir/ietxdsl/profiling.py b/devito/ir/ietxdsl/profiling.py index aa02ee5006..15f0130979 100644 --- a/devito/ir/ietxdsl/profiling.py +++ b/devito/ir/ietxdsl/profiling.py @@ -56,6 +56,8 @@ def apply_timers(module, **kwargs): """ Apply timers to a module """ + if kwargs['xdsl_num_sections'] < 1: + return name = kwargs.get("name", "Kernel") grpa = GreedyRewritePatternApplier([MakeFunctionTimed(name)]) PatternRewriteWalker(grpa, walk_regions_first=True).rewrite_module(module) diff --git a/tests/test_xdsl_base.py b/tests/test_xdsl_base.py index d4a4dfd543..3b9db7b30d 100644 --- a/tests/test_xdsl_base.py +++ b/tests/test_xdsl_base.py @@ -3,7 +3,7 @@ from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad, diag, solve, Operator, Eq, Constant, norm, SpaceDimension) -from devito.types import Symbol, Array, Function, TimeFunction +from devito.types import Array, Function, TimeFunction from xdsl.dialects.scf import For, Yield from xdsl.dialects.arith import Addi