diff --git a/devito/ir/ietxdsl/cluster_to_ssa.py b/devito/ir/ietxdsl/cluster_to_ssa.py index 2ffb0d1887..1cb9c347d3 100644 --- a/devito/ir/ietxdsl/cluster_to_ssa.py +++ b/devito/ir/ietxdsl/cluster_to_ssa.py @@ -272,7 +272,7 @@ def _ensure_same_type(self, *vals: SSAValue): new_vals.append(val) continue # insert an integer to float cast op - conv = arith.SIToFPOp.get(val, builtin.f32) + conv = arith.SIToFPOp(val, builtin.f32) self.block.add_op(conv) new_vals.append(conv.result) return new_vals @@ -340,7 +340,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): self.seen_ops.add(op) rewriter.insert_op_at_start([ - t0 := func.Call.get('timer_start', [], [builtin.f64]) + t0 := func.Call('timer_start', [], [builtin.f64]) ], op.body.block) ret = op.get_return_op() @@ -348,8 +348,8 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): rewriter.insert_op_before([ timers := iet_ssa.LoadSymbolic.get('timers', llvm.LLVMPointerType.typed(builtin.f64)), - t1 := func.Call.get('timer_end', [t0], [builtin.f64]), - llvm.StoreOp.get(t1, timers), + t1 := func.Call('timer_end', [t0], [builtin.f64]), + llvm.StoreOp(t1, timers), ], ret) rewriter.insert_op_after_matched_op([ diff --git a/devito/ir/ietxdsl/ietxdsl_functions.py b/devito/ir/ietxdsl/ietxdsl_functions.py index 20a4d105d8..b4a9498fb1 100644 --- a/devito/ir/ietxdsl/ietxdsl_functions.py +++ b/devito/ir/ietxdsl/ietxdsl_functions.py @@ -20,12 +20,11 @@ # XDSL specific imports from xdsl.irdl import AnyOf, Operation, SSAValue from xdsl.dialects.builtin import (ContainerOf, Float16Type, Float32Type, - Float64Type, Builtin, i32, f32) + Float64Type, i32, f32) -from xdsl.dialects.arith import Muli, Addi from devito.ir.ietxdsl import iet_ssa -from xdsl.dialects import memref, arith, builtin, llvm +from xdsl.dialects import memref, arith, builtin from xdsl.dialects.experimental import math import devito.types @@ -74,7 +73,7 @@ def print_calls(cgen, calldefs): print("Call not translated in calldefs") return - call = Call.get(call_name, C_names, C_typenames, C_typeqs, prefix, retval) + call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval) cgen.printCall(call, True) @@ -180,10 +179,10 @@ def add_to_block(expr, arg_by_expr: dict[Any, Operation], result): # reconcile differences if isinstance(rhs.typ, builtin.IntegerType): - rhs = arith.SIToFPOp.get(rhs, lhs.typ) + rhs = arith.SIToFPOp(rhs, lhs.typ) result.append(rhs) else: - lhs = arith.SIToFPOp.get(lhs, rhs.typ) + lhs = arith.SIToFPOp(lhs, rhs.typ) result.append(lhs) @@ -426,7 +425,7 @@ def myVisit(node, block: Block, ssa_vals={}): print(f"Call {node.name} instance translated as comment") return - call = Call.get(call_name, C_names, C_typenames, C_typeqs, prefix, retval) + call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval) block.add_ops([call]) print(f"Call {node.name} translated") diff --git a/devito/ir/ietxdsl/lowering.py b/devito/ir/ietxdsl/lowering.py index 515fe2f7a1..8ccf6b6d25 100644 --- a/devito/ir/ietxdsl/lowering.py +++ b/devito/ir/ietxdsl/lowering.py @@ -371,7 +371,7 @@ def match_and_rewrite(self, op: memref.Store, rewriter: PatternRewriter, ssa_indices=[idx], result_type=llvm.LLVMPointerType.typed(op.memref.memref.element_type) ), - store := llvm.StoreOp.get(op.value, gep), + store := llvm.StoreOp(op.value, gep), ], [], ) diff --git a/devito/mpi/distributed.py b/devito/mpi/distributed.py index 5b0c890ebe..50577659cb 100644 --- a/devito/mpi/distributed.py +++ b/devito/mpi/distributed.py @@ -196,7 +196,9 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None): # mpi4py takes care of that when the object gets out of scope self._input_comm = (input_comm or MPI.COMM_WORLD).Clone() - topology = ('*', '*', 1) + if len(shape) == 3: + topology = ('*', '*', 1) + if topology is None: # `MPI.Compute_dims` sets the dimension sizes to be as close to each other # as possible, using an appropriate divisibility algorithm. Thus, in 3D: diff --git a/devito/operator/xdsl_operator.py b/devito/operator/xdsl_operator.py index 2ffe000863..9e59f03bec 100644 --- a/devito/operator/xdsl_operator.py +++ b/devito/operator/xdsl_operator.py @@ -53,9 +53,9 @@ # gpu-launch-sink-index-computations seemed to have no impact MLIR_GPU_PIPELINE = '"builtin.module(test-math-algebraic-simplification,scf-parallel-loop-tiling{parallel-loop-tile-sizes=128,1,1},func.func(gpu-map-parallel-loops),convert-parallel-loops-to-gpu,fold-memref-alias-ops,expand-strided-metadata,lower-affine,gpu-kernel-outlining,canonicalize,cse,convert-arith-to-llvm{index-bitwidth=64},finalize-memref-to-llvm{index-bitwidth=64},convert-scf-to-cf,convert-cf-to-llvm{index-bitwidth=64},canonicalize,cse,gpu.module(convert-gpu-to-nvvm,reconcile-unrealized-casts,canonicalize,gpu-to-cubin),gpu-to-llvm,canonicalize,cse)"' -XDSL_CPU_PIPELINE = "stencil-shape-inference,convert-stencil-to-ll-mlir,printf-to-llvm" -XDSL_GPU_PIPELINE = "stencil-shape-inference,convert-stencil-to-ll-mlir{target=gpu},printf-to-llvm" -XDSL_MPI_PIPELINE = lambda decomp: f'"dmp-decompose-2d{decomp},canonicalize-dmp,convert-stencil-to-ll-mlir,dmp-to-mpi{{mpi_init=false}},lower-mpi,printf-to-llvm"' +XDSL_CPU_PIPELINE = lambda nb_tiled_dims: f'"stencil-shape-inference,convert-stencil-to-ll-mlir{{tile-sizes={",".join(["64"]*nb_tiled_dims)}}},printf-to-llvm"' +XDSL_GPU_PIPELINE = '"stencil-shape-inference,convert-stencil-to-ll-mlir{target=gpu},printf-to-llvm"' +XDSL_MPI_PIPELINE = lambda decomp, nb_tiled_dims: f'"dmp-decompose-2d{decomp},canonicalize-dmp,convert-stencil-to-ll-mlir{{tile-sizes={",".join(["64"]*nb_tiled_dims)}}},dmp-to-mpi{{mpi_init=false}},lower-mpi,printf-to-llvm"' class XDSLOperator(Operator): @@ -85,7 +85,10 @@ def _make_interop_o(self): @property def mpi_shape(self) -> tuple: dist = self.functions[0].grid.distributor - return dist.topology, dist.myrank + # temporary fix: + # swap dim 0 and 1 in topology because dmp.grid is row major and not column major + + return (dist.topology[1], dist.topology[0], *dist.topology[2:]), dist.myrank def _jit_compile(self): """ @@ -114,7 +117,9 @@ def _jit_compile(self): Printer(stream=module_str).print(self._module) module_str = module_str.getvalue() - xdsl_pipeline = XDSL_CPU_PIPELINE + to_tile = len(list(filter(lambda s : str(s) in ["x", "y", "z"], self.dimensions)))-1 + + xdsl_pipeline = XDSL_CPU_PIPELINE(to_tile) mlir_pipeline = MLIR_CPU_PIPELINE if is_omp: @@ -126,7 +131,7 @@ def _jit_compile(self): # reduce the domain of the computation (as devito has already done that for us) slices = ','.join(str(x) for x in shape) decomp = f"{{strategy=2d-grid slices={slices} restrict_domain=false}}" - xdsl_pipeline = XDSL_MPI_PIPELINE(decomp) + xdsl_pipeline = XDSL_MPI_PIPELINE(decomp, to_tile) elif is_gpu: xdsl_pipeline = XDSL_GPU_PIPELINE mlir_pipeline = MLIR_GPU_PIPELINE diff --git a/fast/bench_utils.py b/fast/bench_utils.py new file mode 100644 index 0000000000..2efc9fa478 --- /dev/null +++ b/fast/bench_utils.py @@ -0,0 +1,25 @@ +import matplotlib.pyplot as plt +import numpy as np +from examples.seismic import plot_image + +__all__ = ['plot_2dfunc', 'plot_3dfunc'] + + +def plot_2dfunc(u): + # Plot a 2D image using devito's machinery + plot_image(u.data[0], cmap='seismic') + plot_image(u.data[1], cmap='seismic') + + +def plot_3dfunc(u): + # Plot a 3D structured grid using pyvista + import pyvista as pv + cmap = plt.colormaps["viridis"] + values = u.data[0, :, :, :] + vistagrid = pv.ImageData() + vistagrid.dimensions = np.array(values.shape) + 1 + vistagrid.spacing = (1, 1, 1) + vistagrid.origin = (0, 0, 0) # The bottom left corner of the data set + vistagrid.cell_data["values"] = values.flatten(order="F") + vistaslices = vistagrid.slice_orthogonal() + vistaslices.plot(cmap=cmap) diff --git a/fast/diffusion_2D_wBCs.py b/fast/diffusion_2D_wBCs.py index 0de0340254..917454487d 100644 --- a/fast/diffusion_2D_wBCs.py +++ b/fast/diffusion_2D_wBCs.py @@ -6,8 +6,8 @@ import numpy as np from devito import Grid, TimeFunction, Eq, solve, Operator, Constant, norm, XDSLOperator -from examples.seismic import plot_image from examples.cfd import init_hat +from fast.bench_utils import plot_2dfunc parser = argparse.ArgumentParser(description='Process arguments.') @@ -21,7 +21,9 @@ type=int, help="Simulation time in millisecond") parser.add_argument("-bls", "--blevels", default=2, type=int, nargs="+", help="Block levels") -parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot3D") +parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D") +parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run") +parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run") args = parser.parse_args() # Some variable declarations @@ -41,9 +43,6 @@ grid = Grid(shape=(nx, ny), extent=(2., 2.)) u = TimeFunction(name='u', grid=grid, space_order=so) -# Reset our data field and ICs -init_hat(field=u.data[0], dx=dx, dy=dy, value=1.) - a = Constant(name='a') # Create an equation with second-order derivatives # eq = Eq(u.dt, a * u.laplace, subdomain=grid.interior) @@ -51,19 +50,21 @@ stencil = solve(eq, u.forward) eq_stencil = Eq(u.forward, stencil) -# Create boundary condition expressions -x, y = grid.dimensions -t = grid.stepping_dim +# Reset our data field and ICs +init_hat(field=u.data[0], dx=dx, dy=dy, value=1.) -initdata = u.data[:] -op = Operator([eq_stencil], name='DevitoOperator') -op.apply(time=nt, dt=dt, a=nu) +if args.devito: + op = Operator([eq_stencil], name='DevitoOperator') + op.apply(time=nt, dt=dt, a=nu) + print("Devito Field norm is:", norm(u)) -print("Devito Field norm is:", norm(u)) + if args.plot: + plot_2dfunc(u) -# Reset data and run XDSLOperator +# Reset data init_hat(field=u.data[0], dx=dx, dy=dy, value=1.) -xdslop = XDSLOperator([eq_stencil], name='XDSLOperator') -xdslop.apply(time=nt, dt=dt, a=nu) -print("XDSL Field norm is:", norm(u)) +if args.xdsl: + xdslop = XDSLOperator([eq_stencil], name='XDSLOperator') + xdslop.apply(time=nt, dt=dt, a=nu) + print("XDSL Field norm is:", norm(u)) diff --git a/fast/diffusion_3D_wBCs.py b/fast/diffusion_3D_wBCs.py index eebd13a7c5..9d0715f340 100644 --- a/fast/diffusion_3D_wBCs.py +++ b/fast/diffusion_3D_wBCs.py @@ -5,7 +5,9 @@ import argparse import numpy as np -from devito import Grid, TimeFunction, Eq, solve, Operator, Constant, norm, XDSLOperator +from devito import (Grid, TimeFunction, Eq, solve, Operator, Constant, + norm, XDSLOperator) +from fast.bench_utils import plot_3dfunc parser = argparse.ArgumentParser(description='Process arguments.') @@ -20,27 +22,10 @@ parser.add_argument("-bls", "--blevels", default=2, type=int, nargs="+", help="Block levels") parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot3D") +parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run") +parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run") args = parser.parse_args() - -def plot_3dfunc(u): - # Plot a 3D structured grid using pyvista - - import matplotlib.pyplot as plt - import pyvista as pv - - cmap = plt.cm.get_cmap("viridis") - values = u.data[0, :, :, :] - vistagrid = pv.ImageData() - vistagrid.dimensions = np.array(values.shape) + 1 - vistagrid.spacing = (1, 1, 1) - vistagrid.origin = (0, 0, 0) # The bottom left corner of the data set - vistagrid.cell_data["values"] = values.flatten(order="F") - vistaslices = vistagrid.slice_orthogonal() - # vistagrid.plot(show_edges=True) - vistaslices.plot(cmap=cmap) - - # Some variable declarations nx, ny, nz = args.shape nt = args.nt @@ -59,39 +44,32 @@ def plot_3dfunc(u): grid = Grid(shape=(nx, ny, nz), extent=(2., 2., 2.)) u = TimeFunction(name='u', grid=grid, space_order=so) -# init_hat(field=u.data[0], dx=dx, dy=dy, value=2.) -u.data[:, :, :, :] = 0 -u.data[:, :, :, int(nz/2)] = 1 a = Constant(name='a') # Create an equation with second-order derivatives eq = Eq(u.dt, a * u.laplace) - stencil = solve(eq, u.forward) eq_stencil = Eq(u.forward, stencil) -# Create boundary condition expressions -x, y, z = grid.dimensions -t = grid.stepping_dim - -print(eq_stencil) # Create Operator -op = Operator([eq_stencil], name='DevitoOperator') -# Apply the operator for a number of timesteps -op.apply(time=nt, dt=dt, a=nu) -print("Devito Field norm is:", norm(u)) - -# Reset field -u.data[:, :, :, :] = 0 -u.data[:, :, :, int(nz/2)] = 1 -xdslop = XDSLOperator([eq_stencil], name='xDSLOperator') -# Apply the xdsl operator for a number of timesteps -xdslop.apply(time=nt, dt=dt, a=nu) - -if args.plot: - plot_3dfunc(u) - -print("XDSL Field norm is:", norm(u)) - -# import pdb;pdb.set_trace() +if args.devito: + u.data[:, :, :, :] = 0 + u.data[:, :, :, int(nz/2)] = 1 + op = Operator([eq_stencil], name='DevitoOperator') + # Apply the operator for a number of timesteps + op.apply(time=nt, dt=dt, a=nu) + print("Devito Field norm is:", norm(u)) + if args.plot: + plot_3dfunc(u) + +if args.xdsl: + # Reset field + u.data[:, :, :, :] = 0 + u.data[:, :, :, int(nz/2)] = 1 + xdslop = XDSLOperator([eq_stencil], name='xDSLOperator') + # Apply the xdsl operator for a number of timesteps + xdslop.apply(time=nt, dt=dt, a=nu) + print("XDSL Field norm is:", norm(u)) + if args.plot: + plot_3dfunc(u) diff --git a/fast/mfe_2D.py b/fast/mfe_2D.py new file mode 100644 index 0000000000..42e0ec213b --- /dev/null +++ b/fast/mfe_2D.py @@ -0,0 +1,73 @@ +# A 2D heat diffusion using Devito +# BC modelling included +# PyVista plotting included + +import argparse +import numpy as np + +from devito import Grid, TimeFunction, Eq, solve, Operator, Constant, norm, XDSLOperator +from examples.seismic import plot_image +from examples.cfd import init_hat + +parser = argparse.ArgumentParser(description='Process arguments.') + +parser.add_argument("-d", "--shape", default=(11, 11), type=int, nargs="+", + help="Number of grid points along each axis") +parser.add_argument("-so", "--space_order", default=2, + type=int, help="Space order of the simulation") +parser.add_argument("-to", "--time_order", default=1, + type=int, help="Time order of the simulation") +parser.add_argument("-nt", "--nt", default=40, + type=int, help="Simulation time in millisecond") +parser.add_argument("-bls", "--blevels", default=2, type=int, nargs="+", + help="Block levels") +parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot3D") +args = parser.parse_args() + +# Some variable declarations +nx, ny = args.shape +nt = args.nt +nu = .5 +dx = 1. / (nx - 1) +dy = 1. / (ny - 1) +sigma = .25 + +dt = sigma * dx * dy / nu +so = args.space_order +to = args.time_order + +print("dx %s, dy %s" % (dx, dy)) + +grid = Grid(shape=(nx, ny), extent=(2., 2.)) +u = TimeFunction(name='u', grid=grid, space_order=so) + +# Reset our data field and ICs +#init_hat(field=u.data[0], dx=dx, dy=dy, value=1.) +u.data[:, 2:3, 2:3] = 1 + +a = Constant(name='a') +# Create an equation with second-order derivatives +# eq = Eq(u.dt, a * u.laplace, subdomain=grid.interior) +eq = Eq(u.dt, a * u.laplace) +stencil = solve(eq, u.forward) +eq_stencil = Eq(u.forward, stencil) + +# Create boundary condition expressions +x, y = grid.dimensions +t = grid.stepping_dim + +initdata = u.data[:] +op = Operator([eq_stencil], name='DevitoOperator') +op.apply(time=nt, dt=dt, a=nu) +print(u.data[0, :]) +print("Devito Field norm is:", norm(u)) + +u.data[:, : , :] = 0 +u.data[:, 2:3 , 2:3] = 1 +# Reset data and run XDSLOperator +#init_hat(field=u.data[0], dx=dx, dy=dy, value=1.) +xdslop = Operator([eq_stencil], name='XDSLOperator') +xdslop.apply(time=nt, dt=dt, a=nu) +print(u.data[0, :]) + +print("XDSL Field norm is:", norm(u)) diff --git a/fast/setup_wave2d.py b/fast/setup_wave2d.py new file mode 100644 index 0000000000..5492d742fc --- /dev/null +++ b/fast/setup_wave2d.py @@ -0,0 +1,115 @@ +# Script to save initial data for the Acoustic wave execution benchmark +# Based on the implementation of the Devito acoustic example implementation +# Not using Devito's source injection abstraction +import sys +import numpy as np + +from devito import (TimeFunction, Eq, Operator, solve, norm, + XDSLOperator, configuration) +from examples.seismic import RickerSource +from examples.seismic import Model, TimeAxis, plot_image +from fast.bench_utils import plot_2dfunc +from devito.tools import as_tuple + +import argparse +np.set_printoptions(threshold=np.inf) + + +parser = argparse.ArgumentParser(description='Process arguments.') + +parser.add_argument("-d", "--shape", default=(16, 16), type=int, nargs="+", + help="Number of grid points along each axis") +parser.add_argument("-so", "--space_order", default=4, + type=int, help="Space order of the simulation") +parser.add_argument("-to", "--time_order", default=2, + type=int, help="Time order of the simulation") +parser.add_argument("-nt", "--nt", default=20, + type=int, help="Simulation time in millisecond") +parser.add_argument("-bls", "--blevels", default=1, type=int, nargs="+", + help="Block levels") +parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D") +parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run") +parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run") +args = parser.parse_args() + + +mpiconf = configuration['mpi'] + +# Define a physical size +# nx, ny, nz = args.shape +nt = args.nt + +shape = (args.shape) # Number of grid point (nx, ny, nz) +spacing = as_tuple(10.0 for _ in range(len(shape))) # Grid spacing in m. The domain size is now 1km by 1km +origin = as_tuple(0.0 for _ in range(len(shape))) # What is the location of the top left corner. +# This is necessary to define +# the absolute location of the source and receivers + +# Define a velocity profile. The velocity is in km/s +v = np.empty(shape, dtype=np.float32) +v[:, :] = 1 + +# With the velocity and model size defined, we can create the seismic model that +# encapsulates this properties. We also define the size of the absorbing layer as +# 10 grid points +so = args.space_order +to = args.time_order + +model = Model(vp=v, origin=origin, shape=shape, spacing=spacing, + space_order=so, nbl=0) + +# plot_velocity(model) + +t0 = 0. # Simulation starts a t=0 +tn = nt # Simulation last 1 second (1000 ms) +dt = model.critical_dt # Time step from model grid spacing +print("dt is:", dt) + +time_range = TimeAxis(start=t0, stop=tn, step=dt) + +# The source is positioned at a $20m$ depth and at the middle of the +# $x$ axis ($x_{src}=500m$), +# with a peak wavelet frequency of $10Hz$. +f0 = 0.010 # Source peak frequency is 10Hz (0.010 kHz) +src = RickerSource(name='src', grid=model.grid, f0=f0, + npoint=1, time_range=time_range) + +# First, position source centrally in all dimensions, then set depth +src.coordinates.data[0, :] = np.array(model.domain_size) * .5 + +# We can plot the time signature to see the wavelet +# src.show() + +# Define the wavefield with the size of the model and the time dimension +u = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so) +# Another one to clone data +u2 = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so) +ub = TimeFunction(name="ub", grid=model.grid, time_order=to, space_order=so) + +# We can now write the PDE +# pde = model.m * u.dt2 - u.laplace + model.damp * u.dt +pde = u.dt2 - u.laplace + +stencil = Eq(u.forward, solve(pde, u.forward)) +# stencil + +# Finally we define the source injection and receiver read function to generate +# the corresponding code +# print(time_range) + +print("Init norm:", np.linalg.norm(u.data[:])) +src_term = src.inject(field=u.forward, expr=src * dt**2 / model.m) +op0 = Operator([stencil] + src_term, subs=model.spacing_map, name='SourceDevitoOperator') + +# Run with source and plot +op0.apply(time=time_range.num-1, dt=model.critical_dt) + +if len(shape) == 2: + if args.plot: + plot_2dfunc(u) + +# Save Data here +shape_str = '_'.join(str(item) for item in shape) +np.save("so%s_critical_dt%s.npy" % (so, shape_str), model.critical_dt, allow_pickle=True) +np.save("so%s_wave_dat%s.npy" % (so, shape_str), u.data[:], allow_pickle=True) +np.save("so%s_grid_extent%s.npy" % (so, shape_str), model.grid.extent, allow_pickle=True) diff --git a/fast/setup_wave3d.py b/fast/setup_wave3d.py new file mode 100644 index 0000000000..797d8257a8 --- /dev/null +++ b/fast/setup_wave3d.py @@ -0,0 +1,117 @@ +# Script to save initial data for the Acoustic wave execution benchmark +# Based on the implementation of the Devito acoustic example implementation +# Not using Devito's source injection abstraction +import sys +import numpy as np + +from devito import (TimeFunction, Eq, Operator, solve, norm, + configuration) +from examples.seismic import RickerSource +from examples.seismic import Model, TimeAxis +from fast.bench_utils import plot_3dfunc +from devito.tools import as_tuple + +import argparse +np.set_printoptions(threshold=np.inf) + + +parser = argparse.ArgumentParser(description='Process arguments.') + +parser.add_argument("-d", "--shape", default=(16, 16, 16), type=int, nargs="+", + help="Number of grid points along each axis") +parser.add_argument("-so", "--space_order", default=4, + type=int, help="Space order of the simulation") +parser.add_argument("-to", "--time_order", default=2, + type=int, help="Time order of the simulation") +parser.add_argument("-nt", "--nt", default=20, + type=int, help="Simulation time in millisecond") +parser.add_argument("-bls", "--blevels", default=1, type=int, nargs="+", + help="Block levels") +parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D") +parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run") +parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run") +args = parser.parse_args() + + +mpiconf = configuration['mpi'] + +# Define a physical size +# nx, ny, nz = args.shape +nt = args.nt + +shape = (args.shape) # Number of grid point (nx, ny, nz) +spacing = as_tuple(10.0 for _ in range(len(shape))) # Grid spacing in m. The domain size is now 1km by 1km +origin = as_tuple(0.0 for _ in range(len(shape))) # What is the location of the top left corner. +# This is necessary to define +# the absolute location of the source and receivers + +# Define a velocity profile. The velocity is in km/s +v = np.empty(shape, dtype=np.float32) +v[:, :, :] = 1 + +# With the velocity and model size defined, we can create the seismic model that +# encapsulates this properties. We also define the size of the absorbing layer as +# 10 grid points +so = args.space_order +to = args.time_order + +model = Model(vp=v, origin=origin, shape=shape, spacing=spacing, + space_order=so, nbl=0) + +# plot_velocity(model) + +t0 = 0. # Simulation starts a t=0 +tn = nt # Simulation last 1 second (1000 ms) +dt = model.critical_dt # Time step from model grid spacing +print("dt is:", dt) + +time_range = TimeAxis(start=t0, stop=tn, step=dt) + +# The source is positioned at a $20m$ depth and at the middle of the +# $x$ axis ($x_{src}=500m$), +# with a peak wavelet frequency of $10Hz$. +f0 = 0.010 # Source peak frequency is 10Hz (0.010 kHz) +src = RickerSource(name='src', grid=model.grid, f0=f0, + npoint=1, time_range=time_range) + +# First, position source centrally in all dimensions, then set depth +src.coordinates.data[0, :] = np.array(model.domain_size) * .5 + +# We can plot the time signature to see the wavelet +# src.show() + +# Define the wavefield with the size of the model and the time dimension +u = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so) +# Another one to clone data +u2 = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so) +ub = TimeFunction(name="ub", grid=model.grid, time_order=to, space_order=so) + +# We can now write the PDE +# pde = model.m * u.dt2 - u.laplace + model.damp * u.dt +pde = u.dt2 - u.laplace + +stencil = Eq(u.forward, solve(pde, u.forward)) +# stencil + +# Finally we define the source injection and receiver read function to generate +# the corresponding code +# print(time_range) + +print("Init norm:", np.linalg.norm(u.data[:])) +src_term = src.inject(field=u.forward, expr=src * dt**2 / model.m) +op0 = Operator([stencil] + src_term, subs=model.spacing_map, name='SourceDevitoOperator') + +# Run with source and plot +op0.apply(time=time_range.num-1, dt=model.critical_dt) + +if len(shape) == 3: + if args.plot: + plot_3dfunc(u) + +# Save Data here +shape_str = '_'.join(str(item) for item in shape) +np.save("so%s_critical_dt%s.npy" % (so, shape_str), model.critical_dt, allow_pickle=True) +np.savez_compressed("so%s_wave_dat%s" % (so, shape_str), u.data[:], allow_pickle=True) + +np.savez_compressed("so%s_grid_extent%s" % (so, shape_str), model.grid.extent, + allow_pickle=True) diff --git a/fast/temp1 b/fast/temp1 deleted file mode 100644 index 46f207a419..0000000000 --- a/fast/temp1 +++ /dev/null @@ -1,201 +0,0 @@ -module { - func.func @apply_kernel(%arg0: memref<260x260xf32>, %arg1: memref<260x260xf32>) -> memref<260x260xf32> attributes {param_names = ["u_vec_0", "u_vec_1"]} { - %c28_i64 = arith.constant 28 : i64 - %c12_i64 = arith.constant 12 : i64 - %c24_i64 = arith.constant 24 : i64 - %c8_i64 = arith.constant 8 : i64 - %c20_i64 = arith.constant 20 : i64 - %c16_i64 = arith.constant 16 : i64 - %c257 = arith.constant 257 : index - %cst = arith.constant 1.000000e-01 : f32 - %cst_0 = arith.constant -2.000000e+00 : f32 - %c-2_i64 = arith.constant -2 : i64 - %cst_1 = arith.constant 0.00392156886 : f32 - %c-1_i64 = arith.constant -1 : i64 - %cst_2 = arith.constant 3.075740e-05 : f32 - %cst_3 = arith.constant 0.00999999977 : f32 - %c-1 = arith.constant -1 : index - %c64 = arith.constant 64 : index - %c738197504_i32 = arith.constant 738197504 : i32 - %c4_i64 = arith.constant 4 : i64 - %c-1_i32 = arith.constant -1 : i32 - %c4_i32 = arith.constant 4 : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c1275069450_i32 = arith.constant 1275069450 : i32 - %c66_i32 = arith.constant 66 : i32 - %c1_i64 = arith.constant 1 : i64 - %c1140850688_i32 = arith.constant 1140850688 : i32 - %c8_i32 = arith.constant 8 : i32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = llvm.alloca %c8_i32 x i32 {alignment = 32 : i64} : (i32) -> !llvm.ptr - %1 = llvm.alloca %c1_i64 x i32 {alignment = 32 : i64} : (i64) -> !llvm.ptr - %2 = call @MPI_Comm_rank(%c1140850688_i32, %1) : (i32, !llvm.ptr) -> i32 - %3 = llvm.load %1 : !llvm.ptr - %alloc = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr = memref.extract_aligned_pointer_as_index %alloc : memref<66xf32> -> index - %4 = arith.index_cast %intptr : index to i64 - %5 = llvm.inttoptr %4 : i64 to !llvm.ptr - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_5 = memref.extract_aligned_pointer_as_index %alloc_4 : memref<66xf32> -> index - %6 = arith.index_cast %intptr_5 : index to i64 - %7 = llvm.inttoptr %6 : i64 to !llvm.ptr - %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_7 = memref.extract_aligned_pointer_as_index %alloc_6 : memref<66xf32> -> index - %8 = arith.index_cast %intptr_7 : index to i64 - %9 = llvm.inttoptr %8 : i64 to !llvm.ptr - %alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_9 = memref.extract_aligned_pointer_as_index %alloc_8 : memref<66xf32> -> index - %10 = arith.index_cast %intptr_9 : index to i64 - %11 = llvm.inttoptr %10 : i64 to !llvm.ptr - %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_11 = memref.extract_aligned_pointer_as_index %alloc_10 : memref<66xf32> -> index - %12 = arith.index_cast %intptr_11 : index to i64 - %13 = llvm.inttoptr %12 : i64 to !llvm.ptr - %alloc_12 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_13 = memref.extract_aligned_pointer_as_index %alloc_12 : memref<66xf32> -> index - %14 = arith.index_cast %intptr_13 : index to i64 - %15 = llvm.inttoptr %14 : i64 to !llvm.ptr - %alloc_14 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_15 = memref.extract_aligned_pointer_as_index %alloc_14 : memref<66xf32> -> index - %16 = arith.index_cast %intptr_15 : index to i64 - %17 = llvm.inttoptr %16 : i64 to !llvm.ptr - %alloc_16 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_17 = memref.extract_aligned_pointer_as_index %alloc_16 : memref<66xf32> -> index - %18 = arith.index_cast %intptr_17 : index to i64 - %19 = llvm.inttoptr %18 : i64 to !llvm.ptr - %20 = arith.remui %3, %c4_i32 : i32 - %21 = arith.divui %3, %c4_i32 : i32 - %22 = arith.remui %21, %c4_i32 : i32 - %23 = arith.addi %22, %c-1_i32 : i32 - %24 = arith.cmpi sge, %23, %c0_i32 : i32 - %25 = arith.muli %23, %c4_i32 : i32 - %26 = arith.addi %20, %25 : i32 - %27 = llvm.ptrtoint %0 : !llvm.ptr to i64 - %28 = llvm.inttoptr %27 : i64 to !llvm.ptr - %29 = arith.addi %27, %c16_i64 : i64 - %30 = llvm.inttoptr %29 : i64 to !llvm.ptr - %31 = arith.addi %22, %c1_i32 : i32 - %32 = arith.cmpi slt, %31, %c4_i32 : i32 - %33 = arith.muli %31, %c4_i32 : i32 - %34 = arith.addi %20, %33 : i32 - %35 = arith.addi %27, %c4_i64 : i64 - %36 = llvm.inttoptr %35 : i64 to !llvm.ptr - %37 = arith.addi %27, %c20_i64 : i64 - %38 = llvm.inttoptr %37 : i64 to !llvm.ptr - %39 = arith.addi %20, %c-1_i32 : i32 - %40 = arith.cmpi sge, %39, %c0_i32 : i32 - %41 = arith.muli %22, %c4_i32 : i32 - %42 = arith.addi %39, %41 : i32 - %43 = arith.addi %27, %c8_i64 : i64 - %44 = llvm.inttoptr %43 : i64 to !llvm.ptr - %45 = arith.addi %27, %c24_i64 : i64 - %46 = llvm.inttoptr %45 : i64 to !llvm.ptr - %47 = arith.addi %20, %c1_i32 : i32 - %48 = arith.cmpi slt, %47, %c4_i32 : i32 - %49 = arith.addi %47, %41 : i32 - %50 = arith.addi %27, %c12_i64 : i64 - %51 = llvm.inttoptr %50 : i64 to !llvm.ptr - %52 = arith.addi %27, %c28_i64 : i64 - %53 = llvm.inttoptr %52 : i64 to !llvm.ptr - %54 = llvm.inttoptr %c1_i64 : i64 to !llvm.ptr - %55 = math.fpowi %cst_2, %c-1_i64 : f32, i64 - %56 = math.fpowi %cst_1, %c-2_i64 : f32, i64 - %57 = arith.mulf %56, %cst_0 : f32 - %58:2 = scf.for %arg2 = %c0 to %c257 step %c1 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (memref<260x260xf32>, memref<260x260xf32>) { - %subview = memref.subview %arg4[2, 2] [64, 64] [1, 1] : memref<260x260xf32> to memref<64x64xf32, strided<[260, 1], offset: 522>> - %subview_18 = memref.subview %arg3[2, 2] [66, 66] [1, 1] : memref<260x260xf32> to memref<66x66xf32, strided<[260, 1], offset: 522>> - scf.if %24 { - %subview_19 = memref.subview %subview_18[-1, 0] [66, 1] [1, 1] : memref<66x66xf32, strided<[260, 1], offset: 522>> to memref<66xf32, strided<[260], offset: 262>> - memref.copy %subview_19, %alloc : memref<66xf32, strided<[260], offset: 262>> to memref<66xf32> - %60 = func.call @MPI_Isend(%5, %c66_i32, %c1275069450_i32, %26, %c0_i32, %c1140850688_i32, %28) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - %61 = func.call @MPI_Irecv(%7, %c66_i32, %c1275069450_i32, %26, %c0_i32, %c1140850688_i32, %30) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - } else { - llvm.store %c738197504_i32, %28 : !llvm.ptr - llvm.store %c738197504_i32, %30 : !llvm.ptr - } - scf.if %32 { - %subview_19 = memref.subview %subview_18[-1, 63] [66, 1] [1, 1] : memref<66x66xf32, strided<[260, 1], offset: 522>> to memref<66xf32, strided<[260], offset: 325>> - memref.copy %subview_19, %alloc_6 : memref<66xf32, strided<[260], offset: 325>> to memref<66xf32> - %60 = func.call @MPI_Isend(%9, %c66_i32, %c1275069450_i32, %34, %c0_i32, %c1140850688_i32, %36) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - %61 = func.call @MPI_Irecv(%11, %c66_i32, %c1275069450_i32, %34, %c0_i32, %c1140850688_i32, %38) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - } else { - llvm.store %c738197504_i32, %36 : !llvm.ptr - llvm.store %c738197504_i32, %38 : !llvm.ptr - } - scf.if %40 { - %subview_19 = memref.subview %subview_18[0, -1] [1, 66] [1, 1] : memref<66x66xf32, strided<[260, 1], offset: 522>> to memref<66xf32, strided<[1], offset: 521>> - memref.copy %subview_19, %alloc_10 : memref<66xf32, strided<[1], offset: 521>> to memref<66xf32> - %60 = func.call @MPI_Isend(%13, %c66_i32, %c1275069450_i32, %42, %c0_i32, %c1140850688_i32, %44) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - %61 = func.call @MPI_Irecv(%15, %c66_i32, %c1275069450_i32, %42, %c0_i32, %c1140850688_i32, %46) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - } else { - llvm.store %c738197504_i32, %44 : !llvm.ptr - llvm.store %c738197504_i32, %46 : !llvm.ptr - } - scf.if %48 { - %subview_19 = memref.subview %subview_18[63, -1] [1, 66] [1, 1] : memref<66x66xf32, strided<[260, 1], offset: 522>> to memref<66xf32, strided<[1], offset: 16901>> - memref.copy %subview_19, %alloc_14 : memref<66xf32, strided<[1], offset: 16901>> to memref<66xf32> - %60 = func.call @MPI_Isend(%17, %c66_i32, %c1275069450_i32, %49, %c0_i32, %c1140850688_i32, %51) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - %61 = func.call @MPI_Irecv(%19, %c66_i32, %c1275069450_i32, %49, %c0_i32, %c1140850688_i32, %53) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - } else { - llvm.store %c738197504_i32, %51 : !llvm.ptr - llvm.store %c738197504_i32, %53 : !llvm.ptr - } - %59 = func.call @MPI_Waitall(%c8_i32, %0, %54) : (i32, !llvm.ptr, !llvm.ptr) -> i32 - scf.if %24 { - %subview_19 = memref.subview %subview_18[-1, -1] [66, 1] [1, 1] : memref<66x66xf32, strided<[260, 1], offset: 522>> to memref<66xf32, strided<[260], offset: 261>> - memref.copy %subview_19, %alloc_4 : memref<66xf32, strided<[260], offset: 261>> to memref<66xf32> - } - scf.if %32 { - %subview_19 = memref.subview %subview_18[-1, 64] [66, 1] [1, 1] : memref<66x66xf32, strided<[260, 1], offset: 522>> to memref<66xf32, strided<[260], offset: 326>> - memref.copy %subview_19, %alloc_8 : memref<66xf32, strided<[260], offset: 326>> to memref<66xf32> - } - scf.if %40 { - %subview_19 = memref.subview %subview_18[-1, -1] [1, 66] [1, 1] : memref<66x66xf32, strided<[260, 1], offset: 522>> to memref<66xf32, strided<[1], offset: 261>> - memref.copy %subview_19, %alloc_12 : memref<66xf32, strided<[1], offset: 261>> to memref<66xf32> - } - scf.if %48 { - %subview_19 = memref.subview %subview_18[64, -1] [1, 66] [1, 1] : memref<66x66xf32, strided<[260, 1], offset: 522>> to memref<66xf32, strided<[1], offset: 17161>> - memref.copy %subview_19, %alloc_16 : memref<66xf32, strided<[1], offset: 17161>> to memref<66xf32> - } - scf.parallel (%arg5) = (%c0) to (%c64) step (%c1) { - %60 = arith.addi %arg5, %c-1 : index - %61 = arith.addi %arg5, %c1 : index - scf.for %arg6 = %c0 to %c64 step %c1 { - %62 = memref.load %subview_18[%arg5, %arg6] : memref<66x66xf32, strided<[260, 1], offset: 522>> - %63 = memref.load %subview_18[%60, %arg6] : memref<66x66xf32, strided<[260, 1], offset: 522>> - %64 = memref.load %subview_18[%61, %arg6] : memref<66x66xf32, strided<[260, 1], offset: 522>> - %65 = arith.addi %arg6, %c-1 : index - %66 = memref.load %subview_18[%arg5, %65] : memref<66x66xf32, strided<[260, 1], offset: 522>> - %67 = arith.addi %arg6, %c1 : index - %68 = memref.load %subview_18[%arg5, %67] : memref<66x66xf32, strided<[260, 1], offset: 522>> - %69 = arith.mulf %55, %62 : f32 - %70 = arith.mulf %56, %63 : f32 - %71 = arith.mulf %56, %64 : f32 - %72 = arith.mulf %57, %62 : f32 - %73 = arith.addf %70, %71 : f32 - %74 = arith.addf %73, %72 : f32 - %75 = arith.mulf %56, %66 : f32 - %76 = arith.mulf %56, %68 : f32 - %77 = arith.addf %75, %76 : f32 - %78 = arith.addf %77, %72 : f32 - %79 = arith.addf %74, %78 : f32 - %80 = arith.mulf %79, %cst : f32 - %81 = arith.addf %69, %cst_3 : f32 - %82 = arith.addf %81, %80 : f32 - %83 = arith.mulf %82, %cst_2 : f32 - memref.store %83, %subview[%arg5, %arg6] : memref<64x64xf32, strided<[260, 1], offset: 522>> - } - scf.yield - } - scf.yield %arg4, %arg3 : memref<260x260xf32>, memref<260x260xf32> - } - return %58#0 : memref<260x260xf32> - } - func.func private @MPI_Comm_rank(i32, !llvm.ptr) -> i32 - func.func private @MPI_Isend(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - func.func private @MPI_Irecv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - func.func private @MPI_Waitall(i32, !llvm.ptr, !llvm.ptr) -> i32 -} - diff --git a/fast/temp2 b/fast/temp2 deleted file mode 100644 index 40783a347a..0000000000 --- a/fast/temp2 +++ /dev/null @@ -1,212 +0,0 @@ -#map = affine_map<()[s0] -> (s0 + 2)> -module { - func.func @apply_kernel(%arg0: memref<260x260xf32>, %arg1: memref<260x260xf32>) -> memref<260x260xf32> attributes {param_names = ["u_vec_0", "u_vec_1"]} { - %c28_i64 = arith.constant 28 : i64 - %c12_i64 = arith.constant 12 : i64 - %c24_i64 = arith.constant 24 : i64 - %c8_i64 = arith.constant 8 : i64 - %c20_i64 = arith.constant 20 : i64 - %c16_i64 = arith.constant 16 : i64 - %c257 = arith.constant 257 : index - %cst = arith.constant 1.000000e-01 : f32 - %cst_0 = arith.constant -2.000000e+00 : f32 - %c-2_i64 = arith.constant -2 : i64 - %cst_1 = arith.constant 0.00392156886 : f32 - %c-1_i64 = arith.constant -1 : i64 - %cst_2 = arith.constant 3.075740e-05 : f32 - %cst_3 = arith.constant 0.00999999977 : f32 - %c-1 = arith.constant -1 : index - %c64 = arith.constant 64 : index - %c738197504_i32 = arith.constant 738197504 : i32 - %c4_i64 = arith.constant 4 : i64 - %c-1_i32 = arith.constant -1 : i32 - %c4_i32 = arith.constant 4 : i32 - %c1_i32 = arith.constant 1 : i32 - %c0_i32 = arith.constant 0 : i32 - %c1275069450_i32 = arith.constant 1275069450 : i32 - %c66_i32 = arith.constant 66 : i32 - %c1_i64 = arith.constant 1 : i64 - %c1140850688_i32 = arith.constant 1140850688 : i32 - %c8_i32 = arith.constant 8 : i32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = llvm.alloca %c8_i32 x i32 {alignment = 32 : i64} : (i32) -> !llvm.ptr - %1 = llvm.alloca %c1_i64 x i32 {alignment = 32 : i64} : (i64) -> !llvm.ptr - %2 = call @MPI_Comm_rank(%c1140850688_i32, %1) : (i32, !llvm.ptr) -> i32 - %3 = llvm.load %1 : !llvm.ptr - %alloc = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr = memref.extract_aligned_pointer_as_index %alloc : memref<66xf32> -> index - %4 = arith.index_cast %intptr : index to i64 - %5 = llvm.inttoptr %4 : i64 to !llvm.ptr - %alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_5 = memref.extract_aligned_pointer_as_index %alloc_4 : memref<66xf32> -> index - %6 = arith.index_cast %intptr_5 : index to i64 - %7 = llvm.inttoptr %6 : i64 to !llvm.ptr - %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_7 = memref.extract_aligned_pointer_as_index %alloc_6 : memref<66xf32> -> index - %8 = arith.index_cast %intptr_7 : index to i64 - %9 = llvm.inttoptr %8 : i64 to !llvm.ptr - %alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_9 = memref.extract_aligned_pointer_as_index %alloc_8 : memref<66xf32> -> index - %10 = arith.index_cast %intptr_9 : index to i64 - %11 = llvm.inttoptr %10 : i64 to !llvm.ptr - %alloc_10 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_11 = memref.extract_aligned_pointer_as_index %alloc_10 : memref<66xf32> -> index - %12 = arith.index_cast %intptr_11 : index to i64 - %13 = llvm.inttoptr %12 : i64 to !llvm.ptr - %alloc_12 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_13 = memref.extract_aligned_pointer_as_index %alloc_12 : memref<66xf32> -> index - %14 = arith.index_cast %intptr_13 : index to i64 - %15 = llvm.inttoptr %14 : i64 to !llvm.ptr - %alloc_14 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_15 = memref.extract_aligned_pointer_as_index %alloc_14 : memref<66xf32> -> index - %16 = arith.index_cast %intptr_15 : index to i64 - %17 = llvm.inttoptr %16 : i64 to !llvm.ptr - %alloc_16 = memref.alloc() {alignment = 64 : i64} : memref<66xf32> - %intptr_17 = memref.extract_aligned_pointer_as_index %alloc_16 : memref<66xf32> -> index - %18 = arith.index_cast %intptr_17 : index to i64 - %19 = llvm.inttoptr %18 : i64 to !llvm.ptr - %20 = arith.remui %3, %c4_i32 : i32 - %21 = arith.divui %3, %c4_i32 : i32 - %22 = arith.remui %21, %c4_i32 : i32 - %23 = arith.addi %22, %c-1_i32 : i32 - %24 = arith.cmpi sge, %23, %c0_i32 : i32 - %25 = arith.muli %23, %c4_i32 : i32 - %26 = arith.addi %20, %25 : i32 - %27 = llvm.ptrtoint %0 : !llvm.ptr to i64 - %28 = llvm.inttoptr %27 : i64 to !llvm.ptr - %29 = arith.addi %27, %c16_i64 : i64 - %30 = llvm.inttoptr %29 : i64 to !llvm.ptr - %31 = arith.addi %22, %c1_i32 : i32 - %32 = arith.cmpi slt, %31, %c4_i32 : i32 - %33 = arith.muli %31, %c4_i32 : i32 - %34 = arith.addi %20, %33 : i32 - %35 = arith.addi %27, %c4_i64 : i64 - %36 = llvm.inttoptr %35 : i64 to !llvm.ptr - %37 = arith.addi %27, %c20_i64 : i64 - %38 = llvm.inttoptr %37 : i64 to !llvm.ptr - %39 = arith.addi %20, %c-1_i32 : i32 - %40 = arith.cmpi sge, %39, %c0_i32 : i32 - %41 = arith.muli %22, %c4_i32 : i32 - %42 = arith.addi %39, %41 : i32 - %43 = arith.addi %27, %c8_i64 : i64 - %44 = llvm.inttoptr %43 : i64 to !llvm.ptr - %45 = arith.addi %27, %c24_i64 : i64 - %46 = llvm.inttoptr %45 : i64 to !llvm.ptr - %47 = arith.addi %20, %c1_i32 : i32 - %48 = arith.cmpi slt, %47, %c4_i32 : i32 - %49 = arith.addi %47, %41 : i32 - %50 = arith.addi %27, %c12_i64 : i64 - %51 = llvm.inttoptr %50 : i64 to !llvm.ptr - %52 = arith.addi %27, %c28_i64 : i64 - %53 = llvm.inttoptr %52 : i64 to !llvm.ptr - %54 = llvm.inttoptr %c1_i64 : i64 to !llvm.ptr - %55 = math.fpowi %cst_2, %c-1_i64 : f32, i64 - %56 = math.fpowi %cst_1, %c-2_i64 : f32, i64 - %57 = arith.mulf %56, %cst_0 : f32 - %58:2 = scf.for %arg2 = %c0 to %c257 step %c1 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (memref<260x260xf32>, memref<260x260xf32>) { - scf.if %24 { - %subview = memref.subview %arg3[1, 2] [66, 1] [1, 1] : memref<260x260xf32> to memref<66xf32, strided<[260], offset: 262>> - memref.copy %subview, %alloc : memref<66xf32, strided<[260], offset: 262>> to memref<66xf32> - %60 = func.call @MPI_Isend(%5, %c66_i32, %c1275069450_i32, %26, %c0_i32, %c1140850688_i32, %28) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - %61 = func.call @MPI_Irecv(%7, %c66_i32, %c1275069450_i32, %26, %c0_i32, %c1140850688_i32, %30) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - } else { - llvm.store %c738197504_i32, %28 : !llvm.ptr - llvm.store %c738197504_i32, %30 : !llvm.ptr - } - scf.if %32 { - %subview = memref.subview %arg3[1, 65] [66, 1] [1, 1] : memref<260x260xf32> to memref<66xf32, strided<[260], offset: 325>> - memref.copy %subview, %alloc_6 : memref<66xf32, strided<[260], offset: 325>> to memref<66xf32> - %60 = func.call @MPI_Isend(%9, %c66_i32, %c1275069450_i32, %34, %c0_i32, %c1140850688_i32, %36) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - %61 = func.call @MPI_Irecv(%11, %c66_i32, %c1275069450_i32, %34, %c0_i32, %c1140850688_i32, %38) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - } else { - llvm.store %c738197504_i32, %36 : !llvm.ptr - llvm.store %c738197504_i32, %38 : !llvm.ptr - } - scf.if %40 { - %subview = memref.subview %arg3[2, 1] [1, 66] [1, 1] : memref<260x260xf32> to memref<66xf32, strided<[1], offset: 521>> - memref.copy %subview, %alloc_10 : memref<66xf32, strided<[1], offset: 521>> to memref<66xf32> - %60 = func.call @MPI_Isend(%13, %c66_i32, %c1275069450_i32, %42, %c0_i32, %c1140850688_i32, %44) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - %61 = func.call @MPI_Irecv(%15, %c66_i32, %c1275069450_i32, %42, %c0_i32, %c1140850688_i32, %46) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - } else { - llvm.store %c738197504_i32, %44 : !llvm.ptr - llvm.store %c738197504_i32, %46 : !llvm.ptr - } - scf.if %48 { - %subview = memref.subview %arg3[65, 1] [1, 66] [1, 1] : memref<260x260xf32> to memref<66xf32, strided<[1], offset: 16901>> - memref.copy %subview, %alloc_14 : memref<66xf32, strided<[1], offset: 16901>> to memref<66xf32> - %60 = func.call @MPI_Isend(%17, %c66_i32, %c1275069450_i32, %49, %c0_i32, %c1140850688_i32, %51) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - %61 = func.call @MPI_Irecv(%19, %c66_i32, %c1275069450_i32, %49, %c0_i32, %c1140850688_i32, %53) : (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - } else { - llvm.store %c738197504_i32, %51 : !llvm.ptr - llvm.store %c738197504_i32, %53 : !llvm.ptr - } - %59 = func.call @MPI_Waitall(%c8_i32, %0, %54) : (i32, !llvm.ptr, !llvm.ptr) -> i32 - scf.if %24 { - %subview = memref.subview %arg3[1, 1] [66, 1] [1, 1] : memref<260x260xf32> to memref<66xf32, strided<[260], offset: 261>> - memref.copy %subview, %alloc_4 : memref<66xf32, strided<[260], offset: 261>> to memref<66xf32> - } - scf.if %32 { - %subview = memref.subview %arg3[1, 66] [66, 1] [1, 1] : memref<260x260xf32> to memref<66xf32, strided<[260], offset: 326>> - memref.copy %subview, %alloc_8 : memref<66xf32, strided<[260], offset: 326>> to memref<66xf32> - } - scf.if %40 { - %subview = memref.subview %arg3[1, 1] [1, 66] [1, 1] : memref<260x260xf32> to memref<66xf32, strided<[1], offset: 261>> - memref.copy %subview, %alloc_12 : memref<66xf32, strided<[1], offset: 261>> to memref<66xf32> - } - scf.if %48 { - %subview = memref.subview %arg3[66, 1] [1, 66] [1, 1] : memref<260x260xf32> to memref<66xf32, strided<[1], offset: 17161>> - memref.copy %subview, %alloc_16 : memref<66xf32, strided<[1], offset: 17161>> to memref<66xf32> - } - scf.parallel (%arg5) = (%c0) to (%c64) step (%c1) { - %60 = arith.addi %arg5, %c-1 : index - %61 = arith.addi %arg5, %c1 : index - scf.for %arg6 = %c0 to %c64 step %c1 { - %62 = affine.apply #map()[%arg5] - %63 = affine.apply #map()[%arg6] - %64 = memref.load %arg3[%62, %63] : memref<260x260xf32> - %65 = affine.apply #map()[%60] - %66 = affine.apply #map()[%arg6] - %67 = memref.load %arg3[%65, %66] : memref<260x260xf32> - %68 = affine.apply #map()[%61] - %69 = affine.apply #map()[%arg6] - %70 = memref.load %arg3[%68, %69] : memref<260x260xf32> - %71 = arith.addi %arg6, %c-1 : index - %72 = affine.apply #map()[%arg5] - %73 = affine.apply #map()[%71] - %74 = memref.load %arg3[%72, %73] : memref<260x260xf32> - %75 = arith.addi %arg6, %c1 : index - %76 = affine.apply #map()[%arg5] - %77 = affine.apply #map()[%75] - %78 = memref.load %arg3[%76, %77] : memref<260x260xf32> - %79 = arith.mulf %55, %64 : f32 - %80 = arith.mulf %56, %67 : f32 - %81 = arith.mulf %56, %70 : f32 - %82 = arith.mulf %57, %64 : f32 - %83 = arith.addf %80, %81 : f32 - %84 = arith.addf %83, %82 : f32 - %85 = arith.mulf %56, %74 : f32 - %86 = arith.mulf %56, %78 : f32 - %87 = arith.addf %85, %86 : f32 - %88 = arith.addf %87, %82 : f32 - %89 = arith.addf %84, %88 : f32 - %90 = arith.mulf %89, %cst : f32 - %91 = arith.addf %79, %cst_3 : f32 - %92 = arith.addf %91, %90 : f32 - %93 = arith.mulf %92, %cst_2 : f32 - %94 = affine.apply #map()[%arg5] - %95 = affine.apply #map()[%arg6] - memref.store %93, %arg4[%94, %95] : memref<260x260xf32> - } - scf.yield - } - scf.yield %arg4, %arg3 : memref<260x260xf32>, memref<260x260xf32> - } - return %58#0 : memref<260x260xf32> - } - func.func private @MPI_Comm_rank(i32, !llvm.ptr) -> i32 - func.func private @MPI_Isend(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - func.func private @MPI_Irecv(!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32 - func.func private @MPI_Waitall(i32, !llvm.ptr, !llvm.ptr) -> i32 -} - diff --git a/fast/wave2d.py b/fast/wave2d.py index 409bbfb33b..2a7855a6e9 100644 --- a/fast/wave2d.py +++ b/fast/wave2d.py @@ -2,10 +2,12 @@ # Not using Devito's source injection abstraction import sys import numpy as np -from devito import TimeFunction, Eq, Operator, solve, norm, XDSLOperator -from examples.seismic import RickerSource -from examples.seismic import Model, TimeAxis +from devito import (TimeFunction, Eq, Operator, solve, norm, + XDSLOperator, configuration) +from examples.seismic import RickerSource +from examples.seismic import Model, TimeAxis, plot_image +from fast.bench_utils import plot_2dfunc from devito.tools import as_tuple import argparse @@ -25,25 +27,12 @@ parser.add_argument("-bls", "--blevels", default=1, type=int, nargs="+", help="Block levels") parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D") +parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run") +parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run") args = parser.parse_args() -def plot_2dfunc(u): - # Plot a 3D structured grid using pyvista - - import matplotlib.pyplot as plt - import pyvista as pv - cmap = plt.colormaps["viridis"] - values = u.data[0, :, :, :] - vistagrid = pv.UniformGrid() - vistagrid.dimensions = np.array(values.shape) + 1 - vistagrid.spacing = (1, 1, 1) - vistagrid.origin = (0, 0, 0) # The bottom left corner of the data set - vistagrid.cell_data["values"] = values.flatten(order="F") - vistaslices = vistagrid.slice_orthogonal() - # vistagrid.plot(show_edges=True) - vistaslices.plot(cmap=cmap) - +mpiconf = configuration['mpi'] # Define a physical size # nx, ny, nz = args.shape @@ -92,27 +81,26 @@ def plot_2dfunc(u): # Define the wavefield with the size of the model and the time dimension u = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so) - +# Another one to clone data u2 = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so) +ub = TimeFunction(name="ub", grid=model.grid, time_order=to, space_order=so) # We can now write the PDE # pde = model.m * u.dt2 - u.laplace + model.damp * u.dt # import pdb;pdb.set_trace() pde = u.dt2 - u.laplace -# The PDE representation is as on paper -pde - stencil = Eq(u.forward, solve(pde, u.forward)) -stencil +# stencil # Finally we define the source injection and receiver read function to generate # the corresponding code -print(time_range) +# print(time_range) -print("Init norm:", norm(u)) +print("Init norm:", np.linalg.norm(u.data[:])) src_term = src.inject(field=u.forward, expr=src * dt**2 / model.m) op0 = Operator([stencil] + src_term, subs=model.spacing_map, name='SourceDevitoOperator') + # Run with source and plot op0.apply(time=time_range.num-1, dt=model.critical_dt) @@ -120,51 +108,52 @@ def plot_2dfunc(u): if args.plot: plot_2dfunc(u) -print("Init Devito linalg norm 0 :", np.linalg.norm(u.data[0])) -print("Init Devito linalg norm 1 :", np.linalg.norm(u.data[1])) -print("Init Devito linalg norm 2 :", np.linalg.norm(u.data[2])) +import pdb;pdb.set_trace() +# print("Init Devito linalg norm 0 :", np.linalg.norm(u.data[0])) +# print("Init Devito linalg norm 1 :", np.linalg.norm(u.data[1])) +# print("Init Devito linalg norm 2 :", np.linalg.norm(u.data[2])) +# print("Norm of initial data:", norm(u)) -print("Norm of initial data:", norm(u)) -# import pdb;pdb.set_trace() +configuration['mpi'] = 0 u2.data[:] = u.data[:] +configuration['mpi'] = mpiconf -# Run more with no sources now (Not supported in xdsl) -op1 = Operator([stencil], name='DevitoOperator') -op1.apply(time=time_range.num-1, dt=model.critical_dt) +if args.devito: + # Run more with no sources now (Not supported in xdsl) + op1 = Operator([stencil], name='DevitoOperator') + op1.apply(time=time_range.num-1, dt=model.critical_dt) -if len(shape) == 2: - if args.plot: - plot_3dfunc(u) - -#devito_output = u.data[:] -print("After Operator 1: Devito norm:", norm(u)) -print("Devito linalg norm 0:", np.linalg.norm(u.data[0])) -print("Devito linalg norm 1:", np.linalg.norm(u.data[1])) -print("Devito linalg norm 2:", np.linalg.norm(u.data[2])) - -# import pdb;pdb.set_trace() - - -# Reset initial data -u.data[:] = u2.data[:] -#v[:, ..., :] = 1 + configuration['mpi'] = 0 + ub.data[:] = u.data[:] + configuration['mpi'] = mpiconf + if len(shape) == 2 and args.plot: + plot_2dfunc(u) -print("Reinitialise data: Devito norm:", norm(u)) -print("Init XDSL linalg norm:", np.linalg.norm(u.data[0])) -print("Init XDSL linalg norm:", np.linalg.norm(u.data[1])) -print("Init XDSL linalg norm:", np.linalg.norm(u.data[2])) + # print("After Operator 1: Devito norm:", norm(u)) + # print("Devito linalg norm 0:", np.linalg.norm(u.data[0])) + # print("Devito linalg norm 1:", np.linalg.norm(u.data[1])) + # print("Devito linalg norm 2:", np.linalg.norm(u.data[2])) -# Run more with no sources now (Not supported in xdsl) -xdslop = XDSLOperator([stencil], name='xDSLOperator') -xdslop.apply(time=time_range.num-1, dt=model.critical_dt) -xdsl_output = u.copy() -print("XDSL norm:", norm(u)) -print(f"xdsl output norm: {norm(xdsl_output)}") +if args.xdsl: + # Reset initial data + configuration['mpi'] = 0 + u.data[:] = u2.data[:] + configuration['mpi'] = mpiconf + # v[:, ..., :] = 1 + # print("Reinitialise data: Devito norm:", norm(u)) + # print("XDSL init linalg norm:", np.linalg.norm(u.data[0])) + # print("XDSL init linalg norm:", np.linalg.norm(u.data[1])) + # print("XDSL init linalg norm:", np.linalg.norm(u.data[2])) -print("XDSL output linalg norm:", np.linalg.norm(u.data[0])) -print("XDSL output linalg norm:", np.linalg.norm(u.data[1])) -print("XDSL output linalg norm:", np.linalg.norm(u.data[2])) + # Run more with no sources now (Not supported in xdsl) + xdslop = Operator([stencil], name='xDSLOperator') + xdslop.apply(time=time_range.num-1, dt=model.critical_dt) + if len(shape) == 2 and args.plot: + plot_2dfunc(u) + print("XDSL output norm 0:", np.linalg.norm(u.data[0]), "vs:", np.linalg.norm(ub.data[0])) + print("XDSL output norm 1:", np.linalg.norm(u.data[1]), "vs:", np.linalg.norm(ub.data[1])) + print("XDSL output norm 2:", np.linalg.norm(u.data[2]), "vs:", np.linalg.norm(ub.data[2])) diff --git a/fast/wave2d_b.py b/fast/wave2d_b.py new file mode 100644 index 0000000000..411d61ef14 --- /dev/null +++ b/fast/wave2d_b.py @@ -0,0 +1,129 @@ +# Based on the implementation of the Devito acoustic example implementation +# Not using Devito's source injection abstraction +import sys +import numpy as np + +from devito import (TimeFunction, Eq, Operator, solve, norm, + XDSLOperator, configuration, Grid) +from examples.seismic import RickerSource +from examples.seismic import Model, TimeAxis, plot_image +from fast.bench_utils import plot_2dfunc +from devito.tools import as_tuple + +import argparse +np.set_printoptions(threshold=np.inf) + + +parser = argparse.ArgumentParser(description='Process arguments.') + +parser.add_argument("-d", "--shape", default=(16, 16), type=int, nargs="+", + help="Number of grid points along each axis") +parser.add_argument("-so", "--space_order", default=4, + type=int, help="Space order of the simulation") +parser.add_argument("-to", "--time_order", default=2, + type=int, help="Time order of the simulation") +parser.add_argument("-nt", "--nt", default=20, + type=int, help="Simulation time in millisecond") +parser.add_argument("-bls", "--blevels", default=1, type=int, nargs="+", + help="Block levels") +parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D") +parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run") +parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run") +args = parser.parse_args() + + +mpiconf = configuration['mpi'] + +# Define a physical size +# nx, ny, nz = args.shape +nt = args.nt +so = args.space_order + +shape = (args.shape) # Number of grid point (nx, ny, nz) +shape_str = '_'.join(str(item) for item in shape) +spacing = as_tuple(10.0 for _ in range(len(shape))) # Grid spacing in m. The domain size is now 1km by 1km +origin = as_tuple(0.0 for _ in range(len(shape))) # What is the location of the top left corner. +domain_size = tuple((d-1) * s for d, s in zip(shape, spacing)) +extent = np.load("so%s_grid_extent%s.npy" % (so, shape_str), allow_pickle=True) + +grid = Grid(shape=shape, extent=as_tuple(extent)) + +# With the velocity and model size defined, we can create the seismic model that +# encapsulates this properties. We also define the size of the absorbing layer as +# 10 grid points +so = args.space_order +to = args.time_order + +t0 = 0. # Simulation starts a t=0 +tn = nt # Simulation last 1 second (1000 ms) + +# Define the wavefield with the size of the model and the time dimension +u = TimeFunction(name="u", grid=grid, time_order=to, space_order=so) +# Another one to clone data +u2 = TimeFunction(name="u", grid=grid, time_order=to, space_order=so) +ub = TimeFunction(name="ub", grid=grid, time_order=to, space_order=so) + +# We can now write the PDE +# pde = model.m * u.dt2 - u.laplace + model.damp * u.dt +# import pdb;pdb.set_trace() +pde = u.dt2 - u.laplace + +stencil = Eq(u.forward, solve(pde, u.forward)) + +# print("Init Devito linalg norm 0 :", np.linalg.norm(u.data[0])) +# print("Init Devito linalg norm 1 :", np.linalg.norm(u.data[1])) +# print("Init Devito linalg norm 2 :", np.linalg.norm(u.data[2])) +# print("Norm of initial data:", norm(u)) + +configuration['mpi'] = 0 +u2.data[:] = u.data[:] +configuration['mpi'] = mpiconf + +u.data[:] = np.load("so%s_wave_dat%s.npy" % (so, shape_str), allow_pickle=True) +dt = np.load("so%s_critical_dt%s.npy" % (so, shape_str), allow_pickle=True) + +# np.save("critical_dt%s.npy" % shape_str, model.critical_dt, allow_pickle=True) +# np.save("wave_dat%s.npy" % shape_str, u.data[:], allow_pickle=True) + +if len(shape) == 2 and args.plot: + plot_2dfunc(u) + +print("Init norm:", np.linalg.norm(u.data[:])) + +if args.devito: + # Run more with no sources now (Not supported in xdsl) + # op1 = Operator([stencil], name='DevitoOperator', subs=grid.spacing_map) + op1 = Operator([stencil], name='DevitoOperator') + op1.apply(time=nt, dt=dt) + + configuration['mpi'] = 0 + ub.data[:] = u.data[:] + configuration['mpi'] = mpiconf + + if len(shape) == 2 and args.plot: + plot_2dfunc(u) + + print("Devito norm:", norm(u)) + # print("Devito linalg norm 0:", np.linalg.norm(u.data[0])) + # print("Devito linalg norm 1:", np.linalg.norm(u.data[1])) + # print("Devito linalg norm 2:", np.linalg.norm(u.data[2])) + + +if args.xdsl: + # print("Reinitialise data: Devito norm:", norm(u)) + # print("XDSL init linalg norm:", np.linalg.norm(u.data[0])) + # print("XDSL init linalg norm:", np.linalg.norm(u.data[1])) + # print("XDSL init linalg norm:", np.linalg.norm(u.data[2])) + + # Run more with no sources now (Not supported in xdsl) + xdslop = Operator([stencil], name='xDSLOperator') + xdslop.apply(time=nt, dt=dt) + + if len(shape) == 2 and args.plot: + plot_2dfunc(u) + + print("XDSL norm:", norm(u)) + + # print("XDSL output norm 0:", np.linalg.norm(u.data[0]), "vs:", np.linalg.norm(ub.data[0])) + # print("XDSL output norm 1:", np.linalg.norm(u.data[1]), "vs:", np.linalg.norm(ub.data[1])) + # print("XDSL output norm 2:", np.linalg.norm(u.data[2]), "vs:", np.linalg.norm(ub.data[2])) diff --git a/fast/wave3d.py b/fast/wave3d.py index abf70d3cdc..86b0d55afe 100644 --- a/fast/wave3d.py +++ b/fast/wave3d.py @@ -2,9 +2,11 @@ # Not using Devito's source injection abstraction import sys import numpy as np -from devito import TimeFunction, Eq, Operator, solve, norm, XDSLOperator +from devito import (TimeFunction, Eq, Operator, solve, norm, + XDSLOperator, configuration) from examples.seismic import RickerSource from examples.seismic import Model, TimeAxis +from fast.bench_utils import plot_3dfunc from devito.tools import as_tuple @@ -25,24 +27,12 @@ parser.add_argument("-bls", "--blevels", default=2, type=int, nargs="+", help="Block levels") parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot3D") +parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run") +parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run") args = parser.parse_args() -def plot_3dfunc(u): - # Plot a 3D structured grid using pyvista - - import matplotlib.pyplot as plt - import pyvista as pv - cmap = plt.colormaps["viridis"] - values = u.data[0, :, :, :] - vistagrid = pv.UniformGrid() - vistagrid.dimensions = np.array(values.shape) + 1 - vistagrid.spacing = (1, 1, 1) - vistagrid.origin = (0, 0, 0) # The bottom left corner of the data set - vistagrid.cell_data["values"] = values.flatten(order="F") - vistaslices = vistagrid.slice_orthogonal() - # vistagrid.plot(show_edges=True) - vistaslices.plot(cmap=cmap) +mpiconf = configuration['mpi'] # Define a physical size @@ -87,83 +77,72 @@ def plot_3dfunc(u): # First, position source centrally in all dimensions, then set depth src.coordinates.data[0, :] = np.array(model.domain_size) * .5 -# We can plot the time signature to see the wavelet -# src.show() - # Define the wavefield with the size of the model and the time dimension u = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so) - +# Another one to clone data u2 = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so) +ub = TimeFunction(name="ub", grid=model.grid, time_order=to, space_order=so) + # We can now write the PDE # pde = model.m * u.dt2 - u.laplace + model.damp * u.dt # import pdb;pdb.set_trace() pde = u.dt2 - u.laplace -# The PDE representation is as on paper -pde - stencil = Eq(u.forward, solve(pde, u.forward)) -stencil - -# Finally we define the source injection and receiver read function to generate -# the corresponding code -print(time_range) -print("Init norm:", norm(u)) src_term = src.inject(field=u.forward, expr=src * dt**2 / model.m) op0 = Operator([stencil] + src_term, subs=model.spacing_map, name='SourceDevitoOperator') # Run with source and plot op0.apply(time=time_range.num-1, dt=model.critical_dt) + if len(shape) == 3: if args.plot: plot_3dfunc(u) -print("Init linalg norm 0 :", np.linalg.norm(u.data[0])) -print("Init linalg norm 1 :", np.linalg.norm(u.data[1])) -print("Init linalg norm 2 :", np.linalg.norm(u.data[2])) +# devito_norm = norm(u) +# print("Init linalg norm 0 (inlined) :", norm(u)) +# print("Init linalg norm 0 :", np.linalg.norm(u.data[0])) +# print("Init linalg norm 1 :", np.linalg.norm(u.data[1])) +# print("Init linalg norm 2 :", np.linalg.norm(u.data[2])) +# print("Norm of initial data:", np.linalg.norm(u.data[:])) -print("Norm of initial data:", norm(u)) -import pdb;pdb.set_trace() +configuration['mpi'] = 0 u2.data[:] = u.data[:] +configuration['mpi'] = mpiconf # Run more with no sources now (Not supported in xdsl) op1 = Operator([stencil], name='DevitoOperator') op1.apply(time=time_range.num-1, dt=model.critical_dt) +configuration['mpi'] = 0 +ub.data[:] = u.data[:] +configuration['mpi'] = mpiconf + if len(shape) == 3: if args.plot: plot_3dfunc(u) -#devito_output = u.data[:] -print("After Operator 1: Devito norm:", norm(u)) -print("Devito linalg norm 0:", np.linalg.norm(u.data[0])) -print("Devito linalg norm 1:", np.linalg.norm(u.data[1])) -print("Devito linalg norm 2:", np.linalg.norm(u.data[2])) - -import pdb;pdb.set_trace() - +# print("After Operator 1: Devito norm:", np.linalg.norm(u.data[:])) +#print("Devito norm 0:", np.linalg.norm(u.data[0])) +#print("Devito norm 1:", np.linalg.norm(u.data[1])) +#print("Devito norm 2:", np.linalg.norm(u.data[2])) # Reset initial data +configuration['mpi'] = 0 u.data[:] = u2.data[:] -#v[:, ..., :] = 1 - +configuration['mpi'] = mpiconf -print("Reinitialise data: Devito norm:", norm(u)) -print("Init XDSL linalg norm:", np.linalg.norm(u.data[0])) -print("Init XDSL linalg norm:", np.linalg.norm(u.data[1])) -print("Init XDSL linalg norm:", np.linalg.norm(u.data[2])) +# print("Reinitialise data for XDSL:", np.linalg.norm(u.data[:])) +# print("Init XDSL linalg norm 0:", np.linalg.norm(u.data[0])) +# print("Init XDSL linalg norm 1:", np.linalg.norm(u.data[1])) +# print("Init XDSL linalg norm 2:", np.linalg.norm(u.data[2])) # Run more with no sources now (Not supported in xdsl) -xdslop = XDSLOperator([stencil], name='xDSLOperator') +xdslop = XDSLOperator([stencil], name='XDSLOperator') xdslop.apply(time=time_range.num-1, dt=model.critical_dt) -xdsl_output = u.copy() -print("XDSL norm:", norm(u)) -print(f"xdsl output norm: {norm(xdsl_output)}") -import pdb;pdb.set_trace() - -print("XDSL output linalg norm:", np.linalg.norm(u.data[0])) -print("XDSL output linalg norm:", np.linalg.norm(u.data[1])) -print("XDSL output linalg norm:", np.linalg.norm(u.data[2])) +print("XDSL output norm 0:", np.linalg.norm(u.data[0]), "vs:", np.linalg.norm(ub.data[0])) +print("XDSL output norm 1:", np.linalg.norm(u.data[1]), "vs:", np.linalg.norm(ub.data[1])) +print("XDSL output norm 2:", np.linalg.norm(u.data[2]), "vs:", np.linalg.norm(ub.data[2])) diff --git a/fast/wave3d_b.py b/fast/wave3d_b.py new file mode 100644 index 0000000000..3e3081f597 --- /dev/null +++ b/fast/wave3d_b.py @@ -0,0 +1,128 @@ +# Based on the implementation of the Devito acoustic example implementation +# Not using Devito's source injection abstraction +import sys +import numpy as np + +from devito import (TimeFunction, Eq, Operator, solve, norm, + XDSLOperator, configuration, Grid) +from examples.seismic import RickerSource +from examples.seismic import Model, TimeAxis, plot_image +from fast.bench_utils import plot_3dfunc +from devito.tools import as_tuple + +import argparse +np.set_printoptions(threshold=np.inf) + + +parser = argparse.ArgumentParser(description='Process arguments.') + +parser.add_argument("-d", "--shape", default=(16, 16, 16), type=int, nargs="+", + help="Number of grid points along each axis") +parser.add_argument("-so", "--space_order", default=4, + type=int, help="Space order of the simulation") +parser.add_argument("-to", "--time_order", default=2, + type=int, help="Time order of the simulation") +parser.add_argument("-nt", "--nt", default=20, + type=int, help="Simulation time in millisecond") +parser.add_argument("-bls", "--blevels", default=1, type=int, nargs="+", + help="Block levels") +parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D") +parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run") +parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run") +args = parser.parse_args() + + +mpiconf = configuration['mpi'] + +# Define a physical size +# nx, ny, nz = args.shape +nt = args.nt +so = args.space_order + +shape = (args.shape) # Number of grid point (nx, ny, nz) +shape_str = '_'.join(str(item) for item in shape) +spacing = as_tuple(10.0 for _ in range(len(shape))) # Grid spacing in m. The domain size is now 1km by 1km +origin = as_tuple(0.0 for _ in range(len(shape))) # What is the location of the top left corner. +domain_size = tuple((d-1) * s for d, s in zip(shape, spacing)) +extent = np.load("so%s_grid_extent%s.npz" % (so, shape_str))['arr_0'] +grid = Grid(shape=shape, extent=as_tuple(extent)) + +# With the velocity and model size defined, we can create the seismic model that +# encapsulates this properties. We also define the size of the absorbing layer as +# 10 grid points +so = args.space_order +to = args.time_order + +t0 = 0. # Simulation starts a t=0 +tn = nt # Simulation last 1 second (1000 ms) + +# Define the wavefield with the size of the model and the time dimension +u = TimeFunction(name="u", grid=grid, time_order=to, space_order=so) +# Another one to clone data +u2 = TimeFunction(name="u", grid=grid, time_order=to, space_order=so) +ub = TimeFunction(name="ub", grid=grid, time_order=to, space_order=so) + +# We can now write the PDE +# pde = model.m * u.dt2 - u.laplace + model.damp * u.dt +# import pdb;pdb.set_trace() +pde = u.dt2 - u.laplace + +stencil = Eq(u.forward, solve(pde, u.forward)) + +# print("Init Devito linalg norm 0 :", np.linalg.norm(u.data[0])) +# print("Init Devito linalg norm 1 :", np.linalg.norm(u.data[1])) +# print("Init Devito linalg norm 2 :", np.linalg.norm(u.data[2])) +# print("Norm of initial data:", norm(u)) + +configuration['mpi'] = 0 +u2.data[:] = u.data[:] +configuration['mpi'] = mpiconf + +u.data[:] = np.load("so%s_wave_dat%s.npz" % (so, shape_str), allow_pickle=True)['arr_0'] +dt = np.load("so%s_critical_dt%s.npy" % (so, shape_str), allow_pickle=True) + +# np.save("critical_dt%s.npy" % shape_str, model.critical_dt, allow_pickle=True) +# np.save("wave_dat%s.npy" % shape_str, u.data[:], allow_pickle=True) + +if len(shape) == 3 and args.plot: + plot_3dfunc(u) + +print("Init norm:", np.linalg.norm(u.data[:])) +# print("Init linalg norm:", np.linalg.norm(u.data[0])) +# print("Init linalg norm:", np.linalg.norm(u.data[1])) +# print("Init linalg norm:", np.linalg.norm(u.data[2])) + + +if args.devito: + # Run more with no sources now (Not supported in xdsl) + # op1 = Operator([stencil], name='DevitoOperator', subs=grid.spacing_map) + op1 = Operator([stencil], name='DevitoOperator') + op1.apply(time=nt, dt=dt) + + configuration['mpi'] = 0 + ub.data[:] = u.data[:] + configuration['mpi'] = mpiconf + + if len(shape) == 3 and args.plot: + plot_3dfunc(u) + + print("Devito norm:", norm(u)) + # print("Devito linalg norm 0:", np.linalg.norm(u.data[0])) + # print("Devito linalg norm 1:", np.linalg.norm(u.data[1])) + # print("Devito linalg norm 2:", np.linalg.norm(u.data[2])) + + +if args.xdsl: + + # Run more with no sources now (Not supported in xdsl) + xdslop = XDSLOperator([stencil], name='xDSLOperator') + xdslop.apply(time=nt, dt=dt) + + if len(shape) == 3 and args.plot: + plot_3dfunc(u) + + print("XDSL norm:", norm(u)) + + # print("XDSL output norm 0:", np.linalg.norm(u.data[0])) + # print("XDSL output norm 1:", np.linalg.norm(u.data[1])) + # print("XDSL output norm 2:", np.linalg.norm(u.data[2])) diff --git a/fast/wave_dat2.npy b/fast/wave_dat2.npy new file mode 100644 index 0000000000..22e536f687 Binary files /dev/null and b/fast/wave_dat2.npy differ