From a01aca4629aa30c213729e66c39ffc05f3105169 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Wed, 2 Aug 2023 01:32:40 +0100 Subject: [PATCH] Rather make a wrapper func, so we have a func just taking the buffers. --- devito/ir/ietxdsl/cluster_to_ssa.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/devito/ir/ietxdsl/cluster_to_ssa.py b/devito/ir/ietxdsl/cluster_to_ssa.py index 86bb434ef41..bc65408da05 100644 --- a/devito/ir/ietxdsl/cluster_to_ssa.py +++ b/devito/ir/ietxdsl/cluster_to_ssa.py @@ -328,18 +328,21 @@ def is_float(val: SSAValue): @dataclass class WrapFunctionWithTransfers(RewritePattern): func_name: str - seen_ops: set[func.FuncOp] = field(default_factory=set) + done: bool = field(default=False) @op_type_rewrite_pattern def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): - if op.sym_name.data != self.func_name or op in self.seen_ops: + if op.sym_name.data != self.func_name or self.done: return - self.seen_ops.add(op) + self.done = True + + op.sym_name = builtin.StringAttr("gpu_kernel") print("Doing GPU STUFF") # GPU STUFF - body = op.body.block - for arg in op.body.block.args: - print("ARG") + wrapper = func.FuncOp(self.func_name, op.function_type, Region(Block([func.Return()], arg_types=op.function_type.inputs))) + body = wrapper.body.block + wrapper.body.block.insert_op_before(func.Call("gpu_kernel", body.args, []), body.last_op) + for arg in wrapper.args: shapetype = arg.type if isinstance(shapetype, stencil.FieldType): memref_type = memref.MemRefType.from_element_type_and_shape(shapetype.get_element_type(), shapetype.get_shape()) @@ -353,6 +356,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): copy_out = gpu.MemcpyOp(source=alloc, destination=incast) dealloc = gpu.DeallocOp(alloc) body.insert_ops_before([copy_out, dealloc], body.ops.last) + rewriter.insert_op_after_matched_op(wrapper) @dataclass class MakeFunctionTimed(RewritePattern): """