Skip to content

Commit

Permalink
Merge pull request #24 from xdslproject/bench_edits-2
Browse files Browse the repository at this point in the history
mpi: Init effort for serial modelling on wave operator
  • Loading branch information
georgebisbas authored Aug 9, 2023
2 parents f2604c6 + 6cfe569 commit 80c0d31
Show file tree
Hide file tree
Showing 18 changed files with 742 additions and 615 deletions.
8 changes: 4 additions & 4 deletions devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _ensure_same_type(self, *vals: SSAValue):
new_vals.append(val)
continue
# insert an integer to float cast op
conv = arith.SIToFPOp.get(val, builtin.f32)
conv = arith.SIToFPOp(val, builtin.f32)
self.block.add_op(conv)
new_vals.append(conv.result)
return new_vals
Expand Down Expand Up @@ -340,16 +340,16 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
self.seen_ops.add(op)

rewriter.insert_op_at_start([
t0 := func.Call.get('timer_start', [], [builtin.f64])
t0 := func.Call('timer_start', [], [builtin.f64])
], op.body.block)

ret = op.get_return_op()
assert ret is not None

rewriter.insert_op_before([
timers := iet_ssa.LoadSymbolic.get('timers', llvm.LLVMPointerType.typed(builtin.f64)),
t1 := func.Call.get('timer_end', [t0], [builtin.f64]),
llvm.StoreOp.get(t1, timers),
t1 := func.Call('timer_end', [t0], [builtin.f64]),
llvm.StoreOp(t1, timers),
], ret)

rewriter.insert_op_after_matched_op([
Expand Down
13 changes: 6 additions & 7 deletions devito/ir/ietxdsl/ietxdsl_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
# XDSL specific imports
from xdsl.irdl import AnyOf, Operation, SSAValue
from xdsl.dialects.builtin import (ContainerOf, Float16Type, Float32Type,
Float64Type, Builtin, i32, f32)
Float64Type, i32, f32)

from xdsl.dialects.arith import Muli, Addi
from devito.ir.ietxdsl import iet_ssa

from xdsl.dialects import memref, arith, builtin, llvm
from xdsl.dialects import memref, arith, builtin
from xdsl.dialects.experimental import math

import devito.types
Expand Down Expand Up @@ -74,7 +73,7 @@ def print_calls(cgen, calldefs):
print("Call not translated in calldefs")
return

call = Call.get(call_name, C_names, C_typenames, C_typeqs, prefix, retval)
call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval)

cgen.printCall(call, True)

Expand Down Expand Up @@ -180,10 +179,10 @@ def add_to_block(expr, arg_by_expr: dict[Any, Operation], result):
# reconcile differences

if isinstance(rhs.typ, builtin.IntegerType):
rhs = arith.SIToFPOp.get(rhs, lhs.typ)
rhs = arith.SIToFPOp(rhs, lhs.typ)
result.append(rhs)
else:
lhs = arith.SIToFPOp.get(lhs, rhs.typ)
lhs = arith.SIToFPOp(lhs, rhs.typ)
result.append(lhs)


Expand Down Expand Up @@ -426,7 +425,7 @@ def myVisit(node, block: Block, ssa_vals={}):
print(f"Call {node.name} instance translated as comment")
return

call = Call.get(call_name, C_names, C_typenames, C_typeqs, prefix, retval)
call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval)
block.add_ops([call])

print(f"Call {node.name} translated")
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/ietxdsl/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def match_and_rewrite(self, op: memref.Store, rewriter: PatternRewriter,
ssa_indices=[idx],
result_type=llvm.LLVMPointerType.typed(op.memref.memref.element_type)
),
store := llvm.StoreOp.get(op.value, gep),
store := llvm.StoreOp(op.value, gep),
],
[],
)
Expand Down
4 changes: 3 additions & 1 deletion devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None):
# mpi4py takes care of that when the object gets out of scope
self._input_comm = (input_comm or MPI.COMM_WORLD).Clone()

topology = ('*', '*', 1)
if len(shape) == 3:
topology = ('*', '*', 1)

if topology is None:
# `MPI.Compute_dims` sets the dimension sizes to be as close to each other
# as possible, using an appropriate divisibility algorithm. Thus, in 3D:
Expand Down
17 changes: 11 additions & 6 deletions devito/operator/xdsl_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@
# gpu-launch-sink-index-computations seemed to have no impact
MLIR_GPU_PIPELINE = '"builtin.module(test-math-algebraic-simplification,scf-parallel-loop-tiling{parallel-loop-tile-sizes=128,1,1},func.func(gpu-map-parallel-loops),convert-parallel-loops-to-gpu,fold-memref-alias-ops,expand-strided-metadata,lower-affine,gpu-kernel-outlining,canonicalize,cse,convert-arith-to-llvm{index-bitwidth=64},finalize-memref-to-llvm{index-bitwidth=64},convert-scf-to-cf,convert-cf-to-llvm{index-bitwidth=64},canonicalize,cse,gpu.module(convert-gpu-to-nvvm,reconcile-unrealized-casts,canonicalize,gpu-to-cubin),gpu-to-llvm,canonicalize,cse)"'

