Skip to content

Commit

Permalink
!!!
Browse files Browse the repository at this point in the history
  • Loading branch information
PapyChacal committed Aug 19, 2024
1 parent ecd3900 commit 06dc480
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 21 deletions.
30 changes: 13 additions & 17 deletions devito/ir/xdsl_iet/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr,
SSAargs = (self._visit_math_nodes(dim, arg, output_indexed)
for arg in node.args)
return reduce(lambda x, y : arith.AndI(x, y).result, SSAargs)

# Trigonometric functions
elif isinstance(node, sin):
assert len(node.args) == 1, "Expected single argument for sin."
Expand All @@ -298,13 +298,13 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr,
assert len(node.args) == 1, "Expected single argument for cos."
return math.CosOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, tan):
assert len(node.args) == 1, "Expected single argument for TanOp."

return math.TanOp(self._visit_math_nodes(dim, node.args[0],
output_indexed)).result

elif isinstance(node, Relational):
if isinstance(node, GreaterThan):
mnemonic = "sge"
Expand Down Expand Up @@ -391,12 +391,10 @@ def build_stencil_step(self, dim: SteppingDimension, eq: LoweredEq) -> None:
apply.res[0],
self.function_values[self.out_time_buffer],
stencil.StencilBoundsAttr(zip(lb, ub)),
stencil.TempType(len(shape),
element_type=dtype_to_xdsltype(write_function.dtype))
)

store.temp_with_halo.name_hint = f"{write_function.name}_t{self.out_time_buffer[1]}_temp" # noqa
self.temps[self.out_time_buffer] = store.temp_with_halo
load = stencil.LoadOp.get(self.function_values[self.out_time_buffer])
load.res.name_hint = f"{write_function.name}_t{self.out_time_buffer[1]}_temp" # noqa
self.temps[self.out_time_buffer] = load.res

def build_generic_step_expression(self, dim: SteppingDimension, eq: LoweredEq):
# Sources
Expand Down Expand Up @@ -439,7 +437,6 @@ def build_condition(self, dim: SteppingDimension, eq: BooleanFunction):
self.build_generic_step_expression(dim, eq)
scf.Yield()


def build_time_loop(
self, eqs: list[Any], step_dim: SteppingDimension, **kwargs
):
Expand All @@ -450,7 +447,7 @@ def build_time_loop(
ub = iet_ssa.LoadSymbolic.get(
step_dim.symbolic_max._C_name, IndexType()
)

one = arith.Constant.from_int_and_width(1, IndexType())

# Devito iterates from time_m to time_M *inclusive*, MLIR only takes
Expand Down Expand Up @@ -497,7 +494,7 @@ def build_time_loop(
for i, (f, t) in enumerate(self.time_buffers)
}
self.function_values |= self.block_args

# Name the block argument for debugging
for (f, t), arg in self.block_args.items():
arg.name_hint = f"{f.name}_t{t}"
Expand All @@ -513,8 +510,7 @@ def build_time_loop(

def lower_devito_Eqs(self, eqs: list[Any], **kwargs):
# Lower devito Equations to xDSL



for eq in eqs:
lowered = self.operator._lower_exprs(as_tuple(eq), **kwargs)
if isinstance(eq, Eq):
Expand Down Expand Up @@ -546,7 +542,7 @@ def _lower_injection(self, eqs: list[LoweredEq]):
lb = arith.Constant.from_int_and_width(int(lower), IndexType())
else:
raise NotImplementedError(f"Lower bound of type {type(lower)} not supported")

try:
name = interval.dim.symbolic_min.name
except:
Expand Down Expand Up @@ -633,7 +629,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp:
# Instantiate the module.
self.function_values: dict[tuple[Function, int], SSAValue] = {}
self.symbol_values: dict[str, SSAValue] = {}

module = ModuleOp(Region([block := Block([])]))
with ImplicitBuilder(block):
# Get all functions used in the equations
Expand All @@ -647,7 +643,7 @@ def convert(self, eqs: Iterable[Eq], **kwargs) -> ModuleOp:
functions.add(f.function)

elif isinstance(eq, Injection):

functions.add(eq.field.function)
for f in retrieve_functions(eq.expr):
if isinstance(f, PointSource):
Expand Down
17 changes: 14 additions & 3 deletions devito/xdsl_core/xdsl_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def _jit_compile(self):
# Run the first pipeline, mostly xDSL-centric
xdsl_args = [source_name,
"--allow-unregistered-dialect",
"--disable-verify",
"-p",
xdsl_pipeline[1:-1],]
# We use the Python API to run xDSL rather than a subprocess
Expand Down Expand Up @@ -597,7 +598,10 @@ def generate_MLIR_OPENMP_PIPELINE(kwargs):

def generate_XDSL_CPU_PIPELINE(nb_tiled_dims):
passes = [
"stencil-shape-inference",
"canonicalize",
"cse",
"shape-inference",
"stencil-bufferize",
"convert-stencil-to-ll-mlir",
f"scf-parallel-loop-tiling{{{generate_tiling_arg(nb_tiled_dims)}}}",
"printf-to-llvm",
Expand All @@ -609,7 +613,10 @@ def generate_XDSL_CPU_PIPELINE(nb_tiled_dims):

def generate_XDSL_CPU_noop_PIPELINE():
passes = [
"stencil-shape-inference",
"canonicalize",
"cse",
"shape-inference",
"stencil-bufferize",
"convert-stencil-to-ll-mlir",
"printf-to-llvm"
]
Expand All @@ -619,11 +626,15 @@ def generate_XDSL_CPU_noop_PIPELINE():

def generate_XDSL_MPI_PIPELINE(decomp, nb_tiled_dims):
passes = [
"canonicalize",
"cse",
f"distribute-stencil{decomp}",
"shape-inference",
"canonicalize-dmp",
"stencil-bufferize",
"dmp-to-mpi{mpi_init=false}",
"convert-stencil-to-ll-mlir",
f"scf-parallel-loop-tiling{{{generate_tiling_arg(nb_tiled_dims)}}}",
"dmp-to-mpi{mpi_init=false}",
"lower-mpi",
"printf-to-llvm",
"canonicalize"
Expand Down
3 changes: 2 additions & 1 deletion tests/test_xdsl_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def test_xdsl_III():
assert isinstance(scffor_ops[0], LoadOp)
assert isinstance(scffor_ops[1], ApplyOp)
assert isinstance(scffor_ops[2], StoreOp)
assert isinstance(scffor_ops[3], Yield)
assert isinstance(scffor_ops[3], LoadOp)
assert isinstance(scffor_ops[4], Yield)

assert type(ops[7] == Call)
assert type(ops[8] == StoreOp)
Expand Down

0 comments on commit 06dc480

Please sign in to comment.