From bacf1afa49934adef7745280ef5adb52026faf46 Mon Sep 17 00:00:00 2001 From: George Bisbas Date: Wed, 2 Aug 2023 18:55:51 +0300 Subject: [PATCH] mpi: Init effort for serial modelling on wave operator --- devito/ir/ietxdsl/cluster_to_ssa.py | 8 ++++---- devito/ir/ietxdsl/ietxdsl_functions.py | 13 ++++++------- devito/ir/ietxdsl/lowering.py | 2 +- devito/mpi/distributed.py | 4 +++- fast/wave2d.py | 20 ++++++++++++++------ fast/wave3d.py | 19 +++++++++---------- 6 files changed, 37 insertions(+), 29 deletions(-) diff --git a/devito/ir/ietxdsl/cluster_to_ssa.py b/devito/ir/ietxdsl/cluster_to_ssa.py index 2ffb0d1887..1cb9c347d3 100644 --- a/devito/ir/ietxdsl/cluster_to_ssa.py +++ b/devito/ir/ietxdsl/cluster_to_ssa.py @@ -272,7 +272,7 @@ def _ensure_same_type(self, *vals: SSAValue): new_vals.append(val) continue # insert an integer to float cast op - conv = arith.SIToFPOp.get(val, builtin.f32) + conv = arith.SIToFPOp(val, builtin.f32) self.block.add_op(conv) new_vals.append(conv.result) return new_vals @@ -340,7 +340,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): self.seen_ops.add(op) rewriter.insert_op_at_start([ - t0 := func.Call.get('timer_start', [], [builtin.f64]) + t0 := func.Call('timer_start', [], [builtin.f64]) ], op.body.block) ret = op.get_return_op() @@ -348,8 +348,8 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): rewriter.insert_op_before([ timers := iet_ssa.LoadSymbolic.get('timers', llvm.LLVMPointerType.typed(builtin.f64)), - t1 := func.Call.get('timer_end', [t0], [builtin.f64]), - llvm.StoreOp.get(t1, timers), + t1 := func.Call('timer_end', [t0], [builtin.f64]), + llvm.StoreOp(t1, timers), ], ret) rewriter.insert_op_after_matched_op([ diff --git a/devito/ir/ietxdsl/ietxdsl_functions.py b/devito/ir/ietxdsl/ietxdsl_functions.py index 20a4d105d8..b4a9498fb1 100644 --- a/devito/ir/ietxdsl/ietxdsl_functions.py +++ b/devito/ir/ietxdsl/ietxdsl_functions.py @@ -20,12 +20,11 @@ # XDSL specific imports from xdsl.irdl import AnyOf, Operation, SSAValue from xdsl.dialects.builtin import (ContainerOf, Float16Type, Float32Type, - Float64Type, Builtin, i32, f32) + Float64Type, i32, f32) -from xdsl.dialects.arith import Muli, Addi from devito.ir.ietxdsl import iet_ssa -from xdsl.dialects import memref, arith, builtin, llvm +from xdsl.dialects import memref, arith, builtin from xdsl.dialects.experimental import math import devito.types @@ -74,7 +73,7 @@ def print_calls(cgen, calldefs): print("Call not translated in calldefs") return - call = Call.get(call_name, C_names, C_typenames, C_typeqs, prefix, retval) + call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval) cgen.printCall(call, True) @@ -180,10 +179,10 @@ def add_to_block(expr, arg_by_expr: dict[Any, Operation], result): # reconcile differences if isinstance(rhs.typ, builtin.IntegerType): - rhs = arith.SIToFPOp.get(rhs, lhs.typ) + rhs = arith.SIToFPOp(rhs, lhs.typ) result.append(rhs) else: - lhs = arith.SIToFPOp.get(lhs, rhs.typ) + lhs = arith.SIToFPOp(lhs, rhs.typ) result.append(lhs) @@ -426,7 +425,7 @@ def myVisit(node, block: Block, ssa_vals={}): print(f"Call {node.name} instance translated as comment") return - call = Call.get(call_name, C_names, C_typenames, C_typeqs, prefix, retval) + call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval) block.add_ops([call]) print(f"Call {node.name} translated") diff --git a/devito/ir/ietxdsl/lowering.py b/devito/ir/ietxdsl/lowering.py index 515fe2f7a1..8ccf6b6d25 100644 --- a/devito/ir/ietxdsl/lowering.py +++ b/devito/ir/ietxdsl/lowering.py @@ -371,7 +371,7 @@ def match_and_rewrite(self, op: memref.Store, rewriter: PatternRewriter, ssa_indices=[idx], result_type=llvm.LLVMPointerType.typed(op.memref.memref.element_type) ), - store := llvm.StoreOp.get(op.value, gep), + store := llvm.StoreOp(op.value, gep), ], [], ) diff --git a/devito/mpi/distributed.py b/devito/mpi/distributed.py index 5b0c890ebe..50577659cb 100644 --- a/devito/mpi/distributed.py +++ b/devito/mpi/distributed.py @@ -196,7 +196,9 @@ def __init__(self, shape, dimensions, input_comm=None, topology=None): # mpi4py takes care of that when the object gets out of scope self._input_comm = (input_comm or MPI.COMM_WORLD).Clone() - topology = ('*', '*', 1) + if len(shape) == 3: + topology = ('*', '*', 1) + if topology is None: # `MPI.Compute_dims` sets the dimension sizes to be as close to each other # as possible, using an appropriate divisibility algorithm. Thus, in 3D: diff --git a/fast/wave2d.py b/fast/wave2d.py index 409bbfb33b..4d5a7eed5b 100644 --- a/fast/wave2d.py +++ b/fast/wave2d.py @@ -1,8 +1,11 @@ # Based on the implementation of the Devito acoustic example implementation # Not using Devito's source injection abstraction import sys + import numpy as np -from devito import TimeFunction, Eq, Operator, solve, norm, XDSLOperator + +from devito import (TimeFunction, Eq, Operator, solve, norm, + XDSLOperator, configuration) from examples.seismic import RickerSource from examples.seismic import Model, TimeAxis @@ -92,7 +95,7 @@ def plot_2dfunc(u): # Define the wavefield with the size of the model and the time dimension u = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so) - +# Another one to clone data u2 = TimeFunction(name="u", grid=model.grid, time_order=to, space_order=so) # We can now write the PDE @@ -101,18 +104,19 @@ def plot_2dfunc(u): pde = u.dt2 - u.laplace # The PDE representation is as on paper -pde +# pde stencil = Eq(u.forward, solve(pde, u.forward)) -stencil +# stencil # Finally we define the source injection and receiver read function to generate # the corresponding code print(time_range) -print("Init norm:", norm(u)) +print("Init norm:", np.linalg.norm(u.data[:])) src_term = src.inject(field=u.forward, expr=src * dt**2 / model.m) op0 = Operator([stencil] + src_term, subs=model.spacing_map, name='SourceDevitoOperator') + # Run with source and plot op0.apply(time=time_range.num-1, dt=model.critical_dt) @@ -125,8 +129,10 @@ def plot_2dfunc(u): print("Init Devito linalg norm 2 :", np.linalg.norm(u.data[2])) print("Norm of initial data:", norm(u)) -# import pdb;pdb.set_trace() + +configuration['mpi'] = 0 u2.data[:] = u.data[:] +configuration['mpi'] = 'basic' # Run more with no sources now (Not supported in xdsl) op1 = Operator([stencil], name='DevitoOperator') @@ -146,7 +152,9 @@ def plot_2dfunc(u): # Reset initial data +configuration['mpi'] = 0 u.data[:] = u2.data[:] +configuration['mpi'] = 'basic' #v[:, ..., :] = 1 diff --git a/fast/wave3d.py b/fast/wave3d.py index abf70d3cdc..1497d3bea2 100644 --- a/fast/wave3d.py +++ b/fast/wave3d.py @@ -2,7 +2,8 @@ # Not using Devito's source injection abstraction import sys import numpy as np -from devito import TimeFunction, Eq, Operator, solve, norm, XDSLOperator +from devito import (TimeFunction, Eq, Operator, solve, norm, + XDSLOperator, configuration) from examples.seismic import RickerSource from examples.seismic import Model, TimeAxis @@ -124,9 +125,10 @@ def plot_3dfunc(u): print("Init linalg norm 1 :", np.linalg.norm(u.data[1])) print("Init linalg norm 2 :", np.linalg.norm(u.data[2])) -print("Norm of initial data:", norm(u)) -import pdb;pdb.set_trace() +print("Norm of initial data:", np.linalg.norm(u.data[:])) +configuration['mpi'] = 0 u2.data[:] = u.data[:] +configuration['mpi'] = 'basic' # Run more with no sources now (Not supported in xdsl) op1 = Operator([stencil], name='DevitoOperator') @@ -142,15 +144,13 @@ def plot_3dfunc(u): print("Devito linalg norm 1:", np.linalg.norm(u.data[1])) print("Devito linalg norm 2:", np.linalg.norm(u.data[2])) -import pdb;pdb.set_trace() - - # Reset initial data +configuration['mpi'] = 0 u.data[:] = u2.data[:] +configuration['mpi'] = 'basic' #v[:, ..., :] = 1 - -print("Reinitialise data: Devito norm:", norm(u)) +print("Reinitialise data: Devito norm:", np.linalg.norm(u.data[:])) print("Init XDSL linalg norm:", np.linalg.norm(u.data[0])) print("Init XDSL linalg norm:", np.linalg.norm(u.data[1])) print("Init XDSL linalg norm:", np.linalg.norm(u.data[2])) @@ -160,10 +160,9 @@ def plot_3dfunc(u): xdslop.apply(time=time_range.num-1, dt=model.critical_dt) xdsl_output = u.copy() + print("XDSL norm:", norm(u)) print(f"xdsl output norm: {norm(xdsl_output)}") -import pdb;pdb.set_trace() - print("XDSL output linalg norm:", np.linalg.norm(u.data[0])) print("XDSL output linalg norm:", np.linalg.norm(u.data[1])) print("XDSL output linalg norm:", np.linalg.norm(u.data[2]))