Skip to content

Commit

Permalink
Merge pull request #86 from xdslproject/elasting_testing_II
Browse files Browse the repository at this point in the history
tests: Add elastic parametric
  • Loading branch information
georgebisbas authored May 16, 2024
2 parents 419fbca + f58567a commit b842563
Showing 1 changed file with 132 additions and 2 deletions.
134 changes: 132 additions & 2 deletions tests/test_xdsl_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import numpy as np
import pytest

from devito import Grid, TimeFunction, Eq, Operator, solve, norm, Function
from devito.types import Symbol, Array
from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad, diag, solve,
Operator, Eq, Constant, norm, SpaceDimension)
from devito.types import Symbol, Array, Function, TimeFunction

from xdsl.dialects.scf import For, Yield
from xdsl.dialects.arith import Addi
from xdsl.dialects.func import Call, Return
from xdsl.dialects.stencil import FieldType, ApplyOp, LoadOp, StoreOp
from xdsl.dialects.llvm import LLVMPointerType

from examples.seismic.source import RickerSource, TimeAxis


def test_xdsl_I():
# Define a simple Devito Operator
Expand Down Expand Up @@ -789,3 +792,130 @@ def test_forward_assignment(self):
op.apply(time_M=1)

assert np.isclose(norm(u), 5.6584, rtol=0.001)

@pytest.mark.xfail(reason="stencil.return operation does not verify for i64")
def test_function(self):
grid = Grid(shape=(5, 5))
x, y = grid.dimensions

f = Function(name="f", grid=grid)

eqns = [Eq(f, 2)]

op = Operator(eqns, opt='xdsl')
op.apply()

assert np.all(f.data == 4)


class TestElastic():

@pytest.mark.parametrize('shape', [(101, 101), (201, 201), (301, 301)])
@pytest.mark.parametrize('so', [2, 4, 8])
@pytest.mark.parametrize('nt', [10, 20, 50, 100])
def test_elastic_2D(self, shape, so, nt):

# Initial grid: km x km, with spacing
extent = (1500., 1500.)
shape = shape
x = SpaceDimension(name='x', spacing=Constant(name='h_x', value=extent[0]/(shape[0]-1))) # noqa
z = SpaceDimension(name='z', spacing=Constant(name='h_z', value=extent[1]/(shape[1]-1))) # noqa
grid = Grid(extent=extent, shape=shape, dimensions=(x, z))

# To be checked again in the future
# class DGaussSource(WaveletSource):

# def wavelet(self, f0, t):
# a = 0.004
# return -2.*a*(t - 1/f0) * np.exp(-a * (t - 1/f0)**2)

# Timestep size from Eq. 7 with V_p=6000. and dx=100
t0, tn = 0., nt
dt = (10. / np.sqrt(2.)) / 6.
time_range = TimeAxis(start=t0, stop=tn, step=dt)

src = RickerSource(name='src', grid=grid, f0=0.01, time_range=time_range)
src.coordinates.data[:] = [250., 250.]

# Now we create the velocity and pressure fields
v = VectorTimeFunction(name='v', grid=grid, space_order=so, time_order=1)
tau = TensorTimeFunction(name='t', grid=grid, space_order=so, time_order=1)

# We need some initial conditions
V_p = 2.0
V_s = 1.0
density = 1.8

# The source injection term
src_xx = src.inject(field=tau.forward[0, 0], expr=src)
src_zz = src.inject(field=tau.forward[1, 1], expr=src)

# Thorbecke's parameter notation
cp2 = V_p*V_p
cs2 = V_s*V_s
ro = 1/density

mu = cs2*density
l = (cp2*density - 2*mu)

# First order elastic wave equation
pde_v = v.dt - ro * div(tau)
pde_tau = (tau.dt - l * diag(div(v.forward)) - mu * (grad(v.forward) +
grad(v.forward).transpose(inner=False)))

# Time update
u_v = Eq(v.forward, solve(pde_v, v.forward))
u_t = Eq(tau.forward, solve(pde_tau, tau.forward))

# Inject sources. We use it to preinject data
# Up to here, let's only use Devito
op = Operator([u_v] + [u_t] + src_xx + src_zz)
op(dt=dt)

op = Operator([u_v] + [u_t], opt='xdsl')
op(dt=dt, time_M=nt)

xdsl_norm_v0 = norm(v[0])
xdsl_norm_v1 = norm(v[1])
xdsl_norm_tau0 = norm(tau[0])
xdsl_norm_tau1 = norm(tau[1])
xdsl_norm_tau2 = norm(tau[2])
xdsl_norm_tau3 = norm(tau[3])

# Reinitialize the fields to zero

v[0].data[:] = 0
v[1].data[:] = 0

tau[0].data[:] = 0
tau[1].data[:] = 0
tau[2].data[:] = 0
tau[3].data[:] = 0

# Inject sources. We use it to preinject data
op = Operator([u_v] + [u_t] + src_xx + src_zz)
op(dt=dt)

op = Operator([u_v] + [u_t], opt='advanced')
op(dt=dt, time_M=nt)

dv_norm_v0 = norm(v[0])
dv_norm_v1 = norm(v[1])
dv_norm_tau0 = norm(tau[0])
dv_norm_tau1 = norm(tau[1])
dv_norm_tau2 = norm(tau[2])
dv_norm_tau3 = norm(tau[3])

assert np.isclose(xdsl_norm_v0, dv_norm_v0, rtol=1e-04)
assert np.isclose(xdsl_norm_v1, dv_norm_v1, rtol=1e-04)
assert np.isclose(xdsl_norm_tau0, dv_norm_tau0, rtol=1e-04)
assert np.isclose(xdsl_norm_tau1, dv_norm_tau1, rtol=1e-04)
assert np.isclose(xdsl_norm_tau2, dv_norm_tau2, rtol=1e-04)
assert np.isclose(xdsl_norm_tau3, dv_norm_tau3, rtol=1e-04)

assert not np.isclose(xdsl_norm_v0, 0.0, rtol=1e-04)
assert not np.isclose(xdsl_norm_v1, 0.0, rtol=1e-04)
assert not np.isclose(xdsl_norm_tau0, 0.0, rtol=1e-04)
assert not np.isclose(xdsl_norm_tau1, 0.0, rtol=1e-04)
assert not np.isclose(xdsl_norm_tau2, 0.0, rtol=1e-04)
assert not np.isclose(xdsl_norm_tau3, 0.0, rtol=1e-04)

0 comments on commit b842563

Please sign in to comment.