Skip to content

Commit

Permalink
compiler: Add tan
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Jul 16, 2024
1 parent ca681a4 commit 1bdf1f1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
15 changes: 9 additions & 6 deletions devito/ir/xdsl_iet/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from sympy import (Add, And, Expr, Float, GreaterThan, Indexed, Integer, LessThan,
Number, Pow, StrictGreaterThan, StrictLessThan, Symbol, floor,
Mul, sin, cos)
Mul, sin, cos, tan)
from sympy.core.relational import Relational
from sympy.logic.boolalg import BooleanFunction
from devito.operations.interpolators import Injection
Expand Down Expand Up @@ -288,17 +288,22 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr,
for arg in node.args)
return reduce(lambda x, y : arith.AndI(x, y).result, SSAargs)

# Trigonometric functions
elif isinstance(node, sin):
assert len(node.args) == 1, "Expected single argument for sin."
return math.SinOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, cos):
assert len(node.args) == 1, "Expected single argument for cos."

assert len(node.args) == 1, "Expected single argument for cos."
return math.CosOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, tan):
assert len(node.args) == 1, "Expected single argument for TanOp."

return math.TanOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, Relational):
if isinstance(node, GreaterThan):
Expand All @@ -311,9 +316,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr,
mnemonic = "slt"
else:
raise NotImplementedError(f"Unimplemented comparison {type(node)}")

# import pdb;
# pdb.set_trace()

SSAargs = (self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args)
# Operands must have the same type
# TODO: look at here if index stuff does not make sense
Expand Down
43 changes: 35 additions & 8 deletions tests/test_xdsl_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import pytest

from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad, diag, solve,
Operator, Eq, Constant, norm, SpaceDimension, switchconfig, sin, cos)
from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad,
diag, solve, Operator, Eq, Constant, norm, SpaceDimension,
switchconfig, sin, cos, tan)
from devito.types import Array, Function, TimeFunction
from devito.tools import as_tuple

Expand All @@ -14,6 +15,7 @@
from xdsl.dialects.stencil import FieldType, ApplyOp, LoadOp, StoreOp
from xdsl.dialects.llvm import LLVMPointerType
from xdsl.dialects.memref import Load
from xdsl.dialects.experimental import math


def test_xdsl_I():
Expand Down Expand Up @@ -980,10 +982,13 @@ def test_sine(self, deg, exp):
u = Function(name="u", grid=grid)
u.data[:, :] = 0

eq0 = Eq(u, sin(deg))
deg0 = Constant(name='deg', value=deg)
eq0 = Eq(u, sin(deg0))

op = Operator([eq0], opt='xdsl')
op.apply()
opx = Operator([eq0], opt='xdsl')
opx.apply()

assert len([op for op in opx._module.walk() if isinstance(op, math.SinOp)]) == 1
assert np.isclose(norm(u), exp, rtol=1e-4)

@pytest.mark.parametrize('deg, exp', ([90.0, 1.7922944], [30.0, 0.6170056],
Expand All @@ -994,10 +999,32 @@ def test_cosine(self, deg, exp):
u = Function(name="u", grid=grid)
u.data[:, :] = 0

eq0 = Eq(u, cos(deg))
deg0 = Constant(name='deg', value=deg)
eq0 = Eq(u, cos(deg0))

opx = Operator([eq0], opt='xdsl')
opx.apply()

assert len([op for op in opx._module.walk() if isinstance(op, math.CosOp)]) == 1

assert np.isclose(norm(u), exp, rtol=1e-4)

@pytest.mark.parametrize('deg, exp', ([2.0, 8.74016], [30.0, 25.621325],
[45.0, 6.4791]))
def test_tan(self, deg, exp):
grid = Grid(shape=(4, 4))

u = Function(name="u", grid=grid)
u.data[:, :] = 0

deg0 = Constant(name='deg', value=deg)
eq0 = Eq(u, tan(deg0))

opx = Operator([eq0], opt='xdsl')
opx.apply()

assert len([op for op in opx._module.walk() if isinstance(op, math.TanOp)]) == 1

op = Operator([eq0], opt='xdsl')
op.apply()
assert np.isclose(norm(u), exp, rtol=1e-4)


Expand Down

0 comments on commit 1bdf1f1

Please sign in to comment.