From 06dc480b521ecfdce89d30a290bb1e92ea220c10 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Mon, 19 Aug 2024 12:43:27 +0100 Subject: [PATCH] !!! --- devito/ir/xdsl_iet/cluster_to_ssa.py | 30 ++++++++++++---------------- devito/xdsl_core/xdsl_cpu.py | 17 +++++++++++++--- tests/test_xdsl_base.py | 3 ++- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/devito/ir/xdsl_iet/cluster_to_ssa.py b/devito/ir/xdsl_iet/cluster_to_ssa.py index 37713b5322..48d2489d12 100644 --- a/devito/ir/xdsl_iet/cluster_to_ssa.py +++ b/devito/ir/xdsl_iet/cluster_to_ssa.py @@ -287,7 +287,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, SSAargs = (self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args) return reduce(lambda x, y : arith.AndI(x, y).result, SSAargs) - + # Trigonometric functions elif isinstance(node, sin): assert len(node.args) == 1, "Expected single argument for sin." @@ -298,13 +298,13 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, assert len(node.args) == 1, "Expected single argument for cos." return math.CosOp(self._visit_math_nodes(dim, node.args[0], output_indexed)).result - + elif isinstance(node, tan): assert len(node.args) == 1, "Expected single argument for TanOp." - + return math.TanOp(self._visit_math_nodes(dim, node.args[0], output_indexed)).result - + elif isinstance(node, Relational): if isinstance(node, GreaterThan): mnemonic = "sge" @@ -391,12 +391,10 @@ def build_stencil_step(self, dim: SteppingDimension, eq: LoweredEq) -> None: apply.res[0], self.function_values[self.out_time_buffer], stencil.StencilBoundsAttr(zip(lb, ub)), - stencil.TempType(len(shape), - element_type=dtype_to_xdsltype(write_function.dtype)) ) - - store.temp_with_halo.name_hint = f"{write_function.name}_t{self.out_time_buffer[1]}_temp" # noqa - self.temps[self.out_time_buffer] = store.temp_with_halo + load = stencil.LoadOp.get(self.function_values[self.out_time_buffer]) + load.res.name_hint = f"{write_function.name}_t{self.out_time_buffer[1]}_temp" # noqa + self.temps[self.out_time_buffer] = load.res def build_generic_step_expression(self, dim: SteppingDimension, eq: LoweredEq): # Sources @@ -439,7 +437,6 @@ def build_condition(self, dim: SteppingDimension, eq: BooleanFunction): self.build_generic_step_expression(dim, eq) scf.Yield() - def build_time_loop( self, eqs: list[Any], step_dim: SteppingDimension, **kwargs ): @@ -450,7 +447,7 @@ def build_time_loop( ub = iet_ssa.LoadSymbolic.get( step_dim.symbolic_max._C_name, IndexType() ) - + one = arith.Constant.from_int_and_width(1, IndexType()) # Devito iterates from time_m to time_M *inclusive*, MLIR only takes @@ -497,7 +494,7 @@ def build_time_loop( for i, (f, t) in enumerate(self.time_buffers) } self.function_values |= self.block_args - + # Name the block argument for debugging for (f, t), arg in self.block_args.items(): arg.name_hint = f"{f.name}_t{t}" @@ -513,8 +510,7 @@ def build_time_loop( def lower_devito_Eqs(self, eqs: list[Any], **kwargs): # Lower devito Equations to xDSL - - + for eq in eqs: lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs) if isinstance(eq, Eq): @@ -546,7 +542,7 @@ def _lower_injection(self, eqs: list[LoweredEq]): lb = arith.Constant.from_int_and_width(int(lower), IndexType()) else: raise NotImplementedError(f"Lower bound of type {type(lower)} not supported") - + try: name = interval.dim.symbolic_min.name except: @@ -633,7 +629,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp: # Instantiate the module. self.function_values: dict[tuple[Function, int], SSAValue] = {} self.symbol_values: dict[str, SSAValue] = {} - + module = ModuleOp(Region([block := Block([])])) with ImplicitBuilder(block): # Get all functions used in the equations @@ -647,7 +643,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp: functions.add(f.function) elif isinstance(eq, Injection): - + functions.add(eq.field.function) for f in retrieve_functions(eq.expr): if isinstance(f, PointSource): diff --git a/devito/xdsl_core/xdsl_cpu.py b/devito/xdsl_core/xdsl_cpu.py index cd86fc0d9c..e56d6a9a5d 100644 --- a/devito/xdsl_core/xdsl_cpu.py +++ b/devito/xdsl_core/xdsl_cpu.py @@ -460,6 +460,7 @@ def _jit_compile(self): # Run the first pipeline, mostly xDSL-centric xdsl_args = [source_name, "--allow-unregistered-dialect", + "--disable-verify", "-p", xdsl_pipeline[1:-1],] # We use the Python API to run xDSL rather than a subprocess @@ -597,7 +598,10 @@ def generate_MLIR_OPENMP_PIPELINE(kwargs): def generate_XDSL_CPU_PIPELINE(nb_tiled_dims): passes = [ - "stencil-shape-inference", + "canonicalize", + "cse", + "shape-inference", + "stencil-bufferize", "convert-stencil-to-ll-mlir", f"scf-parallel-loop-tiling{{{generate_tiling_arg(nb_tiled_dims)}}}", "printf-to-llvm", @@ -609,7 +613,10 @@ def generate_XDSL_CPU_PIPELINE(nb_tiled_dims): def generate_XDSL_CPU_noop_PIPELINE(): passes = [ - "stencil-shape-inference", + "canonicalize", + "cse", + "shape-inference", + "stencil-bufferize", "convert-stencil-to-ll-mlir", "printf-to-llvm" ] @@ -619,11 +626,15 @@ def generate_XDSL_CPU_noop_PIPELINE(): def generate_XDSL_MPI_PIPELINE(decomp, nb_tiled_dims): passes = [ + "canonicalize", + "cse", f"distribute-stencil{decomp}", + "shape-inference", "canonicalize-dmp", + "stencil-bufferize", + "dmp-to-mpi{mpi_init=false}", "convert-stencil-to-ll-mlir", f"scf-parallel-loop-tiling{{{generate_tiling_arg(nb_tiled_dims)}}}", - "dmp-to-mpi{mpi_init=false}", "lower-mpi", "printf-to-llvm", "canonicalize" diff --git a/tests/test_xdsl_base.py b/tests/test_xdsl_base.py index 39aa97828b..72ddfcb16a 100644 --- a/tests/test_xdsl_base.py +++ b/tests/test_xdsl_base.py @@ -73,7 +73,8 @@ def test_xdsl_III(): assert isinstance(scffor_ops[0], LoadOp) assert isinstance(scffor_ops[1], ApplyOp) assert isinstance(scffor_ops[2], StoreOp) - assert isinstance(scffor_ops[3], Yield) + assert isinstance(scffor_ops[3], LoadOp) + assert isinstance(scffor_ops[4], Yield) assert type(ops[7] == Call) assert type(ops[8] == StoreOp)