From 556009cda5b9d1befb943cb439d5aab5aaa28a7b Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 12 Sep 2024 13:03:26 -0700 Subject: [PATCH] [mpact][compiler] add stable hlo pipeline (#78) adds a lowering to stable hlo method in addition to lowering to linalg; note that this can be used as an alternative path into the mpact pipeline --- python/mpact/mpactbackend.py | 18 +++++++++++++++++- test/python/mm_print.py | 13 ++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py index eb09e0b..346e5a2 100644 --- a/python/mpact/mpactbackend.py +++ b/python/mpact/mpactbackend.py @@ -320,7 +320,7 @@ def export_and_import(f, *args, **kwargs): def mpact_linalg(f, *args, **kwargs): - """Imports a function as module and lowers it into Linalg IR.""" + """Imports a callable as module and lowers it into Linalg IR.""" module = export_and_import(f, *args, **kwargs) run_pipeline_with_repro_report( module, @@ -335,6 +335,22 @@ def mpact_linalg(f, *args, **kwargs): return module +def mpact_stablehlo(f, *args, **kwargs): + """Imports a callable as module and lowers it into StableHLO IR.""" + module = export_and_import(f, *args, **kwargs) + run_pipeline_with_repro_report( + module, + ( + "builtin.module(" + "func.func(torch-decompose-complex-ops)," + "torch-backend-to-stablehlo-backend-pipeline)" + ), + "Lowering TorchFX IR -> StableHLO IR", + enable_ir_printing=False, + ) + return module + + def mpact_jit_compile(f, *args, opt_level=2, use_sp_it=False, **kwargs): """This method compiles the given callable using the MPACT backend.""" module = mpact_linalg(f, *args, **kwargs) diff --git a/test/python/mm_print.py b/test/python/mm_print.py index 976c10c..8065e77 100644 --- a/test/python/mm_print.py +++ b/test/python/mm_print.py @@ -3,7 +3,7 @@ import torch import numpy as np -from mpact.mpactbackend import mpact_linalg +from mpact.mpactbackend import mpact_linalg, mpact_stablehlo from mpact.models.kernels import MMNet @@ -29,3 +29,14 @@ linalg = mpact_linalg(net, X, Y) print(linalg) + +# +# CHECK: module { +# CHECK: func.func @main(%[[A0:.*]]: tensor<4x4xf32>, %[[A1:.*]]: tensor<4x4xf32>) -> tensor<4x4xf32> { +# CHECK: %[[T0:.*]] = stablehlo.dot_general %[[A0]], %[[A1]], contracting_dims = [1] x [0] : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> +# CHECK: return %[[T0]] : tensor<4x4xf32> +# CHECK: } +# CHECK: } + +stablehlo = mpact_stablehlo(net, X, Y) +print(stablehlo)