From a49422fb44547baa1ef5482080efa3535dd44fe4 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Sun, 25 Feb 2024 11:53:54 +0000 Subject: [PATCH] Reuse state to simplify signatures, also more Pythonic constructs. --- devito/ir/ietxdsl/cluster_to_ssa.py | 112 +++++++++++----------------- 1 file changed, 44 insertions(+), 68 deletions(-) diff --git a/devito/ir/ietxdsl/cluster_to_ssa.py b/devito/ir/ietxdsl/cluster_to_ssa.py index bf37850a95..5fcbb404c5 100644 --- a/devito/ir/ietxdsl/cluster_to_ssa.py +++ b/devito/ir/ietxdsl/cluster_to_ssa.py @@ -63,7 +63,7 @@ def _convert_eq(self, eq: LoweredEq): # an equation may look like this: # u[x+1,y+1,z] = (u[x,y,z+1] + u[x+2,y+2,z+1]) / 2 if isinstance(eq.lhs, Symbol): - return func.FuncOp.external(eq.lhs.name, [], [builtin.i32]) + return xdsl_func.FuncOp.external(eq.lhs.name, [], [builtin.i32]) assert isinstance(eq.lhs, Indexed) # u(t, x, y) @@ -73,30 +73,34 @@ def _convert_eq(self, eq: LoweredEq): # Get all functions used in the equation functions = OrderedSet(*(f.function for f in retrieve_function_carriers(eq))) + # We identify time buffers by their function and positive time offset. + # We store a list of those here to help the following steps. + self.time_buffers = [(f, i) for f in functions for i in range(f.time_size)] # For each used function, define as many fields as its time_size - fields = [[field_from_function(f)]*f.time_size for f in functions] - # Flatten the list for function signature - flat_fields = [f for ff in fields for f in ff] + fields_types = [field_from_function(f) for (f, _) in self.time_buffers] # Create a function with the fields as arguments - f = func.FuncOp("apply_kernel", ([*flat_fields], [])) + xdsl_func = func.FuncOp("apply_kernel", (fields_types, [])) # Define nice argument names to try and stay sane while debugging - arg_names = [f"{f.name}_vec_{i}" for f in functions for i in range(f.time_size)] - for i, arg_name in enumerate(arg_names): - f.body.block.args[i].name_hint = arg_name - - with ImplicitBuilder(f.body.block): - self._build_step_loop(step_dim, functions, eq) + # And store in self.function_args a mapping from time_buffers to their + # corresponding function arguments. + self.function_args = {} + for i, (f, t) in enumerate(self.time_buffers): + xdsl_func.body.block.args[i].name_hint = f"{f.name}_vec_{t}" + self.function_args[(f,t)] = xdsl_func.body.block.args[i] + + with ImplicitBuilder(xdsl_func.body.block): + self._build_step_loop(step_dim, eq) # func wants a return func.Return() - - return f - def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, temps: dict[tuple[DiscreteFunction, int], SSAValue], output_indexed:Indexed) -> SSAValue: + + return xdsl_func + def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, output_indexed:Indexed) -> SSAValue: # Handle Indexeds if isinstance(node, Indexed): space_offsets = [node.indices[d] - output_indexed.indices[d] for d in node.function.space_dimensions] # import pdb; pdb.set_trace() - temp = temps[(node.function, (node.indices[dim] - dim) % node.function.time_size)] + temp = self.apply_temps[(node.function, (node.indices[dim] - dim) % node.function.time_size)] access = stencil.AccessOp.get(temp, space_offsets) return access.res # Handle Integers @@ -113,7 +117,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, temps: dict[tupl return symb.result # Handle Add Mul elif isinstance(node, (Add, Mul)): - args = [self._visit_math_nodes(dim, arg, temps, output_indexed) for arg in node.args] + args = [self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args] # add casts when necessary # get first element out, store the rest in args # this makes the reduction easier @@ -131,7 +135,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, temps: dict[tupl return carry # Handle Pow elif isinstance(node, Pow): - args = [self._visit_math_nodes(dim, arg, temps, output_indexed) for arg in node.args] + args = [self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args] assert len(args) == 2, "can't pow with != 2 args!" base, ex = args if is_int(base): @@ -157,47 +161,36 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, temps: dict[tupl else: raise NotImplementedError(f"Unknown math: {node}", node) - - def _build_step_body(self, dim: SteppingDimension, functions:Iterable[DiscreteFunction], eq:LoweredEq) -> None: - iter_args = ImplicitBuilder.get().insertion_point.block.args[1:] - + def _build_step_body(self, dim: SteppingDimension, eq:LoweredEq) -> None: output_function = eq.lhs.function output_time_offset = (eq.lhs.indices[dim] - dim) % eq.lhs.function.time_size - function_fields : dict[tuple[DiscreteFunction, int], SSAValue] = {} - handled_args = 0 - for f in functions: - for i in range(f.time_size): - function_fields[(f, i)] = iter_args[handled_args + i] - handled_args += f.time_size - - function_temps : dict[tuple[DiscreteFunction, int], SSAValue] = {} - for (function, time_offset), field in function_fields.items(): - # import pdb; pdb.set_trace() - if (function, time_offset) == (output_function, output_time_offset): - continue - load = stencil.LoadOp.get(field) - load.res.name_hint = f"{function.name}_t{time_offset}_temp" + loop_temps = { + (f, t): stencil.LoadOp.get(a).res + for (f, t), a in self.block_args.items() + if (f, t) != (output_function, output_time_offset) + } + for (f,t), a in loop_temps.items(): + a.name_hint = f"{f.name}_t{t}_temp" - function_temps[(function, time_offset)] = load.res shape = output_function.grid.shape_local apply = stencil.ApplyOp.get( - function_temps.values(), - Block(arg_types=[a.type for a in function_temps.values()]), + loop_temps.values(), + Block(arg_types=[a.type for a in loop_temps.values()]), result_types=[stencil.TempType(len(shape), element_type=dtypes_to_xdsltypes[output_function.dtype])] ) for apply_arg, apply_op in zip(apply.region.block.args, apply.operands): apply_arg.name_hint = apply_op.name_hint[:-4]+"_blk" - apply_temps = {k:v for k,v in zip(function_temps.keys(), apply.region.block.args)} + self.apply_temps = {k:v for k,v in zip(loop_temps.keys(), apply.region.block.args)} with ImplicitBuilder(apply.region.block): - stencil.ReturnOp.get([self._visit_math_nodes(dim, eq.rhs, apply_temps, eq.lhs)]) + stencil.ReturnOp.get([self._visit_math_nodes(dim, eq.rhs, eq.lhs)]) # TODO Think about multiple outputs stencil.StoreOp.get( apply.res[0], - function_fields[output_function, output_time_offset], + self.block_args[output_function, output_time_offset], stencil.IndexAttr.get(*([0] * len(shape))), stencil.IndexAttr.get(*shape), ) @@ -205,7 +198,6 @@ def _build_step_body(self, dim: SteppingDimension, functions:Iterable[DiscreteFu def _build_step_loop( self, dim: SteppingDimension, - functions: Iterable[DiscreteFunction], eq: LoweredEq, ) -> scf.For: # Bounds and step boilerpalte @@ -220,34 +212,19 @@ def _build_step_loop( except: raise ValueError("step must be int!") - # Get the function arguments for iteration arguments - func_args = ImplicitBuilder.get().insertion_point.block.args - + iter_args = list(self.function_args.values()) # Create the for loop - loop = scf.For(lb, arith.Addi(ub, one), step, func_args, Block(arg_types=[builtin.IndexType()] + [a.type for a in func_args])) + loop = scf.For(lb, arith.Addi(ub, one), step, iter_args, Block(arg_types=[builtin.IndexType(), *(a.type for a in iter_args)])) loop.body.block.args[0].name_hint = "time" - handled_args = 0 - for f in functions: - for i in range(f.time_size): - loop.body.block.args[1+handled_args+i].name_hint = f"{f.name}_t{i}" - handled_args += f.time_size - with ImplicitBuilder(loop.body.block): - self._build_step_body(dim, functions, eq) - - # Compute the yield order to implement the buffer swap in MLIR - yield_args = [] - # Skip the induction variable - handled_args = 1 - # For each Devito function - for f in functions: - # Get their corresponding iteration fields - fargs = loop.body.block.args[handled_args:handled_args + f.time_size] - # Yield them swapped - yield_args += fargs[1:] - yield_args.append(fargs[0]) - - handled_args += f.time_size + self.block_args = {(f,t) : loop.body.block.args[1+i] for i, (f,t) in enumerate(self.time_buffers)} + for ((f,t), arg) in self.block_args.items(): + arg.name_hint = f"{f.name}_t{t}" + + with ImplicitBuilder(loop.body.block): + self._build_step_body(dim, eq) + # Swap buffers through scf.yield + yield_args = [self.block_args[(f, (t+1)%f.time_size)] for (f, t) in self.block_args.keys()] scf.Yield(*yield_args) def convert(self, eqs: Iterable[Eq]) -> builtin.ModuleOp: @@ -256,7 +233,6 @@ def convert(self, eqs: Iterable[Eq]) -> builtin.ModuleOp: Region([Block([self._convert_eq(eq) for eq in eqs])]) ) - def _ensure_same_type(self, *vals: SSAValue): if all(isinstance(val.type, builtin.IntegerAttr) for val in vals): return vals