Skip to content

Commit

Permalink
tests: Add trigonometric
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Jul 16, 2024
1 parent b4fee6e commit 9d6024c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
1 change: 0 additions & 1 deletion devito/ir/xdsl_iet/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr,

elif isinstance(node, sin):
assert len(node.args) == 1, "Expected single argument for sin."
#import pdb; pdb.set_trace()
return math.SinOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

Expand Down
34 changes: 33 additions & 1 deletion tests/test_xdsl_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import pytest
from sympy import S

from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad, diag, solve,
Operator, Eq, Constant, norm, SpaceDimension, switchconfig)
Operator, Eq, Constant, norm, SpaceDimension, switchconfig, sin, cos)
from devito.types import Array, Function, TimeFunction
from devito.tools import as_tuple

Expand Down Expand Up @@ -970,6 +971,37 @@ def test_function_IV():
assert np.isclose(norm(u), devito_norm_u)


class TestTrigonometric(object):

@pytest.mark.parametrize('deg, exp', ([90.0, 3.5759869], [30.0, 3.9521265],
[45.0, 3.403614]))
def test_sine(self, deg, exp):
grid = Grid(shape=(4, 4))

u = Function(name="u", grid=grid)
u.data[:, :] = 0

eq0 = Eq(u, sin(deg))

op = Operator([eq0], opt='xdsl')
op.apply()
assert np.isclose(norm(u), exp, rtol=1e-4)

@pytest.mark.parametrize('deg, exp', ([90.0, 1.7922944], [30.0, 0.6170056],
[45.0, 2.101288]))
def test_cosine(self, deg, exp):
grid = Grid(shape=(4, 4))

u = Function(name="u", grid=grid)
u.data[:, :] = 0

eq0 = Eq(u, cos(deg))

op = Operator([eq0], opt='xdsl')
op.apply()
assert np.isclose(norm(u), exp, rtol=1e-4)


class TestOperatorUnsupported(object):

@pytest.mark.xfail(reason="stencil.return operation does not verify for i64")
Expand Down

0 comments on commit 9d6024c

Please sign in to comment.