diff --git a/test/test_ops.py b/test/test_ops.py index 44f503ae9b6ed..cbec88136ed27 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2522,8 +2522,8 @@ def map_to_fake(e): or name in sometimes_dynamic_output_op_test ) self.assertTrue( - mode.shape_env is None - or not mode.shape_env.allow_dynamic_output_shape_ops + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops or name not in supported_dynamic_output_op_tests ) except torch._subclasses.fake_tensor.DataDependentOutputException: diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index d8aa8863d5666..c7b2e51ced20e 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -2003,7 +2003,6 @@ def f(t): xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition - xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but... @@ -2034,8 +2033,6 @@ def f(t): inplace_symbolic_tensor_failures = { # bugs xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double - # decomp not implemented - xfail('unique', ''), } out_symbolic_tensor_failures = { diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 4376d24255ef7..2b1cf13cc9358 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -258,9 +258,8 @@ def dyn_shape(fake_mode, func, *args, **kwargs): raise DynamicOutputShapeException(func) -@register_op_impl(aten._unique2.default) -def unique2( - fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False +def _unique( + fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False ): if ( fake_mode.shape_env is None @@ -269,7 +268,8 @@ def unique2( # Without symints/symfloats, cannot handle this raise DynamicOutputShapeException(func) - if (nnz := arg.unique_memo) is None: + # Do not use a memo for unique_dim + if dim is not None or (nnz := arg.unique_memo) is None: # Avoid importing sympy at a module level from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, @@ -291,28 +291,59 @@ def unique2( maxval = sys.maxsize - 1 - if not has_free_symbols(arg.numel()): - maxval = int(arg.numel()) + numel = arg.numel() if dim is None else arg.size(dim) + if not has_free_symbols(numel): + maxval = int(numel) _constrain_range_for_size(nnz, max=maxval) - arg.unique_memo = nnz + if dim is None: + arg.unique_memo = nnz - ret = [arg.new_empty((nnz,))] + if dim is None: + ret = [arg.new_empty((nnz,))] + else: + ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])] - if return_inverse: - ret.append(torch.empty_like(arg)) + return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu") + if return_inverse or return_if_dim_and_cpu: + inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],)) else: - ret.append(arg.new_empty(0)) + inverse = arg.new_empty(0) + ret.append(inverse) - if return_counts: - ret.append(torch.empty_like(arg)) + if return_counts or return_if_dim_and_cpu: + counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],)) else: - ret.append(arg.new_empty(0)) + counts = arg.new_empty(0) + ret.append(counts) return tuple(ret) +@register_op_impl(aten._unique2.default) +def unique2( + fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False +): + return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) + + +@register_op_impl(aten.unique_dim.default) +def unique_dim( + fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False +): + return _unique( + fake_mode, + func, + arg, + # normalize dim to be non-negative + dim if dim >= 0 else dim % max(arg.ndim, 1), + sorted, + return_inverse, + return_counts, + ) + + @register_op_impl(aten.repeat_interleave.Tensor) def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): if output_size is None: