Skip to content

Commit

Permalink
[torchlib] Unregister stft, var, var_mean, std, std_mean (#1867)
Browse files Browse the repository at this point in the history
Following pytorch/pytorch#136153, we remove
stft, var, var_mean, std, std_mean ops. They were never used even before
because the ops were always decomposed.
  • Loading branch information
justinchuby authored Sep 16, 2024
1 parent 377869a commit 1eef633
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 274 deletions.
170 changes: 17 additions & 153 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3974,8 +3974,6 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType:


# Do not register hstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918


def aten_hstack(tensors: Sequence[TTensor]) -> TTensor:
"""hstack(Tensor[] tensors) -> Tensor"""

Expand Down Expand Up @@ -7887,14 +7885,14 @@ def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrStr
return op.ConcatFromSequence(tensors, axis=dim, new_axis=1)


@torch_op("aten::std", trace_only=True)
# std is decomposed by PyTroch
def aten_std(self: TReal, unbiased: bool = True) -> TReal:
"""std(Tensor self, bool unbiased=True) -> Tensor"""
var = _aten_var_onnx(self, correction=float(unbiased), keepdim=False)
return op.Sqrt(var)


@torch_op("aten::std.dim", trace_only=True)
# std_dim is decomposed by PyTroch
def aten_std_dim(
self: TReal,
dim: Sequence[int],
Expand All @@ -7907,7 +7905,7 @@ def aten_std_dim(
return op.Sqrt(var)


@torch_op("aten::var.correction", trace_only=True)
# std is decomposed by PyTroch
def aten_std_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -7927,7 +7925,7 @@ def aten_std_correction(
return op.Sqrt(var)


@torch_op("aten::std_mean", trace_only=True)
# std_mean is decomposed by PyTroch
def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
"""std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""

Expand All @@ -7937,7 +7935,7 @@ def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
return op.Sqrt(var), mean


@torch_op("aten::std_mean.dim", trace_only=True)
# std_mean is decomposed by PyTroch
def aten_std_mean_dim(
self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand All @@ -7951,7 +7949,7 @@ def aten_std_mean_dim(
return op.Sqrt(var), mean


@torch_op("aten::std_mean.correction", trace_only=True)
# std_mean is decomposed by PyTroch
def aten_std_mean_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -7973,139 +7971,6 @@ def aten_std_mean_correction(
return op.Sqrt(var), mean


@torch_op("aten::stft", private=True)
def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]:
signal_rank = Rank(self)
if signal_rank == 1:
# Add a batch dimension
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
return op.Identity(self), signal_rank


@torch_op("aten::stft", private=True)
def _center_window_around_zeros_if_needed(
window: TFloatOrBFloat16, n_fft: int
) -> TFloatOrBFloat16:
# first dimension
n_win = op.Shape(window, start=0, end=1)
# Center window around zeros if needed (required by ONNX's STFT)
if n_win < n_fft:
left = (n_fft - n_win) / 2

right = n_fft - left - n_win
left = op.Reshape(left, op.Constant(value_ints=[1]))
right = op.Reshape(right, op.Constant(value_ints=[1]))

left_win = op.Expand(op.Constant(value_ints=[0]), left)
right_win = op.Expand(op.Constant(value_ints=[0]), right)
right_win = op.CastLike(right_win, window)
left_win = op.CastLike(left_win, window)
window = op.Concat(left_win, window, right_win, axis=0)
return window


@torch_op("aten::stft", private=True)
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16:
left = (n_fft - win_length) / 2

right = n_fft - left - win_length
left = op.Reshape(left, op.Constant(value_ints=[1]))
right = op.Reshape(right, op.Constant(value_ints=[1]))
win_length = op.Reshape(win_length, op.Constant(value_ints=[1]))

left_win = op.Expand(op.Constant(value_ints=[0]), left)
right_win = op.Expand(op.Constant(value_ints=[0]), right)
window_list = op.Expand(op.Constant(value_ints=[1]), win_length)
return op.Concat(left_win, window_list, right_win, axis=0)


@torch_op("aten::stft", private=True)
def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16:
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
return window


@torch_op("aten::stft", private=True)
def _normalize_fft_result(
signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int
) -> TFloatOrBFloat16:
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
result = result / sqrt_nfft
return result


@torch_op("aten::stft", private=True)
def _aten_stft_onnx(
signal: TFloatOrBFloat16,
frame_step_const: INT64,
window: Union[TFloatOrBFloat16, INT64],
frame_length_const: INT64,
signal_rank: INT64,
onesided: int,
) -> TFloatOrBFloat16:
window = op.CastLike(window, signal)
result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)
result = op.Transpose(result, perm=[0, 2, 1, 3])
# Remove batch dimension, if needed
if signal_rank == 1:
result = op.Squeeze(result, op.Constant(value_ints=[0]))
return result


@torch_op("aten::stft", trace_only=True)
def aten_stft(
self: TFloatOrBFloat16,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[TFloatOrBFloat16] = None,
normalized: bool = False,
onesided: Optional[bool] = None,
return_complex: Optional[bool] = None,
) -> TFloatOrBFloat16:
"""stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""

# NOTE: regarless of the value of return_complex, we always return a real representation.
del return_complex

# Get STFT sizes
if hop_length is None:
# core dump
# hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
hop_length = n_fft // 4
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))

# Pre-process input if needed
self, signal_rank = _add_batch_dimension(self)

# Get window and make sure it's the same size as `win_length` or `n_fft`
if window is not None and window.shape[0] is not None:
window = _center_window_around_zeros_if_needed(window, n_fft)
elif window is None:
if win_length is not None:
window = _create_window_from_win_length(win_length, n_fft)
else:
window = _create_window_from_n_fft(n_fft)

if onesided is None or onesided:
onesided = 1
else:
onesided = 0
# remove batch dimension included
result = _aten_stft_onnx(
self, frame_step_const, window, frame_length_const, signal_rank, onesided
)

# Normalize, if needed
if normalized:
result = _normalize_fft_result(self, result, n_fft)

return result


@torch_op(
(
"aten::sub.Tensor",
Expand Down Expand Up @@ -8738,7 +8603,7 @@ def aten_vander(
raise NotImplementedError()


@torch_op("aten::var", trace_only=True)
# var is decomposed by PyTroch
def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
"""var(Tensor self, bool unbiased=True) -> Tensor"""

Expand All @@ -8747,7 +8612,7 @@ def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
return _aten_var_onnx(self, correction=float(unbiased), keepdim=False)


@torch_op("aten::var.dim", trace_only=True)
# var is decomposed by PyTroch
def aten_var_dim(
self: TReal,
dim: Sequence[int],
Expand All @@ -8759,7 +8624,7 @@ def aten_var_dim(
return _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)


@torch_op("aten::var.correction", trace_only=True)
# var is decomposed by PyTroch
def aten_var_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -8779,7 +8644,7 @@ def aten_var_correction(
return var


@torch_op("aten::var", private=True, traceable=True)
# var is decomposed by PyTroch
def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TReal:
mean = op.ReduceMean(self, keepdims=keepdim)
sub_mean = op.Sub(self, mean)
Expand All @@ -8796,7 +8661,7 @@ def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TRe
return var


@torch_op("aten::var.dim", private=True, traceable=True)
# var is decomposed by PyTroch
def _aten_var_dim_onnx(
self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False
) -> TReal:
Expand All @@ -8817,7 +8682,7 @@ def _aten_var_dim_onnx(
return var


@torch_op("aten::var_mean", trace_only=True)
# var_mean is decomposed by PyTroch
def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
"""var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""

Expand All @@ -8826,7 +8691,7 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
return _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False)


@torch_op("aten::var_mean.dim", trace_only=True)
# var_mean is decomposed by PyTroch
def aten_var_mean_dim(
self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand All @@ -8837,7 +8702,7 @@ def aten_var_mean_dim(
return _aten_var_mean_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)


@torch_op("aten::var_mean.correction", trace_only=True)
# var_mean is decomposed by PyTroch
def aten_var_mean_correction(
self: TReal,
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
Expand All @@ -8859,7 +8724,7 @@ def aten_var_mean_correction(
return var, mean


@torch_op("aten::var_mean", private=True)
# var_mean is decomposed by PyTroch
def _aten_var_mean_onnx(
self: TReal, correction: float = 1.0, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand All @@ -8879,7 +8744,7 @@ def _aten_var_mean_onnx(
return var, mean


@torch_op("aten::var_mean.dim", private=True)
# var_mean is decomposed by PyTroch
def _aten_var_mean_dim_onnx(
self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False
) -> Tuple[TReal, TReal]:
Expand Down Expand Up @@ -8977,8 +8842,6 @@ def aten_view_copy(self: TTensor, size: IntType) -> TTensor:


# Do not register vstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918


def aten_vstack(tensors: Sequence[TTensor]) -> TTensor:
"""vstack(Tensor[] tensors) -> Tensor"""

Expand All @@ -8998,6 +8861,7 @@ def reshape_to_2d(tensor):

@torch_op(
(
"aten::where",
"aten::where.Scalar",
"aten::where.ScalarSelf",
"aten::where.ScalarOther",
Expand Down
Loading

0 comments on commit 1eef633

Please sign in to comment.