Skip to content

Commit

Permalink
State for output time buffer.
Browse files Browse the repository at this point in the history
  • Loading branch information
PapyChacal committed Feb 25, 2024
1 parent a49422f commit 83a5d67
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def _convert_eq(self, eq: LoweredEq):
# We identify time buffers by their function and positive time offset.
# We store a list of those here to help the following steps.
self.time_buffers = [(f, i) for f in functions for i in range(f.time_size)]
# Also store the time buffer used for output in this equation
output_time_offset = (eq.lhs.indices[step_dim] - step_dim) % eq.lhs.function.time_size
self.out_time_buffer = (output_function, output_time_offset)


# For each used function, define as many fields as its time_size
fields_types = [field_from_function(f) for (f, _) in self.time_buffers]
Expand Down Expand Up @@ -162,18 +166,15 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, output_indexed:I
raise NotImplementedError(f"Unknown math: {node}", node)

def _build_step_body(self, dim: SteppingDimension, eq:LoweredEq) -> None:
output_function = eq.lhs.function
output_time_offset = (eq.lhs.indices[dim] - dim) % eq.lhs.function.time_size

loop_temps = {
(f, t): stencil.LoadOp.get(a).res
for (f, t), a in self.block_args.items()
if (f, t) != (output_function, output_time_offset)
if (f, t) != self.out_time_buffer
}
for (f,t), a in loop_temps.items():
a.name_hint = f"{f.name}_t{t}_temp"


output_function = self.out_time_buffer[0]
shape = output_function.grid.shape_local
apply = stencil.ApplyOp.get(
loop_temps.values(),
Expand All @@ -190,7 +191,7 @@ def _build_step_body(self, dim: SteppingDimension, eq:LoweredEq) -> None:
# TODO Think about multiple outputs
stencil.StoreOp.get(
apply.res[0],
self.block_args[output_function, output_time_offset],
self.block_args[self.out_time_buffer],
stencil.IndexAttr.get(*([0] * len(shape))),
stencil.IndexAttr.get(*shape),
)
Expand Down

0 comments on commit 83a5d67

Please sign in to comment.