Skip to content

Commit

Permalink
[mpact][compiler] add stable hlo pipeline (#78)
Browse files Browse the repository at this point in the history
adds a lowering to stable hlo method in addition
to lowering to linalg; note that this can be used
as an alternative path into the mpact pipeline
  • Loading branch information
aartbik authored Sep 12, 2024
1 parent 403b89a commit 556009c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
18 changes: 17 additions & 1 deletion python/mpact/mpactbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def export_and_import(f, *args, **kwargs):


def mpact_linalg(f, *args, **kwargs):
"""Imports a function as module and lowers it into Linalg IR."""
"""Imports a callable as module and lowers it into Linalg IR."""
module = export_and_import(f, *args, **kwargs)
run_pipeline_with_repro_report(
module,
Expand All @@ -335,6 +335,22 @@ def mpact_linalg(f, *args, **kwargs):
return module


def mpact_stablehlo(f, *args, **kwargs):
"""Imports a callable as module and lowers it into StableHLO IR."""
module = export_and_import(f, *args, **kwargs)
run_pipeline_with_repro_report(
module,
(
"builtin.module("
"func.func(torch-decompose-complex-ops),"
"torch-backend-to-stablehlo-backend-pipeline)"
),
"Lowering TorchFX IR -> StableHLO IR",
enable_ir_printing=False,
)
return module


def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs):
"""This method compiles the given callable using the MPACT backend."""
module = mpact_linalg(f, *args, **kwargs)
Expand Down
13 changes: 12 additions & 1 deletion test/python/mm_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import numpy as np

from mpact.mpactbackend import mpact_linalg
from mpact.mpactbackend import mpact_linalg, mpact_stablehlo

from mpact.models.kernels import MMNet

Expand All @@ -29,3 +29,14 @@

linalg = mpact_linalg(net, X, Y)
print(linalg)

#
# CHECK: module {
# CHECK: func.func @main(%[[A0:.*]]: tensor<4x4xf32>, %[[A1:.*]]: tensor<4x4xf32>) -> tensor<4x4xf32> {
# CHECK: %[[T0:.*]] = stablehlo.dot_general %[[A0]], %[[A1]], contracting_dims = [1] x [0] : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
# CHECK: return %[[T0]] : tensor<4x4xf32>
# CHECK: }
# CHECK: }

stablehlo = mpact_stablehlo(net, X, Y)
print(stablehlo)

0 comments on commit 556009c

Please sign in to comment.