XDSL_CPU_PIPELINE = "stencil-shape-inference,convert-stencil-to-ll-mlir,printf-to-llvm"
XDSL_GPU_PIPELINE = "stencil-shape-inference,convert-stencil-to-ll-mlir{target=gpu},printf-to-llvm"
XDSL_MPI_PIPELINE = lambda decomp: f'"dmp-decompose-2d{decomp},canonicalize-dmp,convert-stencil-to-ll-mlir,dmp-to-mpi{{mpi_init=false}},lower-mpi,printf-to-llvm"'
XDSL_CPU_PIPELINE = lambda nb_tiled_dims: f'"stencil-shape-inference,convert-stencil-to-ll-mlir{{tile-sizes={",".join(["64"]*nb_tiled_dims)}}},printf-to-llvm"'
XDSL_GPU_PIPELINE = '"stencil-shape-inference,convert-stencil-to-ll-mlir{target=gpu},printf-to-llvm"'
XDSL_MPI_PIPELINE = lambda decomp, nb_tiled_dims: f'"dmp-decompose-2d{decomp},canonicalize-dmp,convert-stencil-to-ll-mlir{{tile-sizes={",".join(["64"]*nb_tiled_dims)}}},dmp-to-mpi{{mpi_init=false}},lower-mpi,printf-to-llvm"'


class XDSLOperator(Operator):
Expand Down Expand Up @@ -85,7 +85,10 @@ def _make_interop_o(self):
@property
def mpi_shape(self) -> tuple:
dist = self.functions[0].grid.distributor
return dist.topology, dist.myrank
# temporary fix:
# swap dim 0 and 1 in topology because dmp.grid is row major and not column major

return (dist.topology[1], dist.topology[0], *dist.topology[2:]), dist.myrank

