Skip to content

Commit

Permalink
Adapt codegen.
Browse files Browse the repository at this point in the history
All time_buffers are swapped as buffer, but we add an iteration argument for the output buffer to trick the stencil dialect.
Instead of load from all non-output fields, now load from all input fields.
  • Loading branch information
PapyChacal committed Feb 25, 2024
1 parent 6fb104a commit 2ee3853
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def _convert_eq(self, eq: LoweredEq):
output_time_offset = (eq.lhs.indices[step_dim] - step_dim) % eq.lhs.function.time_size
self.out_time_buffer = (output_function, output_time_offset)


# For each used function, define as many fields as its time_size
fields_types = [field_from_function(f) for (f, _) in self.time_buffers]
# Create a function with the fields as arguments
Expand All @@ -103,7 +102,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, output_indexed:I
# Handle Indexeds
if isinstance(node, Indexed):
space_offsets = [node.indices[d] - output_indexed.indices[d] for d in node.function.space_dimensions]
# import pdb; pdb.set_trace()

temp = self.apply_temps[(node.function, (node.indices[dim] - dim) % node.function.time_size)]
access = stencil.AccessOp.get(temp, space_offsets)
return access.res
Expand Down Expand Up @@ -166,10 +165,14 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, output_indexed:I
raise NotImplementedError(f"Unknown math: {node}", node)

def _build_step_body(self, dim: SteppingDimension, eq:LoweredEq) -> None:

input_indexeds = retrieve_indexed(eq.rhs)
input_time_buffers = OrderedSet(*((f.function, (f.indices[dim] - dim) % f.function.time_size) for f in input_indexeds))

loop_temps = {
(f, t): stencil.LoadOp.get(a).res
for (f, t), a in self.block_args.items()
if (f, t) != self.out_time_buffer
if (f, t) in input_time_buffers
}
for (f,t), a in loop_temps.items():
a.name_hint = f"{f.name}_t{t}_temp"
Expand All @@ -185,13 +188,13 @@ def _build_step_body(self, dim: SteppingDimension, eq:LoweredEq) -> None:
apply_arg.name_hint = apply_op.name_hint[:-4]+"_blk"

self.apply_temps = {k:v for k,v in zip(loop_temps.keys(), apply.region.block.args)}

with ImplicitBuilder(apply.region.block):
stencil.ReturnOp.get([self._visit_math_nodes(dim, eq.rhs, eq.lhs)])
# TODO Think about multiple outputs
stencil.StoreOp.get(
apply.res[0],
self.block_args[self.out_time_buffer],
self.out_block_arg,
stencil.IndexAttr.get(*([0] * len(shape))),
stencil.IndexAttr.get(*shape),
)
Expand All @@ -213,19 +216,26 @@ def _build_step_loop(
except:
raise ValueError("step must be int!")

iter_args = list(self.function_args.values())
iter_args = [self.function_args[self.out_time_buffer], *self.function_args.values()]
# Create the for loop
loop = scf.For(lb, arith.Addi(ub, one), step, iter_args, Block(arg_types=[builtin.IndexType(), *(a.type for a in iter_args)]))
loop.body.block.args[0].name_hint = "time"

self.block_args = {(f,t) : loop.body.block.args[1+i] for i, (f,t) in enumerate(self.time_buffers)}
self.block_args = {(f,t) : loop.body.block.args[2+i] for i, (f,t) in enumerate(self.time_buffers)}
for ((f,t), arg) in self.block_args.items():
arg.name_hint = f"{f.name}_t{t}"

self.out_block_arg = loop.body.block.args[1]
(of, ot) = self.out_time_buffer

yield_args = [self.block_args[(of, (ot + 1) % of.time_size)],
*(self.block_args[(f, (t + 1) % f.time_size)]
for (f, t) in self.block_args.keys())
]

with ImplicitBuilder(loop.body.block):
self._build_step_body(dim, eq)
# Swap buffers through scf.yield
yield_args = [self.block_args[(f, (t+1)%f.time_size)] for (f, t) in self.block_args.keys()]

scf.Yield(*yield_args)

def convert(self, eqs: Iterable[Eq]) -> builtin.ModuleOp:
Expand Down

0 comments on commit 2ee3853

Please sign in to comment.