Skip to content

Commit

Permalink
Merge pull request #22 from xdslproject/emilien/try-fix-wave
Browse files Browse the repository at this point in the history
Reverse stencil.apply inputs and try to name accordingly.
  • Loading branch information
georgebisbas authored Aug 2, 2023
2 parents d9c4239 + afca66d commit f2604c6
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 8 deletions.
2 changes: 1 addition & 1 deletion devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ def match_and_rewrite(self, op: iet_ssa.Stencil, rewriter: PatternRewriter, /):

for field in op.input_indices:
rewriter.insert_op_before_matched_op(load_op := stencil.LoadOp.get(field))
input_temps.append(load_op.res)
load_op.res.name_hint = field.name_hint + "_temp"
input_temps.insert(0, load_op.res)

rewriter.replace_matched_op(
[
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/ietxdsl/iet_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def get(
stencil.TempType(len(shape), typ)
] * (time_buffers - 1))

for block_arg, idx_arg in zip(block.args, time_indices):
for block_arg, idx_arg in zip(block.args, reversed(inputs)):
name = SSAValue.get(idx_arg).name_hint
if name is None:
continue
Expand Down
3 changes: 3 additions & 0 deletions devito/ir/ietxdsl/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def match_and_rewrite(self, op: iet_ssa.For, rewriter: PatternRewriter, /):
]
rewriter.insert_op_before_matched_op(subindice_vals)

subindice_vals = list(reversed(subindice_vals))
subindice_vals.append(subindice_vals.pop(0))

rewriter.replace_matched_op([
cst1 := arith.Constant.from_int_and_width(1, builtin.IndexType()),
new_ub := arith.Addi(op.ub, cst1),
Expand Down
2 changes: 1 addition & 1 deletion devito/operator/xdsl_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

XDSL_CPU_PIPELINE = "stencil-shape-inference,convert-stencil-to-ll-mlir,printf-to-llvm"
XDSL_GPU_PIPELINE = "stencil-shape-inference,convert-stencil-to-ll-mlir{target=gpu},printf-to-llvm"
XDSL_MPI_PIPELINE = lambda decomp: f'"dmp-decompose-2d{decomp},convert-stencil-to-ll-mlir,dmp-to-mpi{{mpi_init=false}},lower-mpi,printf-to-llvm"'
XDSL_MPI_PIPELINE = lambda decomp: f'"dmp-decompose-2d{decomp},canonicalize-dmp,convert-stencil-to-ll-mlir,dmp-to-mpi{{mpi_init=false}},lower-mpi,printf-to-llvm"'


class XDSLOperator(Operator):
Expand Down
16 changes: 11 additions & 5 deletions fast/nd_nwave_devito_nodamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def plot_3dfunc(u):
t0 = 0. # Simulation starts a t=0
tn = nt # Simulation last 1 second (1000 ms)
dt = model.critical_dt # Time step from model grid spacing
print("dt is:", dt)

time_range = TimeAxis(start=t0, stop=tn, step=dt)

Expand Down Expand Up @@ -115,20 +116,25 @@ def plot_3dfunc(u):
initdata = u.data[:]

# Run more with no sources now (Not supported in xdsl)
xdslop = XDSLOperator([stencil], name='XDSLOperator')
xdslop.apply(time=time_range.num-1, dt=model.critical_dt)
op = Operator([stencil], name='DevitoOperator', opt='noop')
op.apply(time=time_range.num-1, dt=model.critical_dt)

if len(shape) == 3:
if args.plot:
plot_3dfunc(u)

print(norm(u))

devito_output = u.copy()
print("Devito norm:", norm(u))
print(f"devito output norm: {norm(devito_output)}")

# Reset initial data
u.data[:] = initdata

# Run more with no sources now (Not supported in xdsl)
xdslop = XDSLOperator([stencil])
xdslop = XDSLOperator([stencil], name='xDSLOperator')
xdslop.apply(time=time_range.num-1, dt=model.critical_dt)

print(norm(u))
xdsl_output = u.copy()
print("XDSL norm:", norm(u))
print(f"xdsl output norm: {norm(xdsl_output)}")
170 changes: 170 additions & 0 deletions fast/wave2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Based on the implementation of the Devito acoustic example implementation
# Not using Devito's source injection abstraction
import sys
import numpy as np
from devito import TimeFunction, Eq, Operator, solve, norm, XDSLOperator
from examples.seismic import RickerSource
from examples.seismic import Model, TimeAxis

from devito.tools import as_tuple

import argparse
np.set_printoptions(threshold=np.inf)


parser = argparse.ArgumentParser(description='Process arguments.')

parser.add_argument("-d", "--shape", default=(16, 16), type=int, nargs="+",
help="Number of grid points along each axis")
parser.add_argument("-so", "--space_order", default=4,
type=int, help="Space order of the simulation")
parser.add_argument("-to", "--time_order", default=2,
type=int, help="Time order of the simulation")
parser.add_argument("-nt", "--nt", default=20,
type=int, help="Simulation time in millisecond")
parser.add_argument("-bls", "--blevels", default=1, type=int, nargs="+",
help="Block levels")
parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D")
args = parser.parse_args()


def plot_2dfunc(u):
# Plot a 3D structured grid using pyvista

import matplotlib.pyplot as plt
import pyvista as pv
cmap = plt.colormaps["viridis"]
values = u.data[0, :, :, :]
vistagrid = pv.UniformGrid()
vistagrid.dimensions = np.array(values.shape) + 1
vistagrid.spacing = (1, 1, 1)
vistagrid.origin = (0, 0, 0) # The bottom left corner of the data set
vistagrid.cell_data["values"] = values.flatten(order="F")
vistaslices = vistagrid.slice_orthogonal()
# vistagrid.plot(show_edges=True)
vistaslices.plot(cmap=cmap)


# Define a physical size
# nx, ny, nz = args.shape
nt = args.nt

shape = (args.shape) # Number of grid point (nx, ny, nz)
spacing = as_tuple(10.0 for _ in range(len(shape))) # Grid spacing in m. The domain size is now 1km by 1km
origin = as_tuple(0.0 for _ in range(len(shape))) # What is the location of the top left corner.
# This is necessary to define
# the absolute location of the source and receivers

# Define a velocity profile. The velocity is in km/s
v = np.empty(shape, dtype=np.float32)
v[:, :] = 1

# With the velocity and model size defined, we can create the seismic model that
# encapsulates this properties. We also define the size of the absorbing layer as
# 10 grid points
so = args.space_order
to = args.time_order

model = Model(vp=v, origin=origin, shape=shape, spacing=spacing,
space_order=so, nbl=0)

# plot_velocity(model)

t0 = 0. # Simulation starts a t=0
tn = nt # Simulation last 1 second (1000 ms)
dt = model.critical_dt # Time step from model grid spacing
print("dt is:", dt)

time_range = TimeAxis(start=t0, stop=tn, step=dt)

# The source is positioned at a $20m$ depth and at the middle of the
# $x$ axis ($x_{src}=500m$),
# with a peak wavelet frequency of $10Hz$.
f0 = 0.010 # Source peak frequency is 10Hz (0.010 kHz)
src = RickerSource(name='src', grid=model.grid, f0=f0,
npoint=1, time_range=time_range)

# First, position source centrally in all dimensions, then set depth
src.coordinates.data[0, :] = np.array(model.domain_size) * .5

# We can plot the time signature to see the wavelet
# src.show()

# Define the wavefield with the size of the model and the time dimension
u = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so)

u2 = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so)