def _jit_compile(self):
"""
Expand Down Expand Up @@ -114,7 +117,9 @@ def _jit_compile(self):
Printer(stream=module_str).print(self._module)
module_str = module_str.getvalue()

xdsl_pipeline = XDSL_CPU_PIPELINE
to_tile = len(list(filter(lambda s : str(s) in ["x", "y", "z"], self.dimensions)))-1

xdsl_pipeline = XDSL_CPU_PIPELINE(to_tile)
mlir_pipeline = MLIR_CPU_PIPELINE

if is_omp:
Expand All @@ -126,7 +131,7 @@ def _jit_compile(self):
# reduce the domain of the computation (as devito has already done that for us)
slices = ','.join(str(x) for x in shape)
decomp = f"{{strategy=2d-grid slices={slices} restrict_domain=false}}"
xdsl_pipeline = XDSL_MPI_PIPELINE(decomp)
xdsl_pipeline = XDSL_MPI_PIPELINE(decomp, to_tile)
elif is_gpu:
xdsl_pipeline = XDSL_GPU_PIPELINE
mlir_pipeline = MLIR_GPU_PIPELINE
Expand Down
25 changes: 25 additions & 0 deletions fast/bench_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import matplotlib.pyplot as plt
import numpy as np
from examples.seismic import plot_image

__all__ = ['plot_2dfunc', 'plot_3dfunc']


def plot_2dfunc(u):
# Plot a 2D image using devito's machinery
plot_image(u.data[0], cmap='seismic')
plot_image(u.data[1], cmap='seismic')


def plot_3dfunc(u):
# Plot a 3D structured grid using pyvista
import pyvista as pv
cmap = plt.colormaps["viridis"]
values = u.data[0, :, :, :]
vistagrid = pv.ImageData()
vistagrid.dimensions = np.array(values.shape) + 1
vistagrid.spacing = (1, 1, 1)
vistagrid.origin = (0, 0, 0) # The bottom left corner of the data set
vistagrid.cell_data["values"] = values.flatten(order="F")
vistaslices = vistagrid.slice_orthogonal()
vistaslices.plot(cmap=cmap)
33 changes: 17 additions & 16 deletions fast/diffusion_2D_wBCs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np

from devito import Grid, TimeFunction, Eq, solve, Operator, Constant, norm, XDSLOperator
from examples.seismic import plot_image
from examples.cfd import init_hat
from fast.bench_utils import plot_2dfunc

parser = argparse.ArgumentParser(description='Process arguments.')

Expand All @@ -21,7 +21,9 @@
type=int, help="Simulation time in millisecond")
parser.add_argument("-bls", "--blevels", default=2, type=int, nargs="+",
help="Block levels")
parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot3D")
parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D")
parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run")
parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run")
args = parser.parse_args()

# Some variable declarations
Expand All @@ -41,29 +43,28 @@
grid = Grid(shape=(nx, ny), extent=(2., 2.))
u = TimeFunction(name='u', grid=grid, space_order=so)

# Reset our data field and ICs
init_hat(field=u.data[0], dx=dx, dy=dy, value=1.)

a = Constant(name='a')
# Create an equation with second-order derivatives
# eq = Eq(u.dt, a * u.laplace, subdomain=grid.interior)
eq = Eq(u.dt, a * u.laplace)
stencil = solve(eq, u.forward)
eq_stencil = Eq(u.forward, stencil)

# Create boundary condition expressions
x, y = grid.dimensions
t = grid.stepping_dim
# Reset our data field and ICs
init_hat(field=u.data[0], dx=dx, dy=dy, value=1.)

initdata = u.data[:]
op = Operator([eq_stencil], name='DevitoOperator')
op.apply(time=nt, dt=dt, a=nu)
if args.devito:
op = Operator([eq_stencil], name='DevitoOperator')
op.apply(time=nt, dt=dt, a=nu)
print("Devito Field norm is:", norm(u))

print("Devito Field norm is:", norm(u))
if args.plot:
plot_2dfunc(u)

# Reset data and run XDSLOperator
# Reset data
init_hat(field=u.data[0], dx=dx, dy=dy, value=1.)
xdslop = XDSLOperator([eq_stencil], name='XDSLOperator')
xdslop.apply(time=nt, dt=dt, a=nu)

print("XDSL Field norm is:", norm(u))
if args.xdsl:
xdslop = XDSLOperator([eq_stencil], name='XDSLOperator')
xdslop.apply(time=nt, dt=dt, a=nu)
print("XDSL Field norm is:", norm(u))
72 changes: 25 additions & 47 deletions fast/diffusion_3D_wBCs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import argparse
import numpy as np

from devito import Grid, TimeFunction, Eq, solve, Operator, Constant, norm, XDSLOperator
from devito import (Grid, TimeFunction, Eq, solve, Operator, Constant,
norm, XDSLOperator)
from fast.bench_utils import plot_3dfunc

parser = argparse.ArgumentParser(description='Process arguments.')

Expand All @@ -20,27 +22,10 @@
parser.add_argument("-bls", "--blevels", default=2, type=int, nargs="+",
help="Block levels")
parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot3D")
parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run")
parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run")
args = parser.parse_args()


def plot_3dfunc(u):
# Plot a 3D structured grid using pyvista

import matplotlib.pyplot as plt
import pyvista as pv

cmap = plt.cm.get_cmap("viridis")
values = u.data[0, :, :, :]
vistagrid = pv.ImageData()
vistagrid.dimensions = np.array(values.shape) + 1
vistagrid.spacing = (1, 1, 1)
vistagrid.origin = (0, 0, 0) # The bottom left corner of the data set
vistagrid.cell_data["values"] = values.flatten(order="F")
vistaslices = vistagrid.slice_orthogonal()
# vistagrid.plot(show_edges=True)
vistaslices.plot(cmap=cmap)


# Some variable declarations
nx, ny, nz = args.shape
nt = args.nt
Expand All @@ -59,39 +44,32 @@ def plot_3dfunc(u):

grid = Grid(shape=(nx, ny, nz), extent=(2., 2., 2.))
u = TimeFunction(name='u', grid=grid, space_order=so)
# init_hat(field=u.data[0], dx=dx, dy=dy, value=2.)
u.data[:, :, :, :] = 0
u.data[:, :, :, int(nz/2)] = 1

a = Constant(name='a')
# Create an equation with second-order derivatives
eq = Eq(u.dt, a * u.laplace)

stencil = solve(eq, u.forward)
eq_stencil = Eq(u.forward, stencil)

# Create boundary condition expressions
x, y, z = grid.dimensions
t = grid.stepping_dim

print(eq_stencil)

# Create Operator
op = Operator([eq_stencil], name='DevitoOperator')
# Apply the operator for a number of timesteps
op.apply(time=nt, dt=dt, a=nu)
print("Devito Field norm is:", norm(u))

# Reset field
u.data[:, :, :, :] = 0
u.data[:, :, :, int(nz/2)] = 1
xdslop = XDSLOperator([eq_stencil], name='xDSLOperator')
# Apply the xdsl operator for a number of timesteps
xdslop.apply(time=nt, dt=dt, a=nu)

if args.plot:
plot_3dfunc(u)

print("XDSL Field norm is:", norm(u))

# import pdb;pdb.set_trace()
if args.devito:
u.data[:, :, :, :] = 0
u.data[:, :, :, int(nz/2)] = 1
op = Operator([eq_stencil], name='DevitoOperator')
# Apply the operator for a number of timesteps
op.apply(time=nt, dt=dt, a=nu)
print("Devito Field norm is:", norm(u))
if args.plot:
plot_3dfunc(u)

if args.xdsl:
# Reset field
u.data[:, :, :, :] = 0
u.data[:, :, :, int(nz/2)] = 1
xdslop = XDSLOperator([eq_stencil], name='xDSLOperator')
# Apply the xdsl operator for a number of timesteps
xdslop.apply(time=nt, dt=dt, a=nu)
print("XDSL Field norm is:", norm(u))
if args.plot:
plot_3dfunc(u)
Loading

0 comments on commit 80c0d31

Please sign in to comment.