Skip to content

Commit

Permalink
[WIP] [mpact][sparse] replace ad-hoc mechanism with proper FX [WIP]
Browse files Browse the repository at this point in the history
NOTE NOTE NOTE
this still is waiting for a torch-mlir bump
NOTE NOTE NOTE
  • Loading branch information
aartbik committed Aug 14, 2024
1 parent 6e2dc79 commit e0ee62a
Showing 1 changed file with 6 additions and 148 deletions.
154 changes: 6 additions & 148 deletions python/mpact/mpactbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mpact.dialects import torch as torch_d
from mpact.execution_engine import *
from mpact.extras.fx_decomp_util import get_decomposition_table
from mpact.extras.fx_importer import FxImporter, SparsityMeta
from mpact.extras.fx_importer import FxImporter
from mpact.ir import *
from mpact.passmanager import *
from mpact.runtime import *
Expand Down Expand Up @@ -124,14 +124,6 @@ def assert_arg_type_is_supported(ty):

CONSUME_RETURN_FUNC_PREFIX = "refbackend_consume_func_return_"

SPARSE_LAYOUTS = [
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
]


def get_return_funcs(module):
return_prefix_len = len(CONSUME_RETURN_FUNC_PREFIX)
Expand Down Expand Up @@ -314,149 +306,15 @@ def load(self, module: MpactCompiledArtifact) -> MpactBackendInvoker:
return MpactBackendInvoker(module, self.opt_level)


def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
"""
Returns a meta data tuple for the given sparse tensor.
NOTE: this will be fully replaced by fx graph SparseTensorMetadata
"""
sparse_dim = a.sparse_dim()
dense_dim = a.dense_dim()
batch_dim = a.ndim - dense_dim - sparse_dim
blocksize = None
if a.layout is torch.sparse_coo:
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a._indices().dtype,
a._indices().dtype,
)
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
if a.layout is torch.sparse_bsr:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a.crow_indices().dtype,
a.col_indices().dtype,
)
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
if a.layout is torch.sparse_bsc:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a.ccol_indices().dtype,
a.row_indices().dtype,
)
else:
raise RuntimeError(f"Unsupported sparse layout for {a}")


def sparse_arg(args, i):
if isinstance(args[i], torch.fx.node.Node):
return args[i].meta.get("sparsity", None)
return None


def sparse_export(
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
) -> torch.export.ExportedProgram:
"""
This is a ***temporary*** wrapper around `torch.export.export`
that eventually should be removed and simply replaced by the
standard API for exporting traced graphs.
But until issue
https://github.com/pytorch/pytorch/pull/117907
is addressed, this wrapper provides support for the sparse
tensor types by first converting all operands to dense tensors,
building the traced graph as for the dense case, then annotating
sparse parameters with their actual sparse layout attributes,
followed by some simple propagation rules. This temporary solution
accelerates testing torch-mlir with PyTorch sparse tensors until
the issue is resolved upstream.
"""
# Convert all arguments to dense.
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
mask = [a.layout in SPARSE_LAYOUTS for a in args]
# Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs)
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
# Annotate sparse arguments in the graph and apply some very
# basic propagation rules for sparsity.
specs = prog.graph_signature.input_specs
alen = len(specs)
k = 0
for i, node in enumerate(prog.graph.nodes):
if node.op == "placeholder":
# Argument.
spec = specs[i]
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
if mask[k]:
node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1
elif node.op == "call_function":
opname = node.target._schema.name.split("::")[1]
# Zero preserving elt-wise unary op.
if opname in {"abs", "neg", "relu", "sin"}:
node.meta["sparsity"] = sparse_arg(node.args, 0)
# Some simplistic rules for preserving sparsity. Soon
# to be replaced by proper FX graph propagation.
elif opname in {"mul"}:
m0 = sparse_arg(node.args, 0)
m1 = sparse_arg(node.args, 1)
if m0 is not None:
node.meta["sparsity"] = m0
elif m1 is not None:
node.meta["sparsity"] = m1
elif opname in {"add", "mm"}:
m0 = sparse_arg(node.args, 0)
m1 = sparse_arg(node.args, 1)
if m0 is not None and m1 is not None:
node.meta["sparsity"] = m0
elif opname == "_to_sparse" or opname == "to_sparse":
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
# TODO: Uncomment this to hack sparsity into the network.
# elif opname == "_to_dense" or opname == "to_dense":
# # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = sparse_arg(node.args, 0)
elif opname == "select" and sparse_arg(node.args, 0):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
elif opname == "stack" and sparse_arg(node.args[0], 0):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
)
return prog


def export_and_import(f, *args, **kwargs):
"""This method implements Stella's importer, stripped down to essentials."""
"""A FX graph importer, stripped down to essentials."""
context = ir.Context()
torch_d.register_dialect(context)
fx_importer = FxImporter(context=context)
prog = sparse_export(f, args, kwargs)
prog = torch.export.export(f, args, kwargs)
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
fx_importer.import_frozen_program(prog)
return fx_importer.module

Expand Down

0 comments on commit e0ee62a

Please sign in to comment.