diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py index eb09e0b..346e5a2 100644 --- a/python/mpact/mpactbackend.py +++ b/python/mpact/mpactbackend.py @@ -320,7 +320,7 @@ def export_and_import(f, *args, **kwargs): def mpact_linalg(f, *args, **kwargs): - """Imports a function as module and lowers it into Linalg IR.""" + """Imports a callable as module and lowers it into Linalg IR.""" module = export_and_import(f, *args, **kwargs) run_pipeline_with_repro_report( module, @@ -335,6 +335,22 @@ def mpact_linalg(f, *args, **kwargs): return module +def mpact_stablehlo(f, *args, **kwargs): + """Imports a callable as module and lowers it into StableHLO IR.""" + module = export_and_import(f, *args, **kwargs) + run_pipeline_with_repro_report( + module, + ( + "builtin.module(" + "func.func(torch-decompose-complex-ops)," + "torch-backend-to-stablehlo-backend-pipeline)" + ), + "Lowering TorchFX IR -> StableHLO IR", + enable_ir_printing=False, + ) + return module + + def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs): """This method compiles the given callable using the MPACT backend.""" module = mpact_linalg(f, *args, **kwargs) diff --git a/test/python/mm_print.py b/test/python/mm_print.py index 976c10c..8065e77 100644 --- a/test/python/mm_print.py +++ b/test/python/mm_print.py @@ -3,7 +3,7 @@ import torch import numpy as np -from mpact.mpactbackend import mpact_linalg +from mpact.mpactbackend import mpact_linalg, mpact_stablehlo from mpact.models.kernels import MMNet @@ -29,3 +29,14 @@ linalg = mpact_linalg(net, X, Y) print(linalg) + +# +# CHECK: module { +# CHECK: func.func @main(%[[A0:.*]]: tensor<4x4xf32>, %[[A1:.*]]: tensor<4x4xf32>) -> tensor<4x4xf32> { +# CHECK: %[[T0:.*]] = stablehlo.dot_general %[[A0]], %[[A1]], contracting_dims = [1] x [0] : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +# CHECK: return %[[T0]] : tensor<4x4xf32> +# CHECK: } +# CHECK: } + +stablehlo = mpact_stablehlo(net, X, Y) +print(stablehlo)