Skip to content

Commit

Permalink
mpi: Init effort for serial modelling on wave operator
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Aug 2, 2023
1 parent f2604c6 commit bacf1af
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 29 deletions.
8 changes: 4 additions & 4 deletions devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _ensure_same_type(self, *vals: SSAValue):
new_vals.append(val)
continue
# insert an integer to float cast op
conv = arith.SIToFPOp.get(val, builtin.f32)
conv = arith.SIToFPOp(val, builtin.f32)
self.block.add_op(conv)
new_vals.append(conv.result)
return new_vals
Expand Down Expand Up @@ -340,16 +340,16 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
self.seen_ops.add(op)

rewriter.insert_op_at_start([
t0 := func.Call.get('timer_start', [], [builtin.f64])
t0 := func.Call('timer_start', [], [builtin.f64])
], op.body.block)

ret = op.get_return_op()
assert ret is not None

rewriter.insert_op_before([
timers := iet_ssa.LoadSymbolic.get('timers', llvm.LLVMPointerType.typed(builtin.f64)),
t1 := func.Call.get('timer_end', [t0], [builtin.f64]),
llvm.StoreOp.get(t1, timers),
t1 := func.Call('timer_end', [t0], [builtin.f64]),
llvm.StoreOp(t1, timers),
], ret)

rewriter.insert_op_after_matched_op([
Expand Down
13 changes: 6 additions & 7 deletions devito/ir/ietxdsl/ietxdsl_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
# XDSL specific imports
from xdsl.irdl import AnyOf, Operation, SSAValue
from xdsl.dialects.builtin import (ContainerOf, Float16Type, Float32Type,
Float64Type, Builtin, i32, f32)
Float64Type, i32, f32)

from xdsl.dialects.arith import Muli, Addi
from devito.ir.ietxdsl import iet_ssa

from xdsl.dialects import memref, arith, builtin, llvm
from xdsl.dialects import memref, arith, builtin
from xdsl.dialects.experimental import math

import devito.types
Expand Down Expand Up @@ -74,7 +73,7 @@ def print_calls(cgen, calldefs):
print("Call not translated in calldefs")
return

call = Call.get(call_name, C_names, C_typenames, C_typeqs, prefix, retval)
call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval)

cgen.printCall(call, True)

Expand Down Expand Up @@ -180,10 +179,10 @@ def add_to_block(expr, arg_by_expr: dict[Any, Operation], result):
# reconcile differences

if isinstance(rhs.typ, builtin.IntegerType):
rhs = arith.SIToFPOp.get(rhs, lhs.typ)
rhs = arith.SIToFPOp(rhs, lhs.typ)
result.append(rhs)
else:
lhs = arith.SIToFPOp.get(lhs, rhs.typ)
lhs = arith.SIToFPOp(lhs, rhs.typ)
result.append(lhs)


Expand Down Expand Up @@ -426,7 +425,7 @@ def myVisit(node, block: Block, ssa_vals={}):
print(f"Call {node.name} instance translated as comment")
return

call = Call.get(call_name, C_names, C_typenames, C_typeqs, prefix, retval)
call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval)
block.add_ops([call])

print(f"Call {node.name} translated")
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/ietxdsl/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def match_and_rewrite(self, op: memref.Store, rewriter: PatternRewriter,
ssa_indices=[idx],
result_type=llvm.LLVMPointerType.typed(op.memref.memref.element_type)
),
store := llvm.StoreOp.get(op.value, gep),
store := llvm.StoreOp(op.value, gep),
],
[],
)
Expand Down
4 changes: 3 additions & 1 deletion devito/mpi/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None):
# mpi4py takes care of that when the object gets out of scope
self._input_comm = (input_comm or MPI.COMM_WORLD).Clone()

topology = ('*', '*', 1)
if len(shape) == 3:
topology = ('*', '*', 1)

if topology is None:
# `MPI.Compute_dims` sets the dimension sizes to be as close to each other
# as possible, using an appropriate divisibility algorithm. Thus, in 3D:
Expand Down
20 changes: 14 additions & 6 deletions fast/wave2d.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Based on the implementation of the Devito acoustic example implementation
# Not using Devito's source injection abstraction
import sys

import numpy as np
from devito import TimeFunction, Eq, Operator, solve, norm, XDSLOperator

from devito import (TimeFunction, Eq, Operator, solve, norm,
XDSLOperator, configuration)
from examples.seismic import RickerSource
from examples.seismic import Model, TimeAxis

Expand Down Expand Up @@ -92,7 +95,7 @@ def plot_2dfunc(u):

# Define the wavefield with the size of the model and the time dimension
u = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so)

# Another one to clone data
u2 = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so)

# We can now write the PDE
Expand All @@ -101,18 +104,19 @@ def plot_2dfunc(u):
pde = u.dt2 - u.laplace

# The PDE representation is as on paper
pde
# pde

stencil = Eq(u.forward, solve(pde, u.forward))
stencil
# stencil

# Finally we define the source injection and receiver read function to generate
# the corresponding code
print(time_range)

print("Init norm:", norm(u))
print("Init norm:", np.linalg.norm(u.data[:]))
src_term = src.inject(field=u.forward, expr=src * dt**2 / model.m)
op0 = Operator([stencil] + src_term, subs=model.spacing_map, name='SourceDevitoOperator')

# Run with source and plot
op0.apply(time=time_range.num-1, dt=model.critical_dt)

Expand All @@ -125,8 +129,10 @@ def plot_2dfunc(u):
print("Init Devito linalg norm 2 :", np.linalg.norm(u.data[2]))

print("Norm of initial data:", norm(u))
# import pdb;pdb.set_trace()

configuration['mpi'] = 0
u2.data[:] = u.data[:]
configuration['mpi'] = 'basic'

# Run more with no sources now (Not supported in xdsl)
op1 = Operator([stencil], name='DevitoOperator')
Expand All @@ -146,7 +152,9 @@ def plot_2dfunc(u):


# Reset initial data
configuration['mpi'] = 0
u.data[:] = u2.data[:]
configuration['mpi'] = 'basic'
#v[:, ..., :] = 1


Expand Down
19 changes: 9 additions & 10 deletions fast/wave3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# Not using Devito's source injection abstraction
import sys
import numpy as np
from devito import TimeFunction, Eq, Operator, solve, norm, XDSLOperator
from devito import (TimeFunction, Eq, Operator, solve, norm,
XDSLOperator, configuration)
from examples.seismic import RickerSource
from examples.seismic import Model, TimeAxis

Expand Down Expand Up @@ -124,9 +125,10 @@ def plot_3dfunc(u):
print("Init linalg norm 1 :", np.linalg.norm(u.data[1]))
print("Init linalg norm 2 :", np.linalg.norm(u.data[2]))

print("Norm of initial data:", norm(u))
import pdb;pdb.set_trace()
print("Norm of initial data:", np.linalg.norm(u.data[:]))
configuration['mpi'] = 0
u2.data[:] = u.data[:]
configuration['mpi'] = 'basic'

# Run more with no sources now (Not supported in xdsl)
op1 = Operator([stencil], name='DevitoOperator')
Expand All @@ -142,15 +144,13 @@ def plot_3dfunc(u):
print("Devito linalg norm 1:", np.linalg.norm(u.data[1]))
print("Devito linalg norm 2:", np.linalg.norm(u.data[2]))

import pdb;pdb.set_trace()


# Reset initial data
configuration['mpi'] = 0
u.data[:] = u2.data[:]
configuration['mpi'] = 'basic'
#v[:, ..., :] = 1


print("Reinitialise data: Devito norm:", norm(u))
print("Reinitialise data: Devito norm:", np.linalg.norm(u.data[:]))
print("Init XDSL linalg norm:", np.linalg.norm(u.data[0]))
print("Init XDSL linalg norm:", np.linalg.norm(u.data[1]))
print("Init XDSL linalg norm:", np.linalg.norm(u.data[2]))
Expand All @@ -160,10 +160,9 @@ def plot_3dfunc(u):
xdslop.apply(time=time_range.num-1, dt=model.critical_dt)

xdsl_output = u.copy()

print("XDSL norm:", norm(u))
print(f"xdsl output norm: {norm(xdsl_output)}")
import pdb;pdb.set_trace()

print("XDSL output linalg norm:", np.linalg.norm(u.data[0]))
print("XDSL output linalg norm:", np.linalg.norm(u.data[1]))
print("XDSL output linalg norm:", np.linalg.norm(u.data[2]))

0 comments on commit bacf1af

Please sign in to comment.