From 83a5d67c18db5974cc828b7f3bae4883202090bd Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Sun, 25 Feb 2024 11:57:47 +0000 Subject: [PATCH] State for output time buffer. --- devito/ir/ietxdsl/cluster_to_ssa.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/devito/ir/ietxdsl/cluster_to_ssa.py b/devito/ir/ietxdsl/cluster_to_ssa.py index 5fcbb404c5..e1c08a51f7 100644 --- a/devito/ir/ietxdsl/cluster_to_ssa.py +++ b/devito/ir/ietxdsl/cluster_to_ssa.py @@ -76,6 +76,10 @@ def _convert_eq(self, eq: LoweredEq): # We identify time buffers by their function and positive time offset. # We store a list of those here to help the following steps. self.time_buffers = [(f, i) for f in functions for i in range(f.time_size)] + # Also store the time buffer used for output in this equation + output_time_offset = (eq.lhs.indices[step_dim] - step_dim) % eq.lhs.function.time_size + self.out_time_buffer = (output_function, output_time_offset) + # For each used function, define as many fields as its time_size fields_types = [field_from_function(f) for (f, _) in self.time_buffers] @@ -162,18 +166,15 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, output_indexed:I raise NotImplementedError(f"Unknown math: {node}", node) def _build_step_body(self, dim: SteppingDimension, eq:LoweredEq) -> None: - output_function = eq.lhs.function - output_time_offset = (eq.lhs.indices[dim] - dim) % eq.lhs.function.time_size - loop_temps = { (f, t): stencil.LoadOp.get(a).res for (f, t), a in self.block_args.items() - if (f, t) != (output_function, output_time_offset) + if (f, t) != self.out_time_buffer } for (f,t), a in loop_temps.items(): a.name_hint = f"{f.name}_t{t}_temp" - + output_function = self.out_time_buffer[0] shape = output_function.grid.shape_local apply = stencil.ApplyOp.get( loop_temps.values(), @@ -190,7 +191,7 @@ def _build_step_body(self, dim: SteppingDimension, eq:LoweredEq) -> None: # TODO Think about multiple outputs stencil.StoreOp.get( apply.res[0], - self.block_args[output_function, output_time_offset], + self.block_args[self.out_time_buffer], stencil.IndexAttr.get(*([0] * len(shape))), stencil.IndexAttr.get(*shape), )