diff --git a/python/mpact/mpactbackend.py b/python/mpact/mpactbackend.py index 425413a..72b440d 100644 --- a/python/mpact/mpactbackend.py +++ b/python/mpact/mpactbackend.py @@ -16,7 +16,7 @@ from mpact.dialects import torch as torch_d from mpact.execution_engine import * from mpact.extras.fx_decomp_util import get_decomposition_table -from mpact.extras.fx_importer import FxImporter, SparsityMeta +from mpact.extras.fx_importer import FxImporter from mpact.ir import * from mpact.passmanager import * from mpact.runtime import * @@ -124,14 +124,6 @@ def assert_arg_type_is_supported(ty): CONSUME_RETURN_FUNC_PREFIX = "refbackend_consume_func_return_" -SPARSE_LAYOUTS = [ - torch.sparse_coo, - torch.sparse_csr, - torch.sparse_csc, - torch.sparse_bsr, - torch.sparse_bsc, -] - def get_return_funcs(module): return_prefix_len = len(CONSUME_RETURN_FUNC_PREFIX) @@ -314,149 +306,15 @@ def load(self, module: MpactCompiledArtifact) -> MpactBackendInvoker: return MpactBackendInvoker(module, self.opt_level) -def sparse_metadata(a: torch.Tensor) -> SparsityMeta: - """ - Returns a meta data tuple for the given sparse tensor. - - NOTE: this will be fully replaced by fx graph SparseTensorMetadata - """ - sparse_dim = a.sparse_dim() - dense_dim = a.dense_dim() - batch_dim = a.ndim - dense_dim - sparse_dim - blocksize = None - if a.layout is torch.sparse_coo: - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a._indices().dtype, - a._indices().dtype, - ) - elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: - if a.layout is torch.sparse_bsr: - blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a.crow_indices().dtype, - a.col_indices().dtype, - ) - elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: - if a.layout is torch.sparse_bsc: - blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a.ccol_indices().dtype, - a.row_indices().dtype, - ) - else: - raise RuntimeError(f"Unsupported sparse layout for {a}") - - -def sparse_arg(args, i): - if isinstance(args[i], torch.fx.node.Node): - return args[i].meta.get("sparsity", None) - return None - - -def sparse_export( - f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None -) -> torch.export.ExportedProgram: - """ - This is a ***temporary*** wrapper around `torch.export.export` - that eventually should be removed and simply replaced by the - standard API for exporting traced graphs. - - But until issue - - https://github.com/pytorch/pytorch/pull/117907 - - is addressed, this wrapper provides support for the sparse - tensor types by first converting all operands to dense tensors, - building the traced graph as for the dense case, then annotating - sparse parameters with their actual sparse layout attributes, - followed by some simple propagation rules. This temporary solution - accelerates testing torch-mlir with PyTorch sparse tensors until - the issue is resolved upstream. - """ - # Convert all arguments to dense. - dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args) - mask = [a.layout in SPARSE_LAYOUTS for a in args] - # Build the regular FX traced graph with only dense arguments - # (the current version would crash otherwise, see issue above). - prog = torch.export.export(f, dargs, kwargs) - decomposition_table = get_decomposition_table() - if decomposition_table: - prog = prog.run_decompositions(decomposition_table) - # Annotate sparse arguments in the graph and apply some very - # basic propagation rules for sparsity. - specs = prog.graph_signature.input_specs - alen = len(specs) - k = 0 - for i, node in enumerate(prog.graph.nodes): - if node.op == "placeholder": - # Argument. - spec = specs[i] - if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - if mask[k]: - node.meta["sparsity"] = sparse_metadata(args[k]) - k = k + 1 - elif node.op == "call_function": - opname = node.target._schema.name.split("::")[1] - # Zero preserving elt-wise unary op. - if opname in {"abs", "neg", "relu", "sin"}: - node.meta["sparsity"] = sparse_arg(node.args, 0) - # Some simplistic rules for preserving sparsity. Soon - # to be replaced by proper FX graph propagation. - elif opname in {"mul"}: - m0 = sparse_arg(node.args, 0) - m1 = sparse_arg(node.args, 1) - if m0 is not None: - node.meta["sparsity"] = m0 - elif m1 is not None: - node.meta["sparsity"] = m1 - elif opname in {"add", "mm"}: - m0 = sparse_arg(node.args, 0) - m1 = sparse_arg(node.args, 1) - if m0 is not None and m1 is not None: - node.meta["sparsity"] = m0 - elif opname == "_to_sparse" or opname == "to_sparse": - dim = len(node.meta.get("val").shape) - node.meta["sparsity"] = SparsityMeta( - torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 - ) - # TODO: Uncomment this to hack sparsity into the network. - # elif opname == "_to_dense" or opname == "to_dense": - # # hack (assumes we never really want the to_dense for now) - # node.meta["sparsity"] = sparse_arg(node.args, 0) - elif opname == "select" and sparse_arg(node.args, 0): - dim = len(node.meta.get("val").shape) - node.meta["sparsity"] = SparsityMeta( - torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 - ) - elif opname == "stack" and sparse_arg(node.args[0], 0): - dim = len(node.meta.get("val").shape) - node.meta["sparsity"] = SparsityMeta( - torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64 - ) - return prog - - def export_and_import(f, *args, **kwargs): - """This method implements Stella's importer, stripped down to essentials.""" + """A FX graph importer, stripped down to essentials.""" context = ir.Context() torch_d.register_dialect(context) fx_importer = FxImporter(context=context) - prog = sparse_export(f, args, kwargs) + prog = torch.export.export(f, args, kwargs) + decomposition_table = get_decomposition_table() + if decomposition_table: + prog = prog.run_decompositions(decomposition_table) fx_importer.import_frozen_program(prog) return fx_importer.module