Skip to content

Commit

Permalink
Rather make a wrapper func, so we have a func just taking the buffers.
Browse files Browse the repository at this point in the history
  • Loading branch information
PapyChacal committed Aug 2, 2023
1 parent 4713098 commit a01aca4
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,18 +328,21 @@ def is_float(val: SSAValue):
@dataclass
class WrapFunctionWithTransfers(RewritePattern):
func_name: str
seen_ops: set[func.FuncOp] = field(default_factory=set)
done: bool = field(default=False)

@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
if op.sym_name.data != self.func_name or op in self.seen_ops:
if op.sym_name.data != self.func_name or self.done:
return
self.seen_ops.add(op)
self.done = True

op.sym_name = builtin.StringAttr("gpu_kernel")
print("Doing GPU STUFF")
# GPU STUFF
body = op.body.block
for arg in op.body.block.args:
print("ARG")
wrapper = func.FuncOp(self.func_name, op.function_type, Region(Block([func.Return()], arg_types=op.function_type.inputs)))
body = wrapper.body.block
wrapper.body.block.insert_op_before(func.Call("gpu_kernel", body.args, []), body.last_op)
for arg in wrapper.args:
shapetype = arg.type
if isinstance(shapetype, stencil.FieldType):
memref_type = memref.MemRefType.from_element_type_and_shape(shapetype.get_element_type(), shapetype.get_shape())
Expand All @@ -353,6 +356,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
copy_out = gpu.MemcpyOp(source=alloc, destination=incast)
dealloc = gpu.DeallocOp(alloc)
body.insert_ops_before([copy_out, dealloc], body.ops.last)
rewriter.insert_op_after_matched_op(wrapper)
@dataclass
class MakeFunctionTimed(RewritePattern):
"""
Expand Down

0 comments on commit a01aca4

Please sign in to comment.