Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mpact] bump torch-mlir and adjust mpact #70

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions python/mpact/models/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ def forward(self, x, v):


class MMNet(torch.nn.Module):
def forward(self, x, v):
return torch.mm(x, v)
def forward(self, x, y):
return torch.mm(x, y)


class AddNet(torch.nn.Module):
def forward(self, x, v):
return torch.add(x, v)
def forward(self, x, y):
return torch.add(x, y)


class MulNet(torch.nn.Module):
def forward(self, x, v):
return torch.mul(x, v)
def forward(self, x, y):
return torch.mul(x, y)


class SelfNet(torch.nn.Module):
Expand Down
154 changes: 6 additions & 148 deletions python/mpact/mpactbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mpact.dialects import torch as torch_d
from mpact.execution_engine import *
from mpact.extras.fx_decomp_util import get_decomposition_table
from mpact.extras.fx_importer import FxImporter, SparsityMeta
from mpact.extras.fx_importer import FxImporter
from mpact.ir import *
from mpact.passmanager import *
from mpact.runtime import *
Expand Down Expand Up @@ -124,14 +124,6 @@ def assert_arg_type_is_supported(ty):

CONSUME_RETURN_FUNC_PREFIX = "refbackend_consume_func_return_"

SPARSE_LAYOUTS = [
torch.sparse_coo,
torch.sparse_csr,
torch.sparse_csc,
torch.sparse_bsr,
torch.sparse_bsc,
]


def get_return_funcs(module):
return_prefix_len = len(CONSUME_RETURN_FUNC_PREFIX)
Expand Down Expand Up @@ -314,149 +306,15 @@ def load(self, module: MpactCompiledArtifact) -> MpactBackendInvoker:
return MpactBackendInvoker(module, self.opt_level)


def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
"""
Returns a meta data tuple for the given sparse tensor.

NOTE: this will be fully replaced by fx graph SparseTensorMetadata
"""
sparse_dim = a.sparse_dim()
dense_dim = a.dense_dim()
batch_dim = a.ndim - dense_dim - sparse_dim
blocksize = None
if a.layout is torch.sparse_coo:
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a._indices().dtype,
a._indices().dtype,
)
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
if a.layout is torch.sparse_bsr:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a.crow_indices().dtype,
a.col_indices().dtype,
)
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
if a.layout is torch.sparse_bsc:
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
return SparsityMeta(
a.layout,
batch_dim,
sparse_dim,
dense_dim,
blocksize,
a.ccol_indices().dtype,
a.row_indices().dtype,
)
else:
raise RuntimeError(f"Unsupported sparse layout for {a}")


def sparse_arg(args, i):
if isinstance(args[i], torch.fx.node.Node):
return args[i].meta.get("sparsity", None)
return None


def sparse_export(
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
) -> torch.export.ExportedProgram:
"""
This is a ***temporary*** wrapper around `torch.export.export`
that eventually should be removed and simply replaced by the
standard API for exporting traced graphs.

But until issue

https://github.com/pytorch/pytorch/pull/117907

is addressed, this wrapper provides support for the sparse
tensor types by first converting all operands to dense tensors,
building the traced graph as for the dense case, then annotating
sparse parameters with their actual sparse layout attributes,
followed by some simple propagation rules. This temporary solution
accelerates testing torch-mlir with PyTorch sparse tensors until
the issue is resolved upstream.
"""
# Convert all arguments to dense.
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
mask = [a.layout in SPARSE_LAYOUTS for a in args]
# Build the regular FX traced graph with only dense arguments
# (the current version would crash otherwise, see issue above).
prog = torch.export.export(f, dargs, kwargs)
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
# Annotate sparse arguments in the graph and apply some very
# basic propagation rules for sparsity.
specs = prog.graph_signature.input_specs
alen = len(specs)
k = 0
for i, node in enumerate(prog.graph.nodes):
if node.op == "placeholder":
# Argument.
spec = specs[i]
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
if mask[k]:
node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1
elif node.op == "call_function":
opname = node.target._schema.name.split("::")[1]
# Zero preserving elt-wise unary op.
if opname in {"abs", "neg", "relu", "sin"}:
node.meta["sparsity"] = sparse_arg(node.args, 0)
# Some simplistic rules for preserving sparsity. Soon
# to be replaced by proper FX graph propagation.
elif opname in {"mul"}:
m0 = sparse_arg(node.args, 0)
m1 = sparse_arg(node.args, 1)
if m0 is not None:
node.meta["sparsity"] = m0
elif m1 is not None:
node.meta["sparsity"] = m1
elif opname in {"add", "mm"}:
m0 = sparse_arg(node.args, 0)
m1 = sparse_arg(node.args, 1)
if m0 is not None and m1 is not None:
node.meta["sparsity"] = m0
elif opname == "_to_sparse" or opname == "to_sparse":
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
# TODO: Uncomment this to hack sparsity into the network.
# elif opname == "_to_dense" or opname == "to_dense":
# # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = sparse_arg(node.args, 0)
elif opname == "select" and sparse_arg(node.args, 0):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
elif opname == "stack" and sparse_arg(node.args[0], 0):
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
)
return prog


def export_and_import(f, *args, **kwargs):
"""This method implements Stella's importer, stripped down to essentials."""
"""A FX graph importer, stripped down to essentials."""
context = ir.Context()
torch_d.register_dialect(context)
fx_importer = FxImporter(context=context)
prog = sparse_export(f, args, kwargs)
prog = torch.export.export(f, args, kwargs)
decomposition_table = get_decomposition_table()
if decomposition_table:
prog = prog.run_decompositions(decomposition_table)
fx_importer.import_frozen_program(prog)
return fx_importer.module

Expand Down
25 changes: 13 additions & 12 deletions test/python/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ def print_sparse(res):
# CHECK: [24. 26. 28. 30.]
# CHECK: [32. 34. 36. 38.]
# CHECK: [40. 42. 44. 46.]{{\]}}
# CHECK: {{\[}}[16. 18. 18. 19.]
# CHECK: [20. 21. 22. 25.]
# CHECK: [24. 25. 26. 27.]
# CHECK: [31. 29. 30. 31.]{{\]}}
# CHECK: {{\[}}[ 0. 2. 2. 3.]
# CHECK: [ 4. 5. 6. 9.]
# CHECK: [ 8. 9. 10. 11.]
# CHECK: [15. 13. 14. 15.]{{\]}}
# CH_ECK: {{\[}}[16. 18. 18. 19.]
# CH_ECK: [20. 21. 22. 25.]
# CH_ECK: [24. 25. 26. 27.]
# CH_ECK: [31. 29. 30. 31.]{{\]}}
# CH_ECK: {{\[}}[ 0. 2. 2. 3.]
# CH_ECK: [ 4. 5. 6. 9.]
# CH_ECK: [ 8. 9. 10. 11.]
# CH_ECK: [15. 13. 14. 15.]{{\]}}
# CHECK: [0 1 2 2 3]
# CHECK: [1 3 0]
# CHECK: [2. 4. 6.]
Expand All @@ -81,9 +81,10 @@ def print_sparse(res):
print("mpact")
res = mpact_jit(net, X, Y)
print(res)
res = mpact_jit(net, S, Y)
print(res)
res = mpact_jit(net, X, S)
print(res)
# TODO: fix in pydev
# res = mpact_jit(net, S, Y)
# print(res)
# res = mpact_jit(net, X, S)
# print(res)
res = mpact_jit(net, S, S)
print_sparse(res)
Loading