# We can now write the PDE
# pde = model.m * u.dt2 - u.laplace + model.damp * u.dt
# import pdb;pdb.set_trace()
pde = u.dt2 - u.laplace

# The PDE representation is as on paper
pde

stencil = Eq(u.forward, solve(pde, u.forward))
stencil

# Finally we define the source injection and receiver read function to generate
# the corresponding code
print(time_range)

print("Init norm:", norm(u))
src_term = src.inject(field=u.forward, expr=src * dt**2 / model.m)
op0 = Operator([stencil] + src_term, subs=model.spacing_map, name='SourceDevitoOperator')
# Run with source and plot
op0.apply(time=time_range.num-1, dt=model.critical_dt)

if len(shape) == 2:
if args.plot:
plot_2dfunc(u)

print("Init Devito linalg norm 0 :", np.linalg.norm(u.data[0]))
print("Init Devito linalg norm 1 :", np.linalg.norm(u.data[1]))
print("Init Devito linalg norm 2 :", np.linalg.norm(u.data[2]))

print("Norm of initial data:", norm(u))
# import pdb;pdb.set_trace()
u2.data[:] = u.data[:]

# Run more with no sources now (Not supported in xdsl)
op1 = Operator([stencil], name='DevitoOperator')
op1.apply(time=time_range.num-1, dt=model.critical_dt)

if len(shape) == 2:
if args.plot:
plot_3dfunc(u)

#devito_output = u.data[:]
print("After Operator 1: Devito norm:", norm(u))
print("Devito linalg norm 0:", np.linalg.norm(u.data[0]))
print("Devito linalg norm 1:", np.linalg.norm(u.data[1]))
print("Devito linalg norm 2:", np.linalg.norm(u.data[2]))

# import pdb;pdb.set_trace()


# Reset initial data
u.data[:] = u2.data[:]
#v[:, ..., :] = 1


print("Reinitialise data: Devito norm:", norm(u))
print("Init XDSL linalg norm:", np.linalg.norm(u.data[0]))
print("Init XDSL linalg norm:", np.linalg.norm(u.data[1]))
print("Init XDSL linalg norm:", np.linalg.norm(u.data[2]))

# Run more with no sources now (Not supported in xdsl)
xdslop = XDSLOperator([stencil], name='xDSLOperator')
xdslop.apply(time=time_range.num-1, dt=model.critical_dt)

xdsl_output = u.copy()
print("XDSL norm:", norm(u))
print(f"xdsl output norm: {norm(xdsl_output)}")

print("XDSL output linalg norm:", np.linalg.norm(u.data[0]))
print("XDSL output linalg norm:", np.linalg.norm(u.data[1]))
print("XDSL output linalg norm:", np.linalg.norm(u.data[2]))


Loading

0 comments on commit f2604c6

Please sign in to comment.