Skip to content

Commit

Permalink
More factoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
PapyChacal committed Aug 2, 2023
1 parent a01aca4 commit 7153632
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,13 +543,13 @@ def convert_devito_stencil_to_xdsl_stencil(module, timed:bool=True):



def finalize_module_with_globals(module: builtin.ModuleOp, known_symbols: dict[str, Any]):
def finalize_module_with_globals(module: builtin.ModuleOp, known_symbols: dict[str, Any], gpu_boilerplate):
patterns = [
_InsertSymbolicConstants(known_symbols),
_LowerLoadSymbolidToFuncArgs(),
]
grpa = GreedyRewritePatternApplier(patterns)
PatternRewriteWalker(grpa).rewrite_module(module)
if isinstance(configuration['platform'], NvidiaDevice):
if gpu_boilerplate:
walker = PatternRewriteWalker(GreedyRewritePatternApplier([WrapFunctionWithTransfers('apply_kernel')]))
walker.rewrite_module(module)
6 changes: 3 additions & 3 deletions devito/operator/xdsl_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
MLIR_CPU_PIPELINE = '"builtin.module(canonicalize, cse, loop-invariant-code-motion, canonicalize, cse, loop-invariant-code-motion,cse,canonicalize,fold-memref-alias-ops,expand-strided-metadata, loop-invariant-code-motion,lower-affine,convert-scf-to-cf,convert-math-to-llvm,convert-func-to-llvm{use-bare-ptr-memref-call-conv},finalize-memref-to-llvm,canonicalize,cse)"'
MLIR_OPENMP_PIPELINE = '"builtin.module(canonicalize, cse, loop-invariant-code-motion, canonicalize, cse, loop-invariant-code-motion,cse,canonicalize,fold-memref-alias-ops,expand-strided-metadata, loop-invariant-code-motion,lower-affine,finalize-memref-to-llvm,loop-invariant-code-motion,canonicalize,cse,convert-scf-to-openmp,finalize-memref-to-llvm,convert-scf-to-cf,convert-func-to-llvm{use-bare-ptr-memref-call-conv},convert-openmp-to-llvm,convert-math-to-llvm,reconcile-unrealized-casts,canonicalize,cse)"'
# gpu-launch-sink-index-computations seemed to have no impact
MLIR_GPU_PIPELINE = '"builtin.module(test-math-algebraic-simplification,scf-parallel-loop-tiling{parallel-loop-tile-sizes=128,1,1},func.func(gpu-map-parallel-loops),convert-parallel-loops-to-gpu,fold-memref-alias-ops,expand-strided-metadata,lower-affine,gpu-kernel-outlining,func.func(gpu-async-region),canonicalize,cse,convert-arith-to-llvm{index-bitwidth=64},finalize-memref-to-llvm{index-bitwidth=64},convert-scf-to-cf,convert-cf-to-llvm{index-bitwidth=64},canonicalize,cse,convert-func-to-llvm{use-bare-ptr-memref-call-conv},gpu.module(convert-gpu-to-nvvm,reconcile-unrealized-casts,canonicalize,gpu-to-cubin),gpu-to-llvm,canonicalize,cse)"'
MLIR_GPU_PIPELINE = '"builtin.module(test-math-algebraic-simplification,scf-parallel-loop-tiling{parallel-loop-tile-sizes=128,1,1},func.func(gpu-map-parallel-loops),convert-parallel-loops-to-gpu,fold-memref-alias-ops,expand-strided-metadata,lower-affine,canonicalize,cse,gpu-kernel-outlining,func.func(gpu-async-region),canonicalize,cse,convert-arith-to-llvm{index-bitwidth=64},finalize-memref-to-llvm{index-bitwidth=64},convert-scf-to-cf,convert-cf-to-llvm{index-bitwidth=64},canonicalize,cse,convert-func-to-llvm{use-bare-ptr-memref-call-conv},gpu.module(convert-gpu-to-nvvm{use-bare-ptr-memref-call-conv},reconcile-unrealized-casts,canonicalize,gpu-to-cubin),gpu-to-llvm,canonicalize,cse)"'

XDSL_CPU_PIPELINE = "stencil-shape-inference,convert-stencil-to-ll-mlir,reconcile-unrealized-casts,printf-to-llvm"
XDSL_GPU_PIPELINE = "stencil-shape-inference,convert-stencil-to-ll-mlir{target=gpu},reconcile-unrealized-casts,printf-to-llvm"
Expand Down Expand Up @@ -107,7 +107,7 @@ def _jit_compile(self):
raise RuntimeError("Cannot run OMP+GPU!")

# specialize the code for the specific apply parameters
finalize_module_with_globals(self._module, self._jit_kernel_constants)
finalize_module_with_globals(self._module, self._jit_kernel_constants, gpu_boilerplate=True)

# print module as IR
module_str = StringIO()
Expand Down Expand Up @@ -268,7 +268,7 @@ def _lower(cls, expressions, **kwargs):
from devito.ir.ietxdsl.cluster_to_ssa import ExtractDevitoStencilConversion, convert_devito_stencil_to_xdsl_stencil
conv = ExtractDevitoStencilConversion(expressions)
module = conv.convert()
convert_devito_stencil_to_xdsl_stencil(module)
convert_devito_stencil_to_xdsl_stencil(module, timed=True)

# [LoweredEq] -> [Clusters]
clusters = cls._lower_clusters(expressions, **kwargs)
Expand Down

0 comments on commit 7153632

Please sign in to comment.