diff --git a/devito/ir/xdsl_iet/cluster_to_ssa.py b/devito/ir/xdsl_iet/cluster_to_ssa.py index e4b1c2963f..37713b5322 100644 --- a/devito/ir/xdsl_iet/cluster_to_ssa.py +++ b/devito/ir/xdsl_iet/cluster_to_ssa.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from sympy import (Add, And, Expr, Float, GreaterThan, Indexed, Integer, LessThan, Number, Pow, StrictGreaterThan, StrictLessThan, Symbol, floor, - Mul, sin, cos) + Mul, sin, cos, tan) from sympy.core.relational import Relational from sympy.logic.boolalg import BooleanFunction from devito.operations.interpolators import Injection @@ -288,17 +288,22 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, for arg in node.args) return reduce(lambda x, y : arith.AndI(x, y).result, SSAargs) + # Trigonometric functions elif isinstance(node, sin): assert len(node.args) == 1, "Expected single argument for sin." return math.SinOp(self._visit_math_nodes(dim, node.args[0], output_indexed)).result elif isinstance(node, cos): - assert len(node.args) == 1, "Expected single argument for cos." - + assert len(node.args) == 1, "Expected single argument for cos." return math.CosOp(self._visit_math_nodes(dim, node.args[0], output_indexed)).result + elif isinstance(node, tan): + assert len(node.args) == 1, "Expected single argument for TanOp." + + return math.TanOp(self._visit_math_nodes(dim, node.args[0], + output_indexed)).result elif isinstance(node, Relational): if isinstance(node, GreaterThan): @@ -311,9 +316,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, mnemonic = "slt" else: raise NotImplementedError(f"Unimplemented comparison {type(node)}") - - # import pdb; - # pdb.set_trace() + SSAargs = (self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args) # Operands must have the same type # TODO: look at here if index stuff does not make sense diff --git a/tests/test_xdsl_base.py b/tests/test_xdsl_base.py index bfd793770b..39aa97828b 100644 --- a/tests/test_xdsl_base.py +++ b/tests/test_xdsl_base.py @@ -1,8 +1,9 @@ import numpy as np import pytest -from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad, diag, solve, - Operator, Eq, Constant, norm, SpaceDimension, switchconfig, sin, cos) +from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad, + diag, solve, Operator, Eq, Constant, norm, SpaceDimension, + switchconfig, sin, cos, tan) from devito.types import Array, Function, TimeFunction from devito.tools import as_tuple @@ -14,6 +15,7 @@ from xdsl.dialects.stencil import FieldType, ApplyOp, LoadOp, StoreOp from xdsl.dialects.llvm import LLVMPointerType from xdsl.dialects.memref import Load +from xdsl.dialects.experimental import math def test_xdsl_I(): @@ -980,10 +982,13 @@ def test_sine(self, deg, exp): u = Function(name="u", grid=grid) u.data[:, :] = 0 - eq0 = Eq(u, sin(deg)) + deg0 = Constant(name='deg', value=deg) + eq0 = Eq(u, sin(deg0)) - op = Operator([eq0], opt='xdsl') - op.apply() + opx = Operator([eq0], opt='xdsl') + opx.apply() + + assert len([op for op in opx._module.walk() if isinstance(op, math.SinOp)]) == 1 assert np.isclose(norm(u), exp, rtol=1e-4) @pytest.mark.parametrize('deg, exp', ([90.0, 1.7922944], [30.0, 0.6170056], @@ -994,10 +999,32 @@ def test_cosine(self, deg, exp): u = Function(name="u", grid=grid) u.data[:, :] = 0 - eq0 = Eq(u, cos(deg)) + deg0 = Constant(name='deg', value=deg) + eq0 = Eq(u, cos(deg0)) + + opx = Operator([eq0], opt='xdsl') + opx.apply() + + assert len([op for op in opx._module.walk() if isinstance(op, math.CosOp)]) == 1 + + assert np.isclose(norm(u), exp, rtol=1e-4) + + @pytest.mark.parametrize('deg, exp', ([2.0, 8.74016], [30.0, 25.621325], + [45.0, 6.4791])) + def test_tan(self, deg, exp): + grid = Grid(shape=(4, 4)) + + u = Function(name="u", grid=grid) + u.data[:, :] = 0 + + deg0 = Constant(name='deg', value=deg) + eq0 = Eq(u, tan(deg0)) + + opx = Operator([eq0], opt='xdsl') + opx.apply() + + assert len([op for op in opx._module.walk() if isinstance(op, math.TanOp)]) == 1 - op = Operator([eq0], opt='xdsl') - op.apply() assert np.isclose(norm(u), exp, rtol=1e-4)