Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix implementation of index_put, validate it with dort | fix(torchlib) #1277

Merged
merged 14 commits into from
Mar 13, 2024
Merged
22 changes: 7 additions & 15 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4046,25 +4046,17 @@ def aten_index_put(
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""

index = op.SequenceAt(indices, 0) # assume indices only have 1 element
# change array([1,3]) to array([[1,1,1,1,1],[3,3,3,3,3]])
self_dim_1 = op.Gather(op.Shape(self), 1)
index_dim_0 = op.Gather(op.Shape(index), 0)
neg_1 = op.Constant(value_ints=[-1])
shape = op.Concat(op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0)
new_ind = op.Expand(index, shape)
new_ind_t = op.Transpose(new_ind)
index = op.SequenceAt(indices, 0)
xadupre marked this conversation as resolved.
Show resolved Hide resolved
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
new_index = op.Unsqueeze(index, -1)
xadupre marked this conversation as resolved.
Show resolved Hide resolved
shape_self = op.Shape(self)

if op.Cast(accumulate, to=BOOL.dtype):
# put values into zeros array first, then add to input
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
zeros = op.CastLike(zeros, values)
result = op.ScatterElements(zeros, new_ind_t, values)
# FIXME: type promotion
result = op.CastLike(result, self)
zeros = op.CastLike(op.ConstantOfShape(shape_self), values)
result = op.ScatterND(zeros, new_index, values, reduction="add")
result = op.Add(result, self)
else:
result = op.ScatterElements(self, new_ind_t, values)
result = op.ScatterND(self, new_index, values)

return result


Expand Down
177 changes: 177 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/dynamo_export_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import copy
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
import inspect
import itertools
import sys
import unittest

import torch
from torch.onnx import ExportOptions
from torch.onnx import _OrtBackend as OrtBackend
from torch.onnx import _OrtBackendOptions as OrtBackendOptions


class FuncModule(torch.nn.Module):
def __init__(self, f, params=None):
if params is None:
params = ()
super().__init__()
self.f = f
self.ppp = torch.nn.Parameter(torch.Tensor([1]))
self.params = torch.nn.ParameterList(list(params))

def forward(self, *args):
f_args = list(itertools.chain(args, self.params))
f_args[0] = f_args[0] + self.ppp
res = self.f(*f_args)
return res


class FuncModuleModule(torch.nn.Module):
def __init__(self, f):
super().__init__()
self.f = f
self.mod = f
self.ppp = torch.nn.Parameter(torch.Tensor([1]))

def forward(self, *args):
x = args[0] + self.ppp
res = self.mod(x, *args[1:])
return res


def make_aot_ort(dynamic: bool = False):
ort_backend = OrtBackend(
options=OrtBackendOptions(
export_options=ExportOptions(
dynamic_shapes=dynamic,
)
)
)
return ort_backend, ort_backend


class TestOperatorsOnnxrt(unittest.TestCase):
def setUp(self):
super().setUp()
torch._dynamo.reset() # pylint: disable=protected-access

def assertONNX(
self,
f,
args,
onnx_export: str,
params=None,
fullgraph: bool = True,
atol=1e-6,
rtol=1e-6,
opset_version=None,
Fixed Show fixed Hide fixed
test_backward=True,
):
if sys.platform == "win32":
raise unittest.SkipTest("Windows not supported yet.")
assert isinstance(onnx_export, str), f"Export onnx is wrong for f={f}"
assert opset_version is None, f"opset={opset_version}, only default opset is supported"
if isinstance(args, torch.Tensor):
args = [args]
if params is None:
params = ()
if isinstance(f, torch.nn.Module):
model = FuncModuleModule(f)
else:
model = FuncModule(f, params)
model.eval()

if test_backward:
# forward/backward
local_aot_ort, _ = make_aot_ort(dynamic=False)

compiled_model = torch.compile(
copy.deepcopy(model),
backend=local_aot_ort,
dynamic=False,
fullgraph=fullgraph,
)

baseline_result = model(*args)
result = compiled_model(*args)

if isinstance(baseline_result, tuple):
baseline_result = baseline_result[0]
result = result[0]
if isinstance(baseline_result, torch.Tensor):
torch.testing.assert_close(
baseline_result, result, atol=atol, rtol=rtol, equal_nan=True
)

baseline_result.sum().backward()
result.sum().backward()

l1 = list(model.parameters())
l2 = list(compiled_model.parameters())
self.assertEqual(len(l1), len(l2))
assert len(l1) > 0, "No gradient to test"
n_gradient = 0
for baseline_param, param in zip(l1, l2):
n_gradient += 1
torch.testing.assert_close(
baseline_param.grad,
param.grad,
atol=atol,
rtol=rtol,
equal_nan=True,
)
assert n_gradient > 0, "No gradient was checked"
else:
raise AssertionError(f"Unexpected type {type(baseline_result)}.")
else:
# forward only
compiled_model = torch.compile(
copy.deepcopy(model),
backend="onnxrt",
dynamic=False,
fullgraph=fullgraph,
)

baseline_result = model(*args)
result = compiled_model(*args)

if isinstance(baseline_result, torch.Tensor):
torch.testing.assert_close(
baseline_result, result, atol=atol, rtol=rtol, equal_nan=True
)

def test_add(self):
x = torch.zeros((10, 3), requires_grad=True, dtype=torch.float32)
self.assertONNX(lambda x: x + x, x, onnx_export=inspect.currentframe().f_code.co_name)

def test_index_put(self):
x = torch.zeros((10, 3), requires_grad=True, dtype=torch.float32)
indices = torch.arange(8, dtype=torch.int64).reshape((-1, 4))
values = torch.arange(24, dtype=torch.float32).reshape((-1, 4, 3))

# redondant test to make sure this expression is valid for torch
assert x.index_put((indices,), values) is not None

self.assertONNX(
lambda x, indices, values: x.index_put((indices,), values),
(x, indices, values),
onnx_export=inspect.currentframe().f_code.co_name,
)

def test_index_put_accumulate(self):
x = torch.zeros((10, 3), requires_grad=True, dtype=torch.float32)
indices = torch.arange(8, dtype=torch.int64).reshape((-1, 4))
values = torch.arange(24, dtype=torch.float32).reshape((-1, 4, 3))

# redondant test to make sure this expression is valid for torch
assert x.index_put((indices,), values) is not None

self.assertONNX(
lambda x, indices, values: x.index_put((indices,), values, accumulate=True),
(x, indices, values),
onnx_export=inspect.currentframe().f_code.co_name,
)
justinchuby marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
unittest.main(verbosity=2)
10 changes: 7 additions & 3 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,14 +836,18 @@ def _where_input_wrangler(
core_ops.aten_index_put_bool,
).skip(
matcher=lambda sample: not (sample.args[0][0].dtype == torch.bool),
reason="this Aten overload only support tensor(bool) as args",
reason="this Aten overload only support tensor(bool) as indices",
),
TorchLibOpInfo(
"index_put",
core_ops.aten_index_put,
).skip(
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64),
reason="this Aten overload only support tensor(int) as args",
matcher=lambda sample: not (
(sample.args[0][0].dtype == torch.int64)
# onnxruntime: MLFloat16 data type is not supported with ScatterND when reduction is 'add'
and (sample.args[1].dtype != torch.float16 or not sample.kwargs.get("accumulate", False))
),
reason="this Aten overload only support tensor(int) as indices or float16 when reduction is 'add'",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this matcher with the original code and it shows no different in results (still 4 passes and 20 skipped). It seems that you found a corner case that is not in the current op_DB_test? If that's the case, we usually use https://github.com/microsoft/onnxscript/blob/main/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py to create corner case to test the implementation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did. I hope I did it right.

),
TorchLibOpInfo("index_select", core_ops.aten_index_select),
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
TorchLibOpInfo("isclose", core_ops.aten_isclose),
Expand Down
Loading