diff --git a/devito/core/__init__.py b/devito/core/__init__.py index ac3772abe61..2fae24864f8 100644 --- a/devito/core/__init__.py +++ b/devito/core/__init__.py @@ -14,6 +14,10 @@ DeviceCustomOmpOperator, DeviceCustomAccOperator) from devito.operator.registry import operator_registry +# Import XDSL Operators +from devito.xdsl_core.xdsl_cpu import XdslnoopOperator, XdslAdvOperator +from devito.xdsl_core.xdsl_gpu import XdslAdvDeviceOperator + # Register CPU Operators operator_registry.add(Cpu64CustomOperator, Cpu64, 'custom', 'C') operator_registry.add(Cpu64CustomOperator, Cpu64, 'custom', 'openmp') @@ -54,3 +58,11 @@ operator_registry.add(DeviceFsgOmpOperator, Device, 'advanced-fsg', 'C') operator_registry.add(DeviceFsgOmpOperator, Device, 'advanced-fsg', 'openmp') operator_registry.add(DeviceFsgAccOperator, Device, 'advanced-fsg', 'openacc') + +# Register XDSL Operators +operator_registry.add(XdslnoopOperator, Cpu64, 'xdsl-noop', 'C') +operator_registry.add(XdslAdvOperator, Cpu64, 'xdsl-noop', 'openmp') + +operator_registry.add(XdslAdvOperator, Cpu64, 'xdsl', 'C') +operator_registry.add(XdslAdvOperator, Cpu64, 'xdsl', 'openmp') +operator_registry.add(XdslAdvDeviceOperator, Device, 'xdsl', 'openacc') diff --git a/devito/ir/xdsl_iet/cluster_to_ssa.py b/devito/ir/xdsl_iet/cluster_to_ssa.py index dd5e7c771eb..a28b13dd85b 100644 --- a/devito/ir/xdsl_iet/cluster_to_ssa.py +++ b/devito/ir/xdsl_iet/cluster_to_ssa.py @@ -1,14 +1,14 @@ from functools import reduce -import numpy as np # ------------- General imports -------------# from typing import Any, Iterable from dataclasses import dataclass, field -from sympy import Add, And, Expr, Float, GreaterThan, Indexed, Integer, LessThan, Mod, Mul, Number, Pow, StrictGreaterThan, StrictLessThan, Symbol, floor +from sympy import (Add, And, Expr, Float, GreaterThan, Indexed, Integer, LessThan, + Number, Pow, StrictGreaterThan, StrictLessThan, Symbol, floor, + Mul) from sympy.core.relational import Relational from sympy.logic.boolalg import BooleanFunction -from devito.ir.equations.equation import OpInc from devito.operations.interpolators import Injection from devito.operator.operator import Operator from devito.symbolics.search import retrieve_dimensions, retrieve_functions @@ -21,8 +21,9 @@ from devito.types.equation import Eq # ------------- xdsl imports -------------# -from xdsl.dialects import (arith, builtin, func, memref, scf, - stencil, gpu) +from xdsl.dialects import arith, func, memref, scf, stencil, gpu, builtin +from xdsl.dialects.builtin import (ModuleOp, UnrealizedConversionCastOp, StringAttr, + IndexType) from xdsl.dialects.experimental import math from xdsl.ir import Block, Operation, OpResult, Region, SSAValue from xdsl.pattern_rewriter import ( @@ -30,8 +31,7 @@ PatternRewriter, PatternRewriteWalker, RewritePattern, - op_type_rewrite_pattern, - InsertPoint + op_type_rewrite_pattern ) from xdsl.builder import ImplicitBuilder from xdsl.transforms.experimental.convert_stencil_to_ll_mlir import StencilToMemRefType @@ -40,14 +40,13 @@ from devito import Grid, SteppingDimension from devito.ir.equations import LoweredEq from devito.symbolics import retrieve_function_carriers -from devito.tools.data_structures import OrderedSet -from devito.types.dense import DiscreteFunction, Function, TimeFunction -from devito.types.equation import Eq -from devito.types.mlir_types import dtype_to_xdsltype +from devito.types.mlir_types import dtype_to_xdsltype, ptr_of, f32 # ------------- devito-xdsl SSA imports -------------# from devito.ir.xdsl_iet import iet_ssa -from devito.ir.xdsl_iet.utils import is_int, is_float +from devito.ir.xdsl_iet.utils import is_int, is_float, dtypes_to_xdsltypes + +from examples.seismic import PointSource # flake8: noqa @@ -84,6 +83,7 @@ def setup_memref_args(functions): return args + class ExtractDevitoStencilConversion: """ Lower Devito equations to the stencil dialect @@ -96,12 +96,6 @@ class ExtractDevitoStencilConversion: symbol_values: dict[str, SSAValue] time_offs: int - def __init__(self): - self.temps = dict() - self.symbol_values = dict() - - time_offs: int - def __init__(self, operator: type[Operator]): self.temps = dict() self.operator = operator @@ -137,13 +131,6 @@ def convert_symbol_eq(self, symbol: Symbol, rhs: LoweredEq, **kwargs): self.symbol_values[symbol.name] = self._visit_math_nodes(None, rhs, None) self.symbol_values[symbol.name].name_hint = symbol.name - def convert_symbol_eq(self, symbol: Symbol, rhs: LoweredEq, **kwargs): - """ - Convert a symbol equation to xDSL. - """ - self.symbol_values[symbol.name] = self._visit_math_nodes(None, rhs, None) - self.symbol_values[symbol.name].name_hint = symbol.name - def _convert_eq(self, eq: LoweredEq, **kwargs): """ # Docs here Need rewriting @@ -208,25 +195,27 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, # If we are in a stencil (encoded by having the output_indexed passed), we # compute the relative space offsets and make it a stencil offset if output_indexed is not None: - space_offsets = [node.indices[d] - output_indexed.indices[d] for d in node.function.space_dimensions] + space_offsets = ([node.indices[d] - output_indexed.indices[d] + for d in node.function.space_dimensions]) temp = self.function_values[(node.function, time_offset)] access = stencil.AccessOp.get(temp, space_offsets) return access.res # Otherwise, generate a load op else: temp = self.function_values[(node.function, time_offset)] - memtemp = builtin.UnrealizedConversionCastOp.get(temp, StencilToMemRefType(temp.type)).results[0] + memreftype = StencilToMemRefType(temp.type) + memtemp = UnrealizedConversionCastOp.get(temp, memreftype).results[0] memtemp.name_hint = temp.name_hint + "_mem" indices = node.indices if isinstance(node.function, TimeFunction): indices = indices[1:] - ssa_indices = [self._visit_math_nodes(dim, i, output_indexed) for i in node.indices] + ssa_indices = ([self._visit_math_nodes(dim, i, output_indexed) + for i in node.indices]) for i, ssa_i in enumerate(ssa_indices): if isinstance(ssa_i.type, builtin.IntegerType): - ssa_indices[i] = arith.IndexCastOp(ssa_i, builtin.IndexType()) + ssa_indices[i] = arith.IndexCastOp(ssa_i, IndexType()) return memref.Load.get(memtemp, ssa_indices).res - import pdb; pdb.set_trace() # Handle Integers elif isinstance(node, Integer): cst = arith.Constant.from_int_and_width(int(node), builtin.i64) @@ -287,13 +276,16 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, # Handle Mod elif isinstance(node, INT): assert len(node.args) == 1, "Expected single argument for integer cast." - return arith.FPToSIOp(self._visit_math_nodes(dim, node.args[0], output_indexed), builtin.i64).result + return arith.FPToSIOp(self._visit_math_nodes(dim, node.args[0], + output_indexed), builtin.i64).result elif isinstance(node, floor): assert len(node.args) == 1, "Expected single argument for floor." - return math.FloorOp(self._visit_math_nodes(dim, node.args[0], output_indexed)).result + op = self._visit_math_nodes(dim, node.args[0], output_indexed) + return math.FloorOp(op).result elif isinstance(node, And): - SSAargs = (self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args) - return reduce(lambda x,y : arith.AndI(x,y).result, SSAargs) + SSAargs = (self._visit_math_nodes(dim, arg, output_indexed) + for arg in node.args) + return reduce(lambda x, y : arith.AndI(x, y).result, SSAargs) elif isinstance(node, Relational): if isinstance(node, GreaterThan): mnemonic = "sge" @@ -318,7 +310,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, else: raise NotImplementedError(f"Unknown math:{type(node)} {node}", node) - def build_stencil_step(self, dim: SteppingDimension, eq:LoweredEq) -> None: + def build_stencil_step(self, dim: SteppingDimension, eq: LoweredEq) -> None: """ Builds the body of the step function for a given dimension and equation. @@ -338,7 +330,6 @@ def build_stencil_step(self, dim: SteppingDimension, eq:LoweredEq) -> None: elif isinstance(f.function, Function): time_offset = 0 else: - import pdb;pdb.set_trace() raise NotImplementedError(f"reading function of type {type(f.function)} not supported") read_functions.add((f.function, time_offset)) @@ -368,7 +359,7 @@ def build_stencil_step(self, dim: SteppingDimension, eq:LoweredEq) -> None: assert "temp" in apply_op.name_hint apply_arg.name_hint = apply_op.name_hint.replace("temp", "blk") - self.apply_temps = {k:v for k,v in zip(read_functions, apply.region.block.args)} + self.apply_temps = {k: v for k, v in zip(read_functions, apply.region.block.args)} # Update the function values with the new temps self.function_values |= self.apply_temps @@ -393,7 +384,7 @@ def build_generic_step_expression(self, dim: SteppingDimension, eq: LoweredEq): # Sources value = self._visit_math_nodes(dim, eq.rhs, None) temp = self.function_values[self.out_time_buffer] - memtemp = builtin.UnrealizedConversionCastOp.get([temp], [StencilToMemRefType(temp.type)]).results[0] + memtemp = UnrealizedConversionCastOp.get([temp], [StencilToMemRefType(temp.type)]).results[0] memtemp.name_hint = temp.name_hint + "_mem" indices = eq.lhs.indices if isinstance(eq.lhs.function, TimeFunction): @@ -401,13 +392,17 @@ def build_generic_step_expression(self, dim: SteppingDimension, eq: LoweredEq): ssa_indices = [self._visit_math_nodes(dim, i, None) for i in indices] for i, ssa_i in enumerate(ssa_indices): if isinstance(ssa_i.type, builtin.IntegerType): - ssa_indices[i] = arith.IndexCastOp(ssa_i, builtin.IndexType()) + ssa_indices[i] = arith.IndexCastOp(ssa_i, IndexType()) match eq.operation: case None: memref.Store.get(value, memtemp, ssa_indices) - case OpInc: - memref.AtomicRMWOp(operands=[value, memtemp, ssa_indices], result_types=[value.type], properties={"kind" : builtin.IntegerAttr(0, builtin.i64)}) + case OpInc: # noqa + # Maybe rename + attr = builtin.IntegerAttr(0, builtin.i64) + memref.AtomicRMWOp(operands=[value, memtemp, ssa_indices], + result_types=[value.type], + properties={"kind": attr}) def build_condition(self, dim: SteppingDimension, eq: BooleanFunction): return self._visit_math_nodes(dim, eq, None) @@ -429,12 +424,12 @@ def build_time_loop( ): # Bounds and step boilerpalte lb = iet_ssa.LoadSymbolic.get( - step_dim.symbolic_min._C_name, builtin.IndexType() + step_dim.symbolic_min._C_name, IndexType() ) ub = iet_ssa.LoadSymbolic.get( - step_dim.symbolic_max._C_name, builtin.IndexType() + step_dim.symbolic_max._C_name, IndexType() ) - one = arith.Constant.from_int_and_width(1, builtin.IndexType()) + one = arith.Constant.from_int_and_width(1, IndexType()) # Devito iterates from time_m to time_M *inclusive*, MLIR only takes # exclusive upper bounds, so we increment here. ub = arith.Addi(ub, one) @@ -442,7 +437,7 @@ def build_time_loop( # Take the exact time_step from Devito try: step = arith.Constant.from_int_and_width( - int(step_dim.symbolic_incr), builtin.IndexType() + int(step_dim.symbolic_incr), IndexType() ) step.result.name_hint = "step" @@ -462,7 +457,7 @@ def build_time_loop( ub, step, iter_args, - Block(arg_types=[builtin.IndexType(), *(a.type for a in iter_args)]), + Block(arg_types=[IndexType(), *(a.type for a in iter_args)]), ) # Name the 'time' step iterator @@ -519,30 +514,45 @@ def _lower_injection(self, eqs: list[LoweredEq]): for interval in ispace[1:]: lower = interval.symbolic_min if isinstance(lower, Scalar): - lb = iet_ssa.LoadSymbolic.get(lower._C_name, builtin.IndexType()) + lb = iet_ssa.LoadSymbolic.get(lower._C_name, IndexType()) elif isinstance(lower, (Number, int)): - lb = arith.Constant.from_int_and_width(int(lower), builtin.IndexType()) + lb = arith.Constant.from_int_and_width(int(lower), IndexType()) else: raise NotImplementedError(f"Lower bound of type {type(lower)} not supported") - lb.result.name_hint = f"{interval.dim.name}_m" + + try: + name = interval.dim.symbolic_min.name + except: + assert interval.dim.symbolic_min.is_integer + name = f"{interval.dim.name}_M" + + lb.result.name_hint = name upper = interval.symbolic_max if isinstance(upper, Scalar): - ub = iet_ssa.LoadSymbolic.get(upper._C_name, builtin.IndexType()) + ub = iet_ssa.LoadSymbolic.get(upper._C_name, IndexType()) elif isinstance(upper, (Number, int)): - ub = arith.Constant.from_int_and_width(int(upper), builtin.IndexType()) + ub = arith.Constant.from_int_and_width(int(upper), IndexType()) else: raise NotImplementedError( f"Upper bound of type {type(upper)} not supported" ) - ub.result.name_hint = f"{interval.dim.name}_M" + + try: + name = interval.dim.symbolic_max.name + except: + assert interval.dim.symbolic_max.is_integer + name = f"{interval.dim.name}_M" + + ub.result.name_hint = name + lbs.append(lb) ubs.append(ub) - steps = [arith.Constant.from_int_and_width(1, builtin.IndexType()).result]*len(ubs) + steps = [arith.Constant.from_int_and_width(1, IndexType()).result]*len(ubs) ubs = [arith.Addi(ub, steps[0]) for ub in ubs] - with ImplicitBuilder(scf.ParallelOp(lbs, ubs, steps, [pblock := Block(arg_types=[builtin.IndexType()]*len(ubs))]).body): + with ImplicitBuilder(scf.ParallelOp(lbs, ubs, steps, [pblock := Block(arg_types=[IndexType()]*len(ubs))]).body): for arg, interval in zip(pblock.args, ispace[1:], strict=True): arg.name_hint = interval.dim.name self.symbol_values[interval.dim.name] = arg @@ -551,7 +561,7 @@ def _lower_injection(self, eqs: list[LoweredEq]): scf.Yield() # raise NotImplementedError("Injections not supported yet") - def convert(self, eqs: Iterable[Eq], **kwargs) -> builtin.ModuleOp: + def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp: """ This converts a Devito Operator, represented here by a list of LoweredEqs, to an xDSL module defining a function implementing it. @@ -568,7 +578,8 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> builtin.ModuleOp: their time sizes. Their sizes are deduced from the Grid. 2. Create a time iteration loop, swapping buffers to implement time buffering. - NB: This needs to be converted to a Cluster conversion soon, which will be more sound. + NB: This needs to be converted to a Cluster conversion soon, + which will be more sound. ```mlir func.func @apply_kernel(%u_vec_0 : !stencil.field<[-1,4]xf32>, %u_vec_1 : !stencil.field<[-1,4]xf32>) { @@ -591,10 +602,12 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> builtin.ModuleOp: Those represents runtime values not yet known that will be JIT-compiled when calling the operator. """ + # Instantiate the module. - self.function_values : dict[tuple[Function, int], SSAValue] = {} - self.symbol_values : dict[str, SSAValue] = {} - module = builtin.ModuleOp(Region([block := Block([])])) + self.function_values: dict[tuple[Function, int], SSAValue] = {} + self.symbol_values: dict[str, SSAValue] = {} + + module = ModuleOp(Region([block := Block([])])) with ImplicitBuilder(block): # Get all functions used in the equations functions = OrderedSet() @@ -617,8 +630,8 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> builtin.ModuleOp: else: raise NotImplementedError(f"Expression {eq} of type {type(eq)} not supported") - self.time_buffers : list[TimeFunction] = [] - self.functions : list[Function] = [] + self.time_buffers: list[TimeFunction] = [] + self.functions: list[Function] = [] for f in functions: match f: case TimeFunction(): @@ -660,12 +673,13 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> builtin.ModuleOp: with ImplicitBuilder(xdsl_func.body.block): # Get the stepping dimension, if there is any in the whole input - time_functions = [f for (f,_) in self.time_buffers] + time_functions = [f for (f, _) in self.time_buffers] dimensions = { d for f in (self.functions + time_functions) for d in f.dimensions } - step_dim = next((d for d in dimensions if isinstance(d, SteppingDimension)), None) + step_dim = next((d for d in dimensions if + isinstance(d, SteppingDimension)), None) if step_dim is not None: self.build_time_loop(eqs, step_dim, **kwargs) else: @@ -679,7 +693,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> builtin.ModuleOp: def _ensure_same_type(self, *vals: SSAValue): if all(isinstance(val.type, builtin.IntegerAttr) for val in vals): return vals - if all(isinstance(val.type, builtin.IndexType) for val in vals): + if all(isinstance(val.type, IndexType) for val in vals): # Sources return vals if all(is_float(val) for val in vals): @@ -693,7 +707,7 @@ def _ensure_same_type(self, *vals: SSAValue): if cast_to_floats and is_float(val): processed.append(val) continue - if (not cast_to_floats) and isinstance(val.type, builtin.IndexType): + if (not cast_to_floats) and isinstance(val.type, IndexType): processed.append(val) continue # if the val is the result of a arith.constant with no uses, @@ -709,17 +723,17 @@ def _ensure_same_type(self, *vals: SSAValue): float(val.op.value.value.data), builtin.f32 ) else: - val.type = builtin.IndexType() - val.op.value.type = builtin.IndexType() + val.type = IndexType() + val.op.value.type = IndexType() processed.append(val) continue # insert a cast op if cast_to_floats: - if val.type == builtin.IndexType(): + if val.type == IndexType(): val = arith.IndexCastOp(val, builtin.i64).result conv = arith.SIToFPOp(val, builtin.f32) else: - conv = arith.IndexCastOp(val, builtin.IndexType()) + conv = arith.IndexCastOp(val, IndexType()) processed.append(conv.result) return processed @@ -747,7 +761,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): return self.done = True - op.sym_name = builtin.StringAttr("gpu_kernel") + op.sym_name = StringAttr("gpu_kernel") print("Doing GPU STUFF") # GPU STUFF wrapper = func.FuncOp(self.func_name, op.function_type, Region(Block([func.Return()], arg_types=op.function_type.inputs))) @@ -755,12 +769,12 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): wrapper.body.block.insert_op_before(func.Call("gpu_kernel", body.args, []), body.last_op) for arg in wrapper.args: shapetype = arg.type - if isinstance(shapetype, stencil.FieldType): + if isinstance(shapetype, stencil.FieldType): memref_type = memref.MemRefType.from_element_type_and_shape(shapetype.get_element_type(), shapetype.get_shape()) alloc = gpu.AllocOp(memref.MemRefType.from_element_type_and_shape(shapetype.get_element_type(), shapetype.get_shape())) - outcast = builtin.UnrealizedConversionCastOp.get(alloc, shapetype) + outcast = UnrealizedConversionCastOp.get(alloc, shapetype) arg.replace_by(outcast.results[0]) - incast = builtin.UnrealizedConversionCastOp.get(arg, memref_type) + incast = UnrealizedConversionCastOp.get(arg, memref_type) copy = gpu.MemcpyOp(source=incast, destination=alloc) body.insert_ops_before([alloc, outcast, incast, copy], body.ops.first) @@ -770,49 +784,6 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): rewriter.insert_op_after_matched_op(wrapper) -class TimerRewritePattern(RewritePattern): - """ - Base class for time benchmarking related rewrite patterns - """ - pass - - -@dataclass -class MakeFunctionTimed(TimerRewritePattern): - """ - Populate the section0 devito timer with the total runtime of the function - """ - func_name: str - seen_ops: set[func.Func] = field(default_factory=set) - - @op_type_rewrite_pattern - def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): - if op.sym_name.data != self.func_name or op in self.seen_ops: - return - - # only apply once - self.seen_ops.add(op) - - # Insert timer start and end calls - rewriter.insert_op([ - t0 := func.Call('timer_start', [], [builtin.f64]) - ], InsertPoint.at_start(op.body.block)) - - ret = op.get_return_op() - assert ret is not None - - rewriter.insert_op_before([ - timers := iet_ssa.LoadSymbolic.get('timers', llvm.LLVMPointerType.opaque()), - t1 := func.Call('timer_end', [t0], [builtin.f64]), - llvm.StoreOp(t1, timers), - ], ret) - - rewriter.insert_op([ - func.FuncOp.external('timer_start', [], [builtin.f64]), - func.FuncOp.external('timer_end', [builtin.f64], [builtin.f64]), - ], InsertPoint.after(rewriter.current_operation)) - - def get_containing_func(op: Operation) -> func.FuncOp | None: while op is not None and not isinstance(op, func.FuncOp): op = op.parent_op() @@ -880,7 +851,7 @@ def match_and_rewrite(self, op: iet_ssa.LoadSymbolic, rewriter: PatternRewriter, parent.update_function_type() -def finalize_module_with_globals(module: builtin.ModuleOp, known_symbols: dict[str, Any], +def finalize_module_with_globals(module: ModuleOp, known_symbols: dict[str, Any], gpu_boilerplate): """ This function finalizes a module by replacing all symbolic constants with their diff --git a/devito/xdsl_core/__init__.py b/devito/xdsl_core/__init__.py index cc4df2c8f0d..3f0f0bffc85 100644 --- a/devito/xdsl_core/__init__.py +++ b/devito/xdsl_core/__init__.py @@ -1,14 +1,4 @@ -from devito.arch import Cpu64, Device +from .xdsl_cpu import * +from .xdsl_gpu import * -from devito.xdsl_core.xdsl_cpu import XdslnoopOperator, XdslAdvOperator - -from devito.xdsl_core.xdsl_gpu import XdslAdvDeviceOperator -from devito.operator.registry import operator_registry - -# Register XDSL Operators -operator_registry.add(XdslnoopOperator, Cpu64, 'xdsl-noop', 'C') -operator_registry.add(XdslAdvOperator, Cpu64, 'xdsl-noop', 'openmp') - -operator_registry.add(XdslAdvOperator, Cpu64, 'xdsl', 'C') -operator_registry.add(XdslAdvOperator, Cpu64, 'xdsl', 'openmp') -operator_registry.add(XdslAdvDeviceOperator, Device, 'xdsl', 'openacc') +# flake8: noqa \ No newline at end of file diff --git a/devito/xdsl_core/xdsl_cpu.py b/devito/xdsl_core/xdsl_cpu.py index 2411c9966ea..24e5d38ce56 100644 --- a/devito/xdsl_core/xdsl_cpu.py +++ b/devito/xdsl_core/xdsl_cpu.py @@ -22,7 +22,8 @@ from xdsl.xdsl_opt_main import xDSLOptMain from devito.ir.xdsl_iet.cluster_to_ssa import (ExtractDevitoStencilConversion, - finalize_module_with_globals) # noqa + finalize_module_with_globals, + setup_memref_args) # noqa from devito.ir.xdsl_iet.profiling import apply_timers from devito.passes.iet import CTarget, OmpTarget diff --git a/devito/xdsl_core/xdsl_gpu.py b/devito/xdsl_core/xdsl_gpu.py index b489ba4173a..87a1bc9b8ad 100644 --- a/devito/xdsl_core/xdsl_gpu.py +++ b/devito/xdsl_core/xdsl_gpu.py @@ -19,6 +19,9 @@ from devito.xdsl_core.utils import generate_pipeline, generate_mlir_pipeline +__all__ = ['XdslAdvDeviceOperator'] + + class XdslAdvDeviceOperator(XdslAdvOperator): _Target = DeviceOmpTarget