diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 30e9b7d33..44c6c0a87 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3974,8 +3974,6 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType: # Do not register hstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918 - - def aten_hstack(tensors: Sequence[TTensor]) -> TTensor: """hstack(Tensor[] tensors) -> Tensor""" @@ -7887,14 +7885,14 @@ def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrStr return op.ConcatFromSequence(tensors, axis=dim, new_axis=1) -@torch_op("aten::std", trace_only=True) +# std is decomposed by PyTroch def aten_std(self: TReal, unbiased: bool = True) -> TReal: """std(Tensor self, bool unbiased=True) -> Tensor""" var = _aten_var_onnx(self, correction=float(unbiased), keepdim=False) return op.Sqrt(var) -@torch_op("aten::std.dim", trace_only=True) +# std_dim is decomposed by PyTroch def aten_std_dim( self: TReal, dim: Sequence[int], @@ -7907,7 +7905,7 @@ def aten_std_dim( return op.Sqrt(var) -@torch_op("aten::var.correction", trace_only=True) +# std is decomposed by PyTroch def aten_std_correction( self: TReal, # FIXME(justinchuby): Make dim Optional[Sequence[int]] @@ -7927,7 +7925,7 @@ def aten_std_correction( return op.Sqrt(var) -@torch_op("aten::std_mean", trace_only=True) +# std_mean is decomposed by PyTroch def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: """std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)""" @@ -7937,7 +7935,7 @@ def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: return op.Sqrt(var), mean -@torch_op("aten::std_mean.dim", trace_only=True) +# std_mean is decomposed by PyTroch def aten_std_mean_dim( self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False ) -> Tuple[TReal, TReal]: @@ -7951,7 +7949,7 @@ def aten_std_mean_dim( return op.Sqrt(var), mean -@torch_op("aten::std_mean.correction", trace_only=True) +# std_mean is decomposed by PyTroch def aten_std_mean_correction( self: TReal, # FIXME(justinchuby): Make dim Optional[Sequence[int]] @@ -7973,139 +7971,6 @@ def aten_std_mean_correction( return op.Sqrt(var), mean -@torch_op("aten::stft", private=True) -def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]: - signal_rank = Rank(self) - if signal_rank == 1: - # Add a batch dimension - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - return op.Identity(self), signal_rank - - -@torch_op("aten::stft", private=True) -def _center_window_around_zeros_if_needed( - window: TFloatOrBFloat16, n_fft: int -) -> TFloatOrBFloat16: - # first dimension - n_win = op.Shape(window, start=0, end=1) - # Center window around zeros if needed (required by ONNX's STFT) - if n_win < n_fft: - left = (n_fft - n_win) / 2 - - right = n_fft - left - n_win - left = op.Reshape(left, op.Constant(value_ints=[1])) - right = op.Reshape(right, op.Constant(value_ints=[1])) - - left_win = op.Expand(op.Constant(value_ints=[0]), left) - right_win = op.Expand(op.Constant(value_ints=[0]), right) - right_win = op.CastLike(right_win, window) - left_win = op.CastLike(left_win, window) - window = op.Concat(left_win, window, right_win, axis=0) - return window - - -@torch_op("aten::stft", private=True) -def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16: - left = (n_fft - win_length) / 2 - - right = n_fft - left - win_length - left = op.Reshape(left, op.Constant(value_ints=[1])) - right = op.Reshape(right, op.Constant(value_ints=[1])) - win_length = op.Reshape(win_length, op.Constant(value_ints=[1])) - - left_win = op.Expand(op.Constant(value_ints=[0]), left) - right_win = op.Expand(op.Constant(value_ints=[0]), right) - window_list = op.Expand(op.Constant(value_ints=[1]), win_length) - return op.Concat(left_win, window_list, right_win, axis=0) - - -@torch_op("aten::stft", private=True) -def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16: - n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) - window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor) - return window - - -@torch_op("aten::stft", private=True) -def _normalize_fft_result( - signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int -) -> TFloatOrBFloat16: - n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) - sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal)) - result = result / sqrt_nfft - return result - - -@torch_op("aten::stft", private=True) -def _aten_stft_onnx( - signal: TFloatOrBFloat16, - frame_step_const: INT64, - window: Union[TFloatOrBFloat16, INT64], - frame_length_const: INT64, - signal_rank: INT64, - onesided: int, -) -> TFloatOrBFloat16: - window = op.CastLike(window, signal) - result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided) - result = op.Transpose(result, perm=[0, 2, 1, 3]) - # Remove batch dimension, if needed - if signal_rank == 1: - result = op.Squeeze(result, op.Constant(value_ints=[0])) - return result - - -@torch_op("aten::stft", trace_only=True) -def aten_stft( - self: TFloatOrBFloat16, - n_fft: int, - hop_length: Optional[int] = None, - win_length: Optional[int] = None, - window: Optional[TFloatOrBFloat16] = None, - normalized: bool = False, - onesided: Optional[bool] = None, - return_complex: Optional[bool] = None, -) -> TFloatOrBFloat16: - """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor""" - - # NOTE: regarless of the value of return_complex, we always return a real representation. - del return_complex - - # Get STFT sizes - if hop_length is None: - # core dump - # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) - hop_length = n_fft // 4 - frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) - frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1])) - - # Pre-process input if needed - self, signal_rank = _add_batch_dimension(self) - - # Get window and make sure it's the same size as `win_length` or `n_fft` - if window is not None and window.shape[0] is not None: - window = _center_window_around_zeros_if_needed(window, n_fft) - elif window is None: - if win_length is not None: - window = _create_window_from_win_length(win_length, n_fft) - else: - window = _create_window_from_n_fft(n_fft) - - if onesided is None or onesided: - onesided = 1 - else: - onesided = 0 - # remove batch dimension included - result = _aten_stft_onnx( - self, frame_step_const, window, frame_length_const, signal_rank, onesided - ) - - # Normalize, if needed - if normalized: - result = _normalize_fft_result(self, result, n_fft) - - return result - - @torch_op( ( "aten::sub.Tensor", @@ -8738,7 +8603,7 @@ def aten_vander( raise NotImplementedError() -@torch_op("aten::var", trace_only=True) +# var is decomposed by PyTroch def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal: """var(Tensor self, bool unbiased=True) -> Tensor""" @@ -8747,7 +8612,7 @@ def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal: return _aten_var_onnx(self, correction=float(unbiased), keepdim=False) -@torch_op("aten::var.dim", trace_only=True) +# var is decomposed by PyTroch def aten_var_dim( self: TReal, dim: Sequence[int], @@ -8759,7 +8624,7 @@ def aten_var_dim( return _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim) -@torch_op("aten::var.correction", trace_only=True) +# var is decomposed by PyTroch def aten_var_correction( self: TReal, # FIXME(justinchuby): Make dim Optional[Sequence[int]] @@ -8779,7 +8644,7 @@ def aten_var_correction( return var -@torch_op("aten::var", private=True, traceable=True) +# var is decomposed by PyTroch def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TReal: mean = op.ReduceMean(self, keepdims=keepdim) sub_mean = op.Sub(self, mean) @@ -8796,7 +8661,7 @@ def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TRe return var -@torch_op("aten::var.dim", private=True, traceable=True) +# var is decomposed by PyTroch def _aten_var_dim_onnx( self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False ) -> TReal: @@ -8817,7 +8682,7 @@ def _aten_var_dim_onnx( return var -@torch_op("aten::var_mean", trace_only=True) +# var_mean is decomposed by PyTroch def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: """var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)""" @@ -8826,7 +8691,7 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: return _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False) -@torch_op("aten::var_mean.dim", trace_only=True) +# var_mean is decomposed by PyTroch def aten_var_mean_dim( self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False ) -> Tuple[TReal, TReal]: @@ -8837,7 +8702,7 @@ def aten_var_mean_dim( return _aten_var_mean_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim) -@torch_op("aten::var_mean.correction", trace_only=True) +# var_mean is decomposed by PyTroch def aten_var_mean_correction( self: TReal, # FIXME(justinchuby): Make dim Optional[Sequence[int]] @@ -8859,7 +8724,7 @@ def aten_var_mean_correction( return var, mean -@torch_op("aten::var_mean", private=True) +# var_mean is decomposed by PyTroch def _aten_var_mean_onnx( self: TReal, correction: float = 1.0, keepdim: bool = False ) -> Tuple[TReal, TReal]: @@ -8879,7 +8744,7 @@ def _aten_var_mean_onnx( return var, mean -@torch_op("aten::var_mean.dim", private=True) +# var_mean is decomposed by PyTroch def _aten_var_mean_dim_onnx( self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False ) -> Tuple[TReal, TReal]: @@ -8977,8 +8842,6 @@ def aten_view_copy(self: TTensor, size: IntType) -> TTensor: # Do not register vstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918 - - def aten_vstack(tensors: Sequence[TTensor]) -> TTensor: """vstack(Tensor[] tensors) -> Tensor""" @@ -8998,6 +8861,7 @@ def reshape_to_2d(tensor): @torch_op( ( + "aten::where", "aten::where.Scalar", "aten::where.ScalarSelf", "aten::where.ScalarOther", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 7a475c9ad..3fcb7802c 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1497,33 +1497,6 @@ def _where_input_wrangler( ), TorchLibOpInfo("stack", core_ops.aten_stack), TorchLibOpInfo("stack", core_ops.aten_stack_complex, complex=True), - TorchLibOpInfo( - "std_mean", - core_ops.aten_std_mean, - ).xfail( - # kwargs is empty - matcher=lambda sample: len(sample.kwargs) > 0, - reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs", - ), - TorchLibOpInfo( - "std_mean_dim", - core_ops.aten_std_mean_dim, - ).xfail( - # kwargs["dim"] must exist, kwargs["correction"] must not exist - matcher=lambda sample: not ( - sample.kwargs.get("dim", None) is not None - and sample.kwargs.get("correction", None) is None - ), - reason="this Aten overload only support with 'dim' argument and without 'correction' argument", - ), - TorchLibOpInfo( - "std_mean_correction", - core_ops.aten_std_mean_correction, - ).skip( - # Don't accept input[1]=bool and 'correction' must be in kwargs - matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, - reason="this Aten overload only support when correction attribute exists", - ), TorchLibOpInfo("sub", core_ops.aten_sub), TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB @@ -2183,41 +2156,6 @@ def _where_input_wrangler( ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), TorchLibOpInfo("slice", core_ops.aten_slice), - TorchLibOpInfo( - "ops.aten.stft", # Custom from extra_opinfo - core_ops.aten_stft, - tolerance={torch.float32: (3.7e-5, 1.8e-4)}, - ).xfail( - dtypes=(torch.float16,), - reason="RuntimeError: MKL FFT doesn't support tensors of type: Half", - ), - TorchLibOpInfo( - "std", - core_ops.aten_std, - ).xfail( - # kwargs must be empty - matcher=lambda sample: len(sample.kwargs) > 0, - reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs", - ), - TorchLibOpInfo( - "std_dim", - core_ops.aten_std_dim, - ).xfail( - # kwargs["dim"] must exist, kwargs["correction"] must not exist - matcher=lambda sample: not ( - sample.kwargs.get("dim", None) is not None - and sample.kwargs.get("correction", None) is None - ), - reason="this Aten overload only support with 'dim' argument and without 'correction' argument", - ), - TorchLibOpInfo( - "std_correction", - core_ops.aten_std_correction, - ).skip( - # Don't accept input[1]=bool and 'correction' must be in kwargs - matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, - reason="this Aten overload only support when correction attribute exists", - ), TorchLibOpInfo( "sum", core_ops.aten_sum_dim_IntList, @@ -2238,60 +2176,6 @@ def _where_input_wrangler( ), # Custom from extra_opinfo TorchLibOpInfo("transpose", core_ops.aten_transpose), TorchLibOpInfo("transpose", core_ops.aten_transpose_complex, complex=True), - TorchLibOpInfo( - "var_mean", - core_ops.aten_var_mean, - ).xfail( - # kwargs is empty - matcher=lambda sample: len(sample.kwargs) > 0, - reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs", - ), - TorchLibOpInfo( - "var_mean_dim", - core_ops.aten_var_mean_dim, - ).xfail( - # kwargs["dim"] must exist, kwargs["correction"] must not exist - matcher=lambda sample: not ( - sample.kwargs.get("dim", None) is not None - and sample.kwargs.get("correction", None) is None - ), - reason="this Aten overload only support with 'dim' argument and without 'correction' argument", - ), - TorchLibOpInfo( - "var_mean_correction", - core_ops.aten_var_mean_correction, - ).skip( - # Don't accept input[1]=bool and 'correction' must be in kwargs - matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, - reason="this Aten overload only support when correction attribute exists", - ), - TorchLibOpInfo( - "var", - core_ops.aten_var, - ).xfail( - # kwargs must be empty - matcher=lambda sample: len(sample.kwargs) > 0, - reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs", - ), - TorchLibOpInfo( - "var_dim", - core_ops.aten_var_dim, - ).xfail( - # kwargs["dim"] must exist, kwargs["correction"] must not exist - matcher=lambda sample: not ( - sample.kwargs.get("dim", None) is not None - and sample.kwargs.get("correction", None) is None - ), - reason="this Aten overload only support with 'dim' argument and without 'correction' argument", - ), - TorchLibOpInfo( - "var_correction", - core_ops.aten_var_correction, - ).skip( - # Don't accept input[1]=bool and 'correction' must be in kwargs - matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs, - reason="this Aten overload only support when correction attribute exists", - ), TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like), TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms), ) @@ -2364,10 +2248,6 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "std_mean", ("std_mean_dim", "std_mean_correction")) -ops_test_common.duplicate_opinfo(OPS_DB, "std", ("std_dim", "std_correction")) -ops_test_common.duplicate_opinfo(OPS_DB, "var_mean", ("var_mean_dim", "var_mean_correction")) -ops_test_common.duplicate_opinfo(OPS_DB, "var", ("var_dim", "var_correction")) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",)) ops_test_common.duplicate_opinfo(OPS_DB, "view_as_real", ("view_as_real_copy",)) @@ -2510,7 +2390,6 @@ def _where_input_wrangler( "transpose", "trunc", "uniform", - "var", "where", )