Skip to content

Commit

Permalink
tti: Add tests in MFE
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Aug 8, 2024
1 parent f1e48d9 commit a9cf049
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions tests/test_xdsl_op_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,93 @@ def test_dt2_sources_script(shape, nt):
norm1 = norm(p)

assert np.isclose(norm0, norm1, atol=1e-5, rtol=0)

@pytest.mark.xfail(reason="time_m is required in ambiguous setting")
@pytest.mark.parametrize('shape, rtol', [(5, 1e-5), (100, 1e-6)])
def test_precision_I(shape, rtol):
# Define an upper bound of precision

nt = 1
grid = Grid(shape=(shape,))
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)

init0 = 1.
init1 = 2.
init2 = 3.

dt = 0.1
u.data[0, :] = init0
u.data[1, :] = init1
u.data[2, :] = init2

pde = u.dt2 - u.laplace
stencil = Eq(u.forward, solve(pde, u.forward))

opd = Operator([stencil])
opd.apply(time_m=0, time_M=nt, dt=dt)
dnorm = norm(u)
dnorm0 = np.linalg.norm(u.data[0])
dnorm1 = np.linalg.norm(u.data[1])
dnorm2 = np.linalg.norm(u.data[2])

u.data[0, :] = init0
u.data[1, :] = init1
u.data[2, :] = init2

opx = Operator([stencil], opt='xdsl')
opx.apply(time_M=nt, dt=dt)
xnorm = norm(u)
xnorm0 = np.linalg.norm(u.data[0])
xnorm1 = np.linalg.norm(u.data[1])
xnorm2 = np.linalg.norm(u.data[2])

assert np.isclose(dnorm0, xnorm0, atol=0, rtol=rtol)
assert np.isclose(dnorm1, xnorm1, atol=0, rtol=rtol)
assert np.isclose(dnorm2, xnorm2, atol=0, rtol=rtol)

assert np.isclose(dnorm, xnorm, atol=0, rtol=rtol)


@pytest.mark.parametrize('shape, rtol', [(5, 1e-5), (100, 1e-6)])
def test_precision_II(shape, rtol):
# Define an upper bound of precision

nt = 1
grid = Grid(shape=(shape,))
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)

init0 = 1.
init1 = 2.
init2 = 3.

dt = 0.1
u.data[0, :] = init0
u.data[1, :] = init1
u.data[2, :] = init2

pde = u.dt2 - u.laplace
stencil = Eq(u.forward, solve(pde, u.forward))

opd = Operator([stencil])
opd.apply(time_m=0, time_M=nt, dt=dt)
dnorm = norm(u)
dnorm0 = np.linalg.norm(u.data[0])
dnorm1 = np.linalg.norm(u.data[1])
dnorm2 = np.linalg.norm(u.data[2])

u.data[0, :] = init0
u.data[1, :] = init1
u.data[2, :] = init2

opx = Operator([stencil], opt='xdsl')
opx.apply(time_m=0, time_M=nt, dt=dt)
xnorm = norm(u)
xnorm0 = np.linalg.norm(u.data[0])
xnorm1 = np.linalg.norm(u.data[1])
xnorm2 = np.linalg.norm(u.data[2])

assert np.isclose(dnorm0, xnorm0, atol=0, rtol=rtol)
assert np.isclose(dnorm1, xnorm1, atol=0, rtol=rtol)
assert np.isclose(dnorm2, xnorm2, atol=0, rtol=rtol)

assert np.isclose(dnorm, xnorm, atol=0, rtol=rtol)

0 comments on commit a9cf049

Please sign in to comment.