Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix implementation of index_put, validate it with dort | fix(torchlib) #1277

Merged
merged 14 commits into from
Mar 13, 2024
170 changes: 170 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/dynamo_export_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import copy
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
import inspect
import itertools
import sys
import unittest

import torch
from torch.onnx import ExportOptions
from torch.onnx import _OrtBackend as OrtBackend
from torch.onnx import _OrtBackendOptions as OrtBackendOptions


class FuncModule(torch.nn.Module):
def __init__(self, f, params=None):
if params is None:
params = ()
super().__init__()
self.f = f
self.ppp = torch.nn.Parameter(torch.Tensor([1]))
self.params = torch.nn.ParameterList(list(params))

def forward(self, *args):
f_args = list(itertools.chain(args, self.params))
f_args[0] = f_args[0] + self.ppp
res = self.f(*f_args)
return res


class FuncModuleModule(torch.nn.Module):
def __init__(self, f):
super().__init__()
self.f = f
self.mod = f
self.ppp = torch.nn.Parameter(torch.Tensor([1]))

def forward(self, *args):
x = args[0] + self.ppp
res = self.mod(x, *args[1:])
return res


def make_aot_ort(dynamic: bool = False):

ort_backend = OrtBackend(
options=OrtBackendOptions(
export_options=ExportOptions(
dynamic_shapes=dynamic,
)
)
)
return ort_backend, ort_backend


class TestOperatorsOnnxrt(unittest.TestCase):
def setUp(self):
super().setUp()
torch._dynamo.reset()
Fixed Show fixed Hide fixed

def assertONNX(
self,
f,
args,
onnx_export: str,
params=None,
fullgraph: bool = True,
atol=1e-6,
rtol=1e-6,
opset_version=None,
Fixed Show fixed Hide fixed
test_backward=True,
operator_export_type=None,
Fixed Show fixed Hide fixed
impl="ort",
Fixed Show fixed Hide fixed
#
input_names=None,
Fixed Show fixed Hide fixed
dynamic_axes=None,
Fixed Show fixed Hide fixed
keep_initializers_as_inputs=None,
Fixed Show fixed Hide fixed
training=None,
Fixed Show fixed Hide fixed
):
if sys.platform == "win32":
raise unittest.SkipTest("Windows not supported yet.")
assert isinstance(onnx_export, str), f"Export onnx is wrong for f={f}"
if isinstance(args, torch.Tensor):
args = [args]
if params is None:
params = ()
if isinstance(f, torch.nn.Module):
model = FuncModuleModule(f)
else:
model = FuncModule(f, params)
model.eval()

if test_backward:
# forward/backward
local_aot_ort, local_ort = make_aot_ort(dynamic=False)
Fixed Show fixed Hide fixed

compiled_model = torch.compile(
copy.deepcopy(model),
backend=local_aot_ort,
dynamic=False,
fullgraph=fullgraph,
)

baseline_result = model(*args)
result = compiled_model(*args)

if isinstance(baseline_result, tuple):
baseline_result = baseline_result[0]
result = result[0]
if isinstance(baseline_result, torch.Tensor):
torch.testing.assert_close(
baseline_result, result, atol=atol, rtol=rtol, equal_nan=True
)

baseline_result.sum().backward()
result.sum().backward()

l1 = list(model.parameters())
l2 = list(compiled_model.parameters())
self.assertEqual(len(l1), len(l2))
assert len(l1) > 0, "No gradient to test"
n_gradient = 0
for baseline_param, param in zip(l1, l2):
n_gradient += 1
torch.testing.assert_close(
baseline_param.grad,
param.grad,
atol=atol,
rtol=rtol,
equal_nan=True,
)
assert n_gradient > 0, "No gradient was checked"
else:
raise AssertionError(f"Unexpected type {type(baseline_result)}.")
else:
# forward only
compiled_model = torch.compile(
copy.deepcopy(model),
backend="onnxrt",
dynamic=False,
fullgraph=fullgraph,
)

baseline_result = model(*args)
result = compiled_model(*args)

if isinstance(baseline_result, torch.Tensor):
torch.testing.assert_close(
baseline_result, result, atol=atol, rtol=rtol, equal_nan=True
)

def test_add(self):
x = torch.zeros((10, 3), requires_grad=True, dtype=torch.float32)
self.assertONNX(lambda x: x + x, x, onnx_export=inspect.currentframe().f_code.co_name)

def test_index_put(self):
x = torch.zeros((10, 3), requires_grad=True, dtype=torch.float32)
indices = torch.arange(8, dtype=torch.int64).reshape((-1, 4))
values = torch.arange(24, dtype=torch.float32).reshape((-1, 4, 3))

Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
# redondant test to make sure this expression is valid for torch
assert x.index_put((indices,), values) is not None

self.assertONNX(
lambda x, indices, values: x.index_put((indices,), values),
(x, indices, values),
onnx_export=inspect.currentframe().f_code.co_name,
)


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading