diff --git a/tests/test_xdsl_op_correctness.py b/tests/test_xdsl_op_correctness.py index e0356d5fde..f0cb28d794 100644 --- a/tests/test_xdsl_op_correctness.py +++ b/tests/test_xdsl_op_correctness.py @@ -231,3 +231,94 @@ def test_dt2_sources_script(shape, nt): norm1 = norm(p) assert np.isclose(norm0, norm1, atol=1e-5, rtol=0) + + +@pytest.mark.xfail(reason="time_m is required in ambiguous setting") +@pytest.mark.parametrize('shape, rtol', [(5, 1e-5), (100, 1e-6)]) +def test_precision_I(shape, rtol): + # Define an upper bound of precision + + nt = 1 + grid = Grid(shape=(shape,)) + u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2) + + init0 = 1. + init1 = 2. + init2 = 3. + + dt = 0.1 + u.data[0, :] = init0 + u.data[1, :] = init1 + u.data[2, :] = init2 + + pde = u.dt2 - u.laplace + stencil = Eq(u.forward, solve(pde, u.forward)) + + opd = Operator([stencil]) + opd.apply(time_m=0, time_M=nt, dt=dt) + dnorm = norm(u) + dnorm0 = np.linalg.norm(u.data[0]) + dnorm1 = np.linalg.norm(u.data[1]) + dnorm2 = np.linalg.norm(u.data[2]) + + u.data[0, :] = init0 + u.data[1, :] = init1 + u.data[2, :] = init2 + + opx = Operator([stencil], opt='xdsl') + opx.apply(time_M=nt, dt=dt) + xnorm = norm(u) + xnorm0 = np.linalg.norm(u.data[0]) + xnorm1 = np.linalg.norm(u.data[1]) + xnorm2 = np.linalg.norm(u.data[2]) + + assert np.isclose(dnorm0, xnorm0, atol=0, rtol=rtol) + assert np.isclose(dnorm1, xnorm1, atol=0, rtol=rtol) + assert np.isclose(dnorm2, xnorm2, atol=0, rtol=rtol) + + assert np.isclose(dnorm, xnorm, atol=0, rtol=rtol) + + +@pytest.mark.parametrize('shape, rtol', [(5, 1e-5), (100, 1e-6)]) +def test_precision_II(shape, rtol): + # Define an upper bound of precision + + nt = 1 + grid = Grid(shape=(shape,)) + u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2) + + init0 = 1. + init1 = 2. + init2 = 3. + + dt = 0.1 + u.data[0, :] = init0 + u.data[1, :] = init1 + u.data[2, :] = init2 + + pde = u.dt2 - u.laplace + stencil = Eq(u.forward, solve(pde, u.forward)) + + opd = Operator([stencil]) + opd.apply(time_m=0, time_M=nt, dt=dt) + dnorm = norm(u) + dnorm0 = np.linalg.norm(u.data[0]) + dnorm1 = np.linalg.norm(u.data[1]) + dnorm2 = np.linalg.norm(u.data[2]) + + u.data[0, :] = init0 + u.data[1, :] = init1 + u.data[2, :] = init2 + + opx = Operator([stencil], opt='xdsl') + opx.apply(time_m=0, time_M=nt, dt=dt) + xnorm = norm(u) + xnorm0 = np.linalg.norm(u.data[0]) + xnorm1 = np.linalg.norm(u.data[1]) + xnorm2 = np.linalg.norm(u.data[2]) + + assert np.isclose(dnorm0, xnorm0, atol=0, rtol=rtol) + assert np.isclose(dnorm1, xnorm1, atol=0, rtol=rtol) + assert np.isclose(dnorm2, xnorm2, atol=0, rtol=rtol) + + assert np.isclose(dnorm, xnorm, atol=0, rtol=rtol)