Skip to content

Commit

Permalink
Merge pull request #90 from xdslproject/emilien/symbol
Browse files Browse the repository at this point in the history
compiler: Add support for Devito's `Symbol` in LHS
  • Loading branch information
georgebisbas authored May 21, 2024
2 parents d118fbb + 55132c6 commit 5f30fdd
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 16 deletions.
3 changes: 3 additions & 0 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from devito.core.operator import CoreOperator, CustomOperator, ParTile
from devito.exceptions import InvalidOperator
from devito.ir.iet import Callable, MetaCall
from devito.ir.iet.nodes import Section
from devito.ir.iet.visitors import FindNodes
from devito.logger import info, perf
from devito.mpi import MPI
from devito.operator.profiling import create_profile
Expand Down Expand Up @@ -277,6 +279,7 @@ def _build(cls, expressions, **kwargs):
op._dimensions = set().union(*[e.dimensions for e in irs.expressions])
op._dtype, op._dspace = irs.clusters.meta
op._profiler = profiler
kwargs['xdsl_num_sections'] = len(FindNodes(Section).visit(irs.iet))
module = cls._lower_stencil(irs.expressions, **kwargs)
op._module = module

Expand Down
11 changes: 10 additions & 1 deletion devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ class ExtractDevitoStencilConversion:
eqs: list[LoweredEq]
block: Block
temps: dict[tuple[DiscreteFunction, int], SSAValue]
symbol_values: dict[str, SSAValue]
time_offs: int

def __init__(self):
self.temps = dict()
self.symbol_values = dict()

time_offs: int

Expand All @@ -75,6 +77,13 @@ def convert_function_eq(self, eq: LoweredEq, **kwargs):
# Get the function carriers of the equation
self._build_step_body(step_dim, eq)

def convert_symbol_eq(self, symbol: Symbol, rhs: LoweredEq, **kwargs):
"""
Convert a symbol equation to xDSL.
"""
self.symbol_values[symbol.name] = self._visit_math_nodes(None, rhs, None)
self.symbol_values[symbol.name].name_hint = symbol.name

def _convert_eq(self, eq: LoweredEq, **kwargs):
"""
# Docs here Need rewriting
Expand Down Expand Up @@ -122,7 +131,7 @@ def _convert_eq(self, eq: LoweredEq, **kwargs):
f"Function of type {type(write_function.function)} not supported" # noqa
)
case Symbol():
self.convert_symbol_eq(write_function, eq, **kwargs)
self.convert_symbol_eq(write_function, eq.rhs, **kwargs)
case _:
raise NotImplementedError(f"LHS of type {type(write_function)} not supported") # noqa

Expand Down
2 changes: 2 additions & 0 deletions devito/ir/ietxdsl/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def apply_timers(module, **kwargs):
"""
Apply timers to a module
"""
if kwargs['xdsl_num_sections'] < 1:
return
name = kwargs.get("name", "Kernel")
grpa = GreedyRewritePatternApplier([MakeFunctionTimed(name)])
PatternRewriteWalker(grpa, walk_regions_first=True).rewrite_module(module)
15 changes: 1 addition & 14 deletions tests/test_xdsl_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad, diag, solve,
Operator, Eq, Constant, norm, SpaceDimension)
from devito.types import Symbol, Array, Function, TimeFunction
from devito.types import Array, Function, TimeFunction

from xdsl.dialects.scf import For, Yield
from xdsl.dialects.arith import Addi
Expand Down Expand Up @@ -789,19 +789,6 @@ def test_function_IV():

class TestOperatorUnsupported(object):

@pytest.mark.xfail(reason="Symbols are not supported in xDSL yet")
def test_symbol_I(self):
# Define a simple Devito a = 1 operator

a = Symbol('a')
eq0 = Eq(a, 1)

op = Operator([eq0], opt='xdsl')

op.apply()

assert a == 1

@pytest.mark.xfail(reason="stencil.return operation does not verify for i64")
def test_forward_assignment(self):
# simple forward assignment
Expand Down
21 changes: 20 additions & 1 deletion tests/test_xdsl_op_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
# flake8: noqa

from xdsl.dialects.scf import For, Yield
from xdsl.dialects.arith import Addi
from xdsl.dialects.arith import Addi, Constant
from xdsl.dialects.func import Call, Return
from xdsl.dialects.stencil import FieldType, ApplyOp, LoadOp, StoreOp
from xdsl.dialects.llvm import LLVMPointerType
from xdsl.printer import Printer

from devito.types.basic import Symbol

def test_udx():

# Define a simple Devito Operator
Expand Down Expand Up @@ -106,6 +108,23 @@ def test_u_and_v_conversion():
assert type(ops[8] == StoreOp)
assert type(ops[9] == Return)

def test_symbol_I():
# Define a simple Devito a = 1 operator

a = Symbol('a')
eq0 = Eq(a, 1)

op = Operator([eq0], opt='xdsl')

op.apply()

assert len(op._module.regions[0].blocks[0].ops.first.body.blocks[0].ops) == 2

ops = list(op._module.regions[0].blocks[0].ops.first.body.blocks[0].ops)
assert isinstance(ops[0], Constant)
assert ops[0].result.name_hint == a.name
assert type(ops[0] == Return)


# This test should fail, as we are trying to use an inplace operation
@pytest.mark.xfail(reason="Cannot store to a field that is loaded from")
Expand Down

0 comments on commit 5f30fdd

Please sign in to comment.