Skip to content

Commit

Permalink
Merge pull request #41 from xdslproject/compiler_more
Browse files Browse the repository at this point in the history
compiler: more cleanup
  • Loading branch information
georgebisbas authored Dec 5, 2023
2 parents 68e9eaf + 97faa07 commit 107aefd
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 1,194 deletions.
29 changes: 29 additions & 0 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,32 @@ class Cpu64FsgCOperator(Cpu64FsgOperator):

class Cpu64FsgOmpOperator(Cpu64FsgOperator):
_Target = OmpTarget


# -----------XDSL
# This is a collection of xDSL optimization pipelines
# Ideally they should follow the same type of subclassing as the rest of
# the Devito Operatos


MLIR_CPU_PIPELINE = '"builtin.module(canonicalize, cse, loop-invariant-code-motion, canonicalize, cse, loop-invariant-code-motion, cse, canonicalize, fold-memref-alias-ops, expand-strided-metadata, loop-invariant-code-motion, lower-affine, convert-scf-to-cf, convert-math-to-llvm, convert-func-to-llvm{use-bare-ptr-memref-call-conv}, finalize-memref-to-llvm, canonicalize, cse)"' # noqa

MLIR_OPENMP_PIPELINE = '"builtin.module(canonicalize, cse, loop-invariant-code-motion, canonicalize, cse, loop-invariant-code-motion,cse,canonicalize,fold-memref-alias-ops,expand-strided-metadata, loop-invariant-code-motion,lower-affine,finalize-memref-to-llvm,loop-invariant-code-motion,canonicalize,cse,convert-scf-to-openmp,finalize-memref-to-llvm,convert-scf-to-cf,convert-func-to-llvm{use-bare-ptr-memref-call-conv},convert-openmp-to-llvm,convert-math-to-llvm,reconcile-unrealized-casts,canonicalize,cse)"' # noqa
# gpu-launch-sink-index-computations seemed to have no impact
MLIR_GPU_PIPELINE = lambda block_sizes: f'"builtin.module(test-math-algebraic-simplification,scf-parallel-loop-tiling{{parallel-loop-tile-sizes={block_sizes}}},func.func(gpu-map-parallel-loops),convert-parallel-loops-to-gpu,lower-affine, canonicalize,cse, fold-memref-alias-ops, gpu-launch-sink-index-computations, gpu-kernel-outlining, canonicalize{{region-simplify}},cse,fold-memref-alias-ops,expand-strided-metadata,lower-affine,canonicalize,cse,func.func(gpu-async-region),canonicalize,cse,convert-arith-to-llvm{{index-bitwidth=64}},convert-scf-to-cf,convert-cf-to-llvm{{index-bitwidth=64}},canonicalize,cse,convert-func-to-llvm{{use-bare-ptr-memref-call-conv}},gpu.module(convert-gpu-to-nvvm,reconcile-unrealized-casts,canonicalize,gpu-to-cubin),gpu-to-llvm,canonicalize,cse)"' # noqa

XDSL_CPU_PIPELINE = lambda nb_tiled_dims: f'"stencil-shape-inference,convert-stencil-to-ll-mlir{{{generate_tiling_arg(nb_tiled_dims)}}},printf-to-llvm"' # noqa

XDSL_GPU_PIPELINE = "stencil-shape-inference,convert-stencil-to-ll-mlir{target=gpu},reconcile-unrealized-casts,printf-to-llvm" # noqa

XDSL_MPI_PIPELINE = lambda decomp, nb_tiled_dims: f'"dmp-decompose{decomp},canonicalize-dmp,convert-stencil-to-ll-mlir{{{generate_tiling_arg(nb_tiled_dims)}}},dmp-to-mpi{{mpi_init=false}},lower-mpi,printf-to-llvm"' # noqa


def generate_tiling_arg(nb_tiled_dims: int):
"""
Generate the tile-sizes arg for the convert-stencil-to-ll-mlir pass.
Generating no argument if the diled_dims arg is 0
"""
if nb_tiled_dims == 0:
return ''
return "tile-sizes=" + ",".join(["64"]*nb_tiled_dims)
7 changes: 5 additions & 2 deletions devito/ir/ietxdsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from devito.ir.ietxdsl.lowering import LowerIetForToScfFor, LowerIetForToScfParallel, DropIetComments, iet_to_standard_mlir # noqa
from devito.ir.ietxdsl.cluster_to_ssa import finalize_module_with_globals, convert_devito_stencil_to_xdsl_stencil # noqa
from devito.ir.ietxdsl.lowering import (LowerIetForToScfFor, LowerIetForToScfParallel)
from devito.ir.ietxdsl.cluster_to_ssa import (finalize_module_with_globals,
convert_devito_stencil_to_xdsl_stencil)

# flake8: noqa
130 changes: 60 additions & 70 deletions devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
# ------------- General imports -------------#

