Skip to content

Commit

Permalink
Add fake impl for aten.unique_dim (#126561)
Browse files Browse the repository at this point in the history
Follow-up to #113118 and #124306.

Developed in coordination with the solution to microsoft/onnxscript#1547

This PR adds the missing fake tensor implementation for `aten.unique_dim`, thus enabling tracing and compilation of `torch.unique` when `dim` is not None.

Local testing has proceeded with the following simple script (provided that one has checked out the changes in microsoft/onnxscript#1547):

```python
    import onnx
    import onnxruntime as ort
    import logging
    import numpy as np
    onnx_program = torch.onnx.dynamo_export(
        lambda x: torch.unique(x,
                               dim=0,
                               return_inverse=True),
        torch.arange(10),
        export_options=torch.onnx.ExportOptions(
            dynamic_shapes=True,
            diagnostic_options=torch.onnx.DiagnosticOptions(
                verbosity_level=logging.DEBUG)))
    onnx_program.save("torch_unique.onnx")
    onnx_inputs = onnx_program.adapt_torch_inputs_to_onnx(torch.arange(10))
    onnx_outputs = onnx_program(*onnx_inputs)
    loaded_onnx_program = onnx.load("torch_unique.onnx")
    onnx.checker.check_model(loaded_onnx_program)
    ort_session = ort.InferenceSession("torch_unique.onnx")
    inputs = np.random.randint(0, 10, 10)
    print(f"Inputs: {inputs}")
    outputs = ort_session.run(None,
                              {
                                  "l_x_": inputs
                              })
    print(f"Outputs: {outputs}")
    print("Success")
```

Co-authored-by: Edward Z. Yang <[email protected]>
Pull Request resolved: #126561
Approved by: https://github.com/ezyang
  • Loading branch information
a-gardner1 authored and pytorchmergebot committed Jun 1, 2024
1 parent 25447ba commit 3c1cf03
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 19 deletions.
4 changes: 2 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2522,8 +2522,8 @@ def map_to_fake(e):
or name in sometimes_dynamic_output_op_test
)
self.assertTrue(
mode.shape_env is None
or not mode.shape_env.allow_dynamic_output_shape_ops
fake_mode.shape_env is None
or not fake_mode.shape_env.allow_dynamic_output_shape_ops
or name not in supported_dynamic_output_op_tests
)
except torch._subclasses.fake_tensor.DataDependentOutputException:
Expand Down
3 changes: 0 additions & 3 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,7 +2003,6 @@ def f(t):
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition

xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...

Expand Down Expand Up @@ -2034,8 +2033,6 @@ def f(t):
inplace_symbolic_tensor_failures = {
# bugs
xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double
# decomp not implemented
xfail('unique', ''),
}

out_symbolic_tensor_failures = {
Expand Down
59 changes: 45 additions & 14 deletions torch/_subclasses/fake_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,8 @@ def dyn_shape(fake_mode, func, *args, **kwargs):
raise DynamicOutputShapeException(func)


@register_op_impl(aten._unique2.default)
def unique2(
fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
def _unique(
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
):
if (
fake_mode.shape_env is None
Expand All @@ -269,7 +268,8 @@ def unique2(
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)

if (nnz := arg.unique_memo) is None:
# Do not use a memo for unique_dim
if dim is not None or (nnz := arg.unique_memo) is None:
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
Expand All @@ -291,28 +291,59 @@ def unique2(

maxval = sys.maxsize - 1

if not has_free_symbols(arg.numel()):
maxval = int(arg.numel())
numel = arg.numel() if dim is None else arg.size(dim)
if not has_free_symbols(numel):
maxval = int(numel)

_constrain_range_for_size(nnz, max=maxval)

arg.unique_memo = nnz
if dim is None:
arg.unique_memo = nnz

ret = [arg.new_empty((nnz,))]
if dim is None:
ret = [arg.new_empty((nnz,))]
else:
ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])]

if return_inverse:
ret.append(torch.empty_like(arg))
return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu")
if return_inverse or return_if_dim_and_cpu:
inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],))
else:
ret.append(arg.new_empty(0))
inverse = arg.new_empty(0)
ret.append(inverse)

if return_counts:
ret.append(torch.empty_like(arg))
if return_counts or return_if_dim_and_cpu:
counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],))
else:
ret.append(arg.new_empty(0))
counts = arg.new_empty(0)
ret.append(counts)

return tuple(ret)


@register_op_impl(aten._unique2.default)
def unique2(
fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
):
return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)


@register_op_impl(aten.unique_dim.default)
def unique_dim(
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
):
return _unique(
fake_mode,
func,
arg,
# normalize dim to be non-negative
dim if dim >= 0 else dim % max(arg.ndim, 1),
sorted,
return_inverse,
return_counts,
)


@register_op_impl(aten.repeat_interleave.Tensor)
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
if output_size is None:
Expand Down

0 comments on commit 3c1cf03

Please sign in to comment.