Skip to content

Commit

Permalink
bench: Generalize benchmarking scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Aug 4, 2023
1 parent ae3d586 commit 9a973cc
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 51 deletions.
8 changes: 8 additions & 0 deletions fast/bench_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from examples.seismic import plot_image

__all__ = ['plot_2dfunc']


def plot_2dfunc(u):
# Plot a 3D structured grid using pyvista
plot_image(u.data[0], cmap='seismic')
5 changes: 4 additions & 1 deletion fast/diffusion_2D_wBCs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np

from devito import Grid, TimeFunction, Eq, solve, Operator, Constant, norm, XDSLOperator
from examples.seismic import plot_image
from examples.cfd import init_hat
from fast.bench_utils import plot_2dfunc

parser = argparse.ArgumentParser(description='Process arguments.')

Expand Down Expand Up @@ -58,6 +58,9 @@
op.apply(time=nt, dt=dt, a=nu)
print("Devito Field norm is:", norm(u))

if args.plot:
plot_2dfunc(u)

# Reset data
init_hat(field=u.data[0], dx=dx, dy=dy, value=1.)

Expand Down
90 changes: 40 additions & 50 deletions fast/wave2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from devito import (TimeFunction, Eq, Operator, solve, norm,
XDSLOperator, configuration)
from examples.seismic import RickerSource
from examples.seismic import Model, TimeAxis

from examples.seismic import Model, TimeAxis, plot_image
from fast.bench_utils import plot_2dfunc
from devito.tools import as_tuple

import argparse
Expand All @@ -27,26 +27,12 @@
parser.add_argument("-bls", "--blevels", default=1, type=int, nargs="+",
help="Block levels")
parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D")
parser.add_argument("-mode", "--mode", default='devito', type=str, help="Operator mode")
parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run")
parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run")
args = parser.parse_args()


def plot_2dfunc(u):
# Plot a 3D structured grid using pyvista

import matplotlib.pyplot as plt
import pyvista as pv
cmap = plt.colormaps["viridis"]
values = u.data[0, :, :, :]
vistagrid = pv.UniformGrid()
vistagrid.dimensions = np.array(values.shape) + 1
vistagrid.spacing = (1, 1, 1)
vistagrid.origin = (0, 0, 0) # The bottom left corner of the data set
vistagrid.cell_data["values"] = values.flatten(order="F")
vistaslices = vistagrid.slice_orthogonal()
# vistagrid.plot(show_edges=True)
vistaslices.plot(cmap=cmap)

mpiconf = configuration['mpi']

# Define a physical size
# nx, ny, nz = args.shape
Expand Down Expand Up @@ -122,47 +108,51 @@ def plot_2dfunc(u):
if args.plot:
plot_2dfunc(u)

#print("Init Devito linalg norm 0 :", np.linalg.norm(u.data[0]))
#print("Init Devito linalg norm 1 :", np.linalg.norm(u.data[1]))
#print("Init Devito linalg norm 2 :", np.linalg.norm(u.data[2]))
# print("Init Devito linalg norm 0 :", np.linalg.norm(u.data[0]))
# print("Init Devito linalg norm 1 :", np.linalg.norm(u.data[1]))
# print("Init Devito linalg norm 2 :", np.linalg.norm(u.data[2]))
# print("Norm of initial data:", norm(u))

configuration['mpi'] = 0
u2.data[:] = u.data[:]
configuration['mpi'] = 'basic'
configuration['mpi'] = mpiconf

# Run more with no sources now (Not supported in xdsl)
op1 = Operator([stencil], name='DevitoOperator')
op1.apply(time=time_range.num-1, dt=model.critical_dt)

configuration['mpi'] = 0
ub.data[:] = u.data[:]
configuration['mpi'] = 'basic'
if args.devito:
# Run more with no sources now (Not supported in xdsl)
op1 = Operator([stencil], name='DevitoOperator')
op1.apply(time=time_range.num-1, dt=model.critical_dt)

#print("After Operator 1: Devito norm:", norm(u))
#print("Devito linalg norm 0:", np.linalg.norm(u.data[0]))
#print("Devito linalg norm 1:", np.linalg.norm(u.data[1]))
#print("Devito linalg norm 2:", np.linalg.norm(u.data[2]))
configuration['mpi'] = 0
ub.data[:] = u.data[:]
configuration['mpi'] = mpiconf

# import pdb;pdb.set_trace()
if len(shape) == 2 and args.plot:
plot_2dfunc(u)

# print("After Operator 1: Devito norm:", norm(u))
# print("Devito linalg norm 0:", np.linalg.norm(u.data[0]))
# print("Devito linalg norm 1:", np.linalg.norm(u.data[1]))
# print("Devito linalg norm 2:", np.linalg.norm(u.data[2]))

# Reset initial data
configuration['mpi'] = 0
u.data[:] = u2.data[:]
configuration['mpi'] = 'basic'
#v[:, ..., :] = 1

if args.xdsl:
# Reset initial data
configuration['mpi'] = 0
u.data[:] = u2.data[:]
configuration['mpi'] = mpiconf
# v[:, ..., :] = 1
# print("Reinitialise data: Devito norm:", norm(u))
# print("XDSL init linalg norm:", np.linalg.norm(u.data[0]))
# print("XDSL init linalg norm:", np.linalg.norm(u.data[1]))
# print("XDSL init linalg norm:", np.linalg.norm(u.data[2]))

#print("Reinitialise data: Devito norm:", norm(u))
#print("Init XDSL linalg norm:", np.linalg.norm(u.data[0]))
#print("Init XDSL linalg norm:", np.linalg.norm(u.data[1]))
#print("Init XDSL linalg norm:", np.linalg.norm(u.data[2]))
# Run more with no sources now (Not supported in xdsl)
xdslop = Operator([stencil], name='xDSLOperator')
xdslop.apply(time=time_range.num-1, dt=model.critical_dt)

# Run more with no sources now (Not supported in xdsl)
xdslop = Operator([stencil], name='xDSLOperator')
xdslop.apply(time=time_range.num-1, dt=model.critical_dt)
if len(shape) == 2 and args.plot:
plot_2dfunc(u)

print("XDSL output norm 0:", np.linalg.norm(u.data[0]), "vs:", np.linalg.norm(ub.data[0]))
print("XDSL output norm 1:", np.linalg.norm(u.data[1]), "vs:", np.linalg.norm(ub.data[1]))
print("XDSL output norm 2:", np.linalg.norm(u.data[2]), "vs:", np.linalg.norm(ub.data[2]))
print("XDSL output norm 0:", np.linalg.norm(u.data[0]), "vs:", np.linalg.norm(ub.data[0]))
print("XDSL output norm 1:", np.linalg.norm(u.data[1]), "vs:", np.linalg.norm(ub.data[1]))
print("XDSL output norm 2:", np.linalg.norm(u.data[2]), "vs:", np.linalg.norm(ub.data[2]))

0 comments on commit 9a973cc

Please sign in to comment.