from typing import Any
from dataclasses import dataclass, field
from sympy import Add, Expr, Float, Indexed, Integer, Mod, Mul, Pow, Symbol

# ------------- xdsl imports -------------#
from xdsl.dialects import arith, builtin, func, memref, scf, stencil, gpu
from xdsl.dialects import (arith, builtin, func, memref, scf,
stencil, gpu, llvm)
from xdsl.dialects.experimental import math
from xdsl.ir import Block, Operation, OpResult, Region, SSAValue
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)

# ------------- devito imports -------------#
from devito import Grid, SteppingDimension
from devito.ir.equations import LoweredEq
from devito.symbolics import retrieve_indexed
from devito.symbolics import retrieve_indexed, retrieve_function_carriers
from devito.logger import perf

# ------------- devito-xdsl SSA imports -------------#
from devito.ir.ietxdsl import iet_ssa
from devito.ir.ietxdsl.utils import is_int, is_float
from devito.ir.ietxdsl.ietxdsl_functions import dtypes_to_xdsltypes
from devito.ir.ietxdsl.lowering import LowerIetForToScfFor

# flake8: noqa

Expand Down Expand Up @@ -50,34 +61,34 @@ def _convert_eq(self, eq: LoweredEq):
function = eq.lhs.function
mlir_type = dtypes_to_xdsltypes[function.dtype]
grid: Grid = function.grid
# get the halo of the space dimensions only e.g [(2, 2), (2, 2)] for the 2d case

# Get the halo of the grid.dimensions
# e.g [(2, 2), (2, 2)] for the 2D case
# Do not forget the issue with Devito adding an extra point!
# Check 'def halo_setup' for more
# (for derivative regions)
halo = [function.halo[function.dimensions.index(d)] for d in grid.dimensions]
halo = [function.halo[d] for d in grid.dimensions]

# Shift all time values so that for all accesses at t + n, n>=0.
self.time_offs = min(
int(idx.indices[0] - grid.stepping_dim) for idx in retrieve_indexed(eq)
)

# Calculate the actual size of our time dimension
actual_time_size = (
max(int(idx.indices[0] - grid.stepping_dim) for idx in retrieve_indexed(eq))
- self.time_offs
+ 1
)

# Get the time_size
time_size = max(d.function.time_size for d in retrieve_function_carriers(eq))

# Build the for loop
perf("Build Time Loop")
loop = self._build_iet_for(grid.stepping_dim, actual_time_size)
loop = self._build_iet_for(grid.stepping_dim, time_size)

# build stencil
perf("Initialize a stencil Op")
stencil_op = iet_ssa.Stencil.get(
loop.subindice_ssa_vals(),
grid.shape_local,
halo,
actual_time_size,
time_size,
mlir_type,
eq.lhs.function._C_name,
)
Expand All @@ -87,7 +98,7 @@ def _convert_eq(self, eq: LoweredEq):
# dims -> ssa vals
perf("Apply time offsets")
time_offset_to_field: dict[str, SSAValue] = {
i: stencil_op.block.args[i] for i in range(actual_time_size - 1)
i: stencil_op.block.args[i] for i in range(time_size - 1)
}

# reset loaded values
Expand All @@ -103,8 +114,9 @@ def _convert_eq(self, eq: LoweredEq):

# emit return
offsets = _get_dim_offsets(eq.lhs, self.time_offs)

assert (
offsets[0] == actual_time_size - 1
offsets[0] == time_size - 1
), "result should be written to last time buffer"
assert all(
o == 0 for o in offsets[1:]
Expand All @@ -118,51 +130,47 @@ def _convert_eq(self, eq: LoweredEq):
)

def _visit_math_nodes(self, node: Expr) -> SSAValue:
# Handle Indexeds
if isinstance(node, Indexed):
offsets = _get_dim_offsets(node, self.time_offs)
return self.loaded_values[offsets]
if isinstance(node, Integer):
# Handle Integers
elif isinstance(node, Integer):
cst = arith.Constant.from_int_and_width(int(node), builtin.i64)
self.block.add_op(cst)
return cst.result
if isinstance(node, Float):
# Handle Floats
elif isinstance(node, Float):
cst = arith.Constant.from_float_and_width(float(node), builtin.f32)
self.block.add_op(cst)
return cst.result
# if isinstance(math, Constant):
# symb = iet_ssa.LoadSymbolic.get(math.name, dtypes_to_xdsltypes[math.dtype])
# self.block.add_op(symb)
# return symb.result
if isinstance(node, Symbol):
# Handle Symbols
elif isinstance(node, Symbol):
symb = iet_ssa.LoadSymbolic.get(node.name, builtin.f32)
self.block.add_op(symb)
return symb.result

# handle all of the math
if not isinstance(node, (Add, Mul, Pow, Mod)):
raise ValueError(f"Unknown math: {node}", node)

args = [self._visit_math_nodes(arg) for arg in node.args]

# make sure all args are the same type:
if isinstance(node, (Add, Mul)):
return symb.result
# Handle Add Mul
elif isinstance(node, (Add, Mul)):
args = [self._visit_math_nodes(arg) for arg in node.args]
# add casts when necessary
# get first element out, store the rest in args
# this makes the reduction easier
carry, *args = self._ensure_same_type(*args)
# select the correct op from arith.addi, arith.addf, arith.muli, arith.mulf
if isinstance(carry.type, builtin.IntegerType):
op_cls = arith.Addi if isinstance(node, Add) else arith.Muli
else:
elif isinstance(carry.type, builtin.Float32Type):
op_cls = arith.Addf if isinstance(node, Add) else arith.Mulf

else:
raise("Add support for another type")
for arg in args:
op = op_cls(carry, arg)
self.block.add_op(op)
carry = op.result
return carry

if isinstance(node, Pow):
# Handle Pow
elif isinstance(node, Pow):
args = [self._visit_math_nodes(arg) for arg in node.args]
assert len(args) == 2, "can't pow with != 2 args!"
base, ex = args
if is_int(base):
Expand All @@ -183,11 +191,12 @@ def _visit_math_nodes(self, node: Expr) -> SSAValue:
op = op_cls.get(base, ex)
self.block.add_op(op)
return op.result
# Handle Mod
elif isinstance(node, Mod):
raise NotImplementedError("Go away, no mod here. >:(")
else:
raise NotImplementedError(f"Unknown math: {node}", node)

if isinstance(node, Mod):
raise ValueError("Go away, no mod here. >:(")

raise ValueError("Unknown math!")

def _add_access_ops(
self, reads: list[Indexed], time_offset_to_field: dict[int, SSAValue]
Expand All @@ -202,10 +211,11 @@ def _add_access_ops(
"""
# get the compile time constant offsets for this read
offsets = _get_dim_offsets(read, self.time_offs)

if offsets in self.loaded_values:
continue

# assume time dimension is first dimension
# Assume time dimension is first dimension
t_offset = offsets[0]
space_offsets = offsets[1:]

Expand Down Expand Up @@ -251,10 +261,10 @@ def _ensure_same_type(self, *vals: SSAValue):
if all(is_float(val) for val in vals):
return vals
# not everything homogeneous
new_vals = []
processed = []
for val in vals:
if is_float(val):
new_vals.append(val)
processed.append(val)
continue
# if the val is the result of a arith.constant with no uses,
# we change the type of the arith.constant to our desired type
Expand All @@ -267,19 +277,20 @@ def _ensure_same_type(self, *vals: SSAValue):
val.op.attributes["value"] = builtin.FloatAttr(
float(val.op.value.value.data), builtin.f32
)
new_vals.append(val)
processed.append(val)
continue
# insert an integer to float cast op
conv = arith.SIToFPOp(val, builtin.f32)
self.block.add_op(conv)
new_vals.append(conv.result)
return new_vals
processed.append(conv.result)
return processed


def _get_dim_offsets(idx: Indexed, t_offset: int) -> tuple:
# shift all time values so that for all accesses at t + n, n>=0.
# time_offs = min(int(i - d) for i, d in zip(idx.indices, idx.function.dimensions))
halo = ((t_offset, 0), *idx.function.halo[1:])

try:
return tuple(
int(i - d - halo_offset)
Expand All @@ -291,36 +302,13 @@ def _get_dim_offsets(idx: Indexed, t_offset: int) -> tuple:
raise ValueError("Indices must be constant offset from dimension!") from ex


def is_int(val: SSAValue):
return isinstance(val.type, builtin.IntegerType)


def is_float(val: SSAValue):
return val.type in (builtin.f32, builtin.f64)


# -------------------------------------------------------- ####
# ####
# devito.stencil ---> stencil dialect ####
# ####
# -------------------------------------------------------- ####

from dataclasses import dataclass, field

from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)

from devito.ir.ietxdsl.lowering import (
LowerIetForToScfFor,
)

from xdsl.dialects import llvm

@dataclass
class WrapFunctionWithTransfers(RewritePattern):
func_name: str
Expand Down Expand Up @@ -545,8 +533,10 @@ def finalize_module_with_globals(module: builtin.ModuleOp, known_symbols: dict[s
_InsertSymbolicConstants(known_symbols),
_LowerLoadSymbolidToFuncArgs(),
]
grpa = GreedyRewritePatternApplier(patterns)
PatternRewriteWalker(grpa).rewrite_module(module)
rewriter = GreedyRewritePatternApplier(patterns)
PatternRewriteWalker(rewriter).rewrite_module(module)

# GPU boilerplate
if gpu_boilerplate:
walker = PatternRewriteWalker(GreedyRewritePatternApplier([WrapFunctionWithTransfers('apply_kernel')]))
walker.rewrite_module(module)
Loading

0 comments on commit 107aefd

Please sign in to comment.