From 99ae64e9befd77706e62f547b4eeb1d395496668 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 24 Sep 2024 09:40:09 -0700 Subject: [PATCH] [torchlib] Implement upsample_nearest{nd}.vec (#1874) --- onnxscript/function_libs/torch_lib/ops/nn.py | 106 +++++------ tests/function_libs/torch_lib/extra_opinfo.py | 166 +++++++++++++++--- .../function_libs/torch_lib/ops_test_data.py | 13 +- 3 files changed, 210 insertions(+), 75 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index d5abcac71..c9c030f0c 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2355,17 +2355,6 @@ def _get_upsample_align_corners_mode(align_corners: bool) -> str: return "align_corners" if align_corners else "pytorch_half_pixel" -@torch_op( - ( - "aten::upsample_bicubic2d", - "aten::upsample_bilinear2d", - "aten::upsample_nearest1d", - "aten::upsample_nearest2d", - "aten::upsample_nearest3d", - "aten::upsample_trilinear3d", - ), - private=True, -) def _aten_upsample_output_size( self: TReal, output_size: INT64, @@ -2388,22 +2377,22 @@ def _aten_upsample_output_size( ) -@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True) def _aten_upsample_scales( self: TReal, - scale_factors: TFloat, + scale_factors: Sequence[float], mode: str, coordinate_transformation_mode: str, ) -> TReal: - scale_factors = op.Cast(scale_factors, to=FLOAT.dtype) - scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0) return op.Resize( self, None, - scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] + op.Constant( + value_floats=[1.0, 1.0, *scale_factors] + ), # format should be: [1.0, 1.0, scale_h, scale_w] None, mode=mode, coordinate_transformation_mode=coordinate_transformation_mode, + nearest_mode="floor", ) @@ -2441,7 +2430,7 @@ def aten_upsample_bicubic2d_vec( if scale_factors is not None: result = _aten_upsample_scales( self, - op.Constant(value_floats=scale_factors), + scale_factors, mode="cubic", coordinate_transformation_mode=coordinate_transformation_mode, ) @@ -2503,11 +2492,12 @@ def aten_upsample_bilinear2d_vec( if scale_factors is not None: result = _aten_upsample_scales( self, - op.Constant(value_floats=scale_factors), + scale_factors, mode="linear", coordinate_transformation_mode=coordinate_transformation_mode, ) else: + assert output_size is not None result = _aten_upsample_output_size( self, output_size, @@ -2536,9 +2526,8 @@ def aten_upsample_linear1d( self: TReal, output_size: INT64, align_corners: bool, scales: Optional[float] = None ) -> TReal: """upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor""" - # FIXME(justinchuby): Support when scales is provided and align_corners is False - del scales coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + # scales is ignored in PyTorch return _aten_upsample_output_size( self, output_size, @@ -2561,31 +2550,35 @@ def aten_upsample_linear1d_backward( @torch_op("aten::upsample_nearest1d", trace_only=True) def aten_upsample_nearest1d( - self: TReal, size: INT64, scale_factor: Optional[float] = None + self: TReal, output_size: INT64, scales: Optional[float] = None ) -> TReal: """upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor""" - if size is not None: - return _aten_upsample_output_size(self, size, "nearest", "asymmetric") + if scales is not None: + return _aten_upsample_scales(self, [scales], "nearest", "asymmetric") else: - return _aten_upsample_nearest1d_scales(self, scale_factor) + return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric") -@torch_op("aten::upsample_nearest1d", private=True) -def _aten_upsample_nearest1d_scales( - self: TReal, - scale_factors: TFloat, +@torch_op( + ( + "aten::upsample_nearest1d.vec", + "aten::upsample_nearest2d.vec", + "aten::upsample_nearest3d.vec", + ), + trace_only=True, +) +def aten_upsample_nearestnd_vec( + input: TReal, + output_size: Optional[INT64] = None, + scale_factors: Optional[Sequence[float]] = None, ) -> TReal: - scale_factors = op.Cast(scale_factors, to=FLOAT.dtype) - scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0) - return op.Resize( - self, - None, - scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w] - None, - mode="nearest", - coordinate_transformation_mode="asymmetric", - nearest_mode="floor", - ) + """upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor""" + + if scale_factors is not None: + return _aten_upsample_scales(input, scale_factors, "nearest", "asymmetric") + else: + assert output_size is not None + return _aten_upsample_output_size(input, output_size, "nearest", "asymmetric") def aten_upsample_nearest1d_backward( @@ -2602,18 +2595,21 @@ def aten_upsample_nearest1d_backward( @torch_op("aten::upsample_nearest2d", trace_only=True) def aten_upsample_nearest2d( self: TReal, - size: INT64, + output_size: INT64, scales_h: Optional[float] = None, scales_w: Optional[float] = None, ) -> TReal: """upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor""" - # NOTE: trace_only because optional attributes are not supported by ONNX - # TODO(justinchuby): Conditionally use scales - del scales_h - del scales_w - - return _aten_upsample_output_size(self, size, "nearest", "asymmetric") + if scales_h is not None and scales_w is not None: + return _aten_upsample_scales( + self, + [scales_h, scales_w], + "nearest", + "asymmetric", + ) + else: + return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric") def aten_upsample_nearest2d_backward( @@ -2631,18 +2627,22 @@ def aten_upsample_nearest2d_backward( @torch_op("aten::upsample_nearest3d", trace_only=True) def aten_upsample_nearest3d( self: TReal, - size: INT64, + output_size: INT64, scales_d: Optional[float] = None, scales_h: Optional[float] = None, scales_w: Optional[float] = None, ) -> TReal: """upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor""" - del scales_h - del scales_w - del scales_d - - return _aten_upsample_output_size(self, size, "nearest", "asymmetric") + if scales_d is not None and scales_h is not None and scales_w is not None: + return _aten_upsample_scales( + self, + [scales_d, scales_h, scales_w], + "nearest", + "asymmetric", + ) + else: + return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric") def aten_upsample_nearest3d_backward( @@ -2695,7 +2695,7 @@ def aten_upsample_trilinear3d_vec( if scale_factors is not None: result = _aten_upsample_scales( self, - op.Constant(value_floats=scale_factors), + scale_factors, mode="linear", coordinate_transformation_mode=coordinate_transformation_mode, ) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 756a74027..91f1df916 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1539,7 +1539,7 @@ def shape(size, rank, with_batch_channel=True): None, # output_size align_corners, ), - kwargs=dict(scale_factors=(1.7, 1.7)), + kwargs=dict(scale_factors=[1.7, 1.7]), ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), @@ -1547,7 +1547,7 @@ def shape(size, rank, with_batch_channel=True): None, # if this is None, the scalar must be list align_corners, ), - kwargs=dict(scale_factors=(0.6, 0.6)), + kwargs=dict(scale_factors=[0.6, 0.6]), ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), @@ -1555,7 +1555,7 @@ def shape(size, rank, with_batch_channel=True): None, # if this is None, the scalar must be list align_corners, ), - kwargs=dict(scale_factors=(0.6, 4.2)), + kwargs=dict(scale_factors=[0.6, 4.2]), ) @@ -1605,7 +1605,6 @@ def sample_inputs_upsample_nearest1d(op_info, device, dtype, requires_grad, **kw N, C = 2, 3 D = 4 - SS = 3 L = 5 rank = 1 @@ -1624,8 +1623,6 @@ def shape(size, rank, with_batch_channel=True): high=1, ) - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) - yield opinfo_core.SampleInput( make_arg(shape(D, rank)), shape(S, rank, False), @@ -1634,15 +1631,53 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), shape(L, rank, False), ) + # yield opinfo_core.SampleInput( + # make_arg(shape(D, rank)), + # shape(S, rank, False), # output_size + # [1.7], # scaler + # ) + # yield opinfo_core.SampleInput( + # make_arg(shape(D, rank)), + # shape(S, rank, False), # if this is None, the scalar must be list + # [0.6], + # ) + + +def sample_inputs_upsample_nearest1d_vec(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + L = 5 + + rank = 1 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), None) + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(L, rank, False), None) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), None, # output_size - (1.7,), # scaler + scale_factors=(1.7,), ) yield opinfo_core.SampleInput( make_arg(shape(D, rank)), - None, # if this is None, the scalar must be list - (0.6,), + None, + scale_factors=(0.6,), ) @@ -1652,7 +1687,6 @@ def sample_inputs_upsample_nearest2d(op_info, device, dtype, requires_grad, **kw N, C = 2, 3 D = 4 - SS = 3 L = 5 rank = 2 @@ -1671,8 +1705,6 @@ def shape(size, rank, with_batch_channel=True): high=1, ) - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) - yield opinfo_core.SampleInput( make_arg(shape(D, rank)), shape(S, rank, False), @@ -1681,26 +1713,62 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), shape(L, rank, False), ) - # ONNX don't support below cases: both output_size and scaler are not None # yield opinfo_core.SampleInput( # make_arg(shape(D, rank)), # shape(L, rank, False), - # 1.7, # scaler + # 1.7, 2.0, # scaler # ) # yield opinfo_core.SampleInput( # make_arg(shape(D, rank)), # shape(L, rank, False), - # 0.6, + # 0.6, 0.4, # ) +def sample_inputs_upsample_nearest2d_vec(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + L = 5 + + rank = 2 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), None) + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(L, rank, False), None) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, + scale_factors=(1.7, 2.0), + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, + scale_factors=(0.6, 0.4), + ) + + def sample_inputs_upsample_nearest3d(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs N, C = 2, 3 D = 4 - SS = 3 L = 5 rank = 3 @@ -1719,8 +1787,6 @@ def shape(size, rank, with_batch_channel=True): high=1, ) - yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(SS, rank, False), True) - yield opinfo_core.SampleInput( make_arg(shape(D, rank)), shape(S, rank, False), @@ -1729,19 +1795,56 @@ def shape(size, rank, with_batch_channel=True): make_arg(shape(D, rank)), shape(L, rank, False), ) - # ONNX don't support below cases: both output_size and scaler are not None # yield opinfo_core.SampleInput( # make_arg(shape(D, rank)), # shape(L, rank, False), - # 1.7, # scaler + # 1.7, 1.5, 2.0, # scaler # ) # yield opinfo_core.SampleInput( # make_arg(shape(D, rank)), # shape(L, rank, False), - # 0.6, + # 0.6, 0.3, 0.5, # ) +def sample_inputs_upsample_nearest3d_vec(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + N, C = 2, 3 + D = 4 + L = 5 + + rank = 3 + + def shape(size, rank, with_batch_channel=True): + if with_batch_channel: + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) + + make_arg = functools.partial( + torch_testing.make_tensor, + device=device, + dtype=dtype, + requires_grad=requires_grad, + low=-1, + high=1, + ) + + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(S, rank, False), None) + yield opinfo_core.SampleInput(make_arg(shape(D, rank)), shape(L, rank, False), None) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, + scale_factors=(1.7, 1.5, 2.0), # scaler + ) + yield opinfo_core.SampleInput( + make_arg(shape(D, rank)), + None, + scale_factors=(0.6, 0.3, 0.5), + ) + + def sample_inputs_upsample_trilinear3d(op_info, device, dtype, requires_grad, **kwargs): del op_info del kwargs @@ -2345,6 +2448,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_nearest1d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_nearest1d.vec", + aten_name="upsample_nearest1d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_nearest1d_vec, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_nearest2d", aten_name="upsample_nearest2d", @@ -2352,6 +2462,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_nearest2d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_nearest2d.vec", + aten_name="upsample_nearest2d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_nearest2d_vec, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_nearest3d", aten_name="upsample_nearest3d", @@ -2359,6 +2476,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_nearest3d, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.upsample_nearest3d.vec", + aten_name="upsample_nearest3d.vec", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_upsample_nearest3d_vec, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.upsample_trilinear3d.default", aten_name="upsample_trilinear3d", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 3fcb7802c..3f6be88e8 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -2096,14 +2096,26 @@ def _where_input_wrangler( "ops.aten.upsample_nearest1d", nn_ops.aten_upsample_nearest1d, ), + TorchLibOpInfo( + "ops.aten.upsample_nearest1d.vec", + nn_ops.aten_upsample_nearestnd_vec, + ), TorchLibOpInfo( "ops.aten.upsample_nearest2d", nn_ops.aten_upsample_nearest2d, ), + TorchLibOpInfo( + "ops.aten.upsample_nearest2d.vec", + nn_ops.aten_upsample_nearestnd_vec, + ), TorchLibOpInfo( "ops.aten.upsample_nearest3d", nn_ops.aten_upsample_nearest3d, ), + TorchLibOpInfo( + "ops.aten.upsample_nearest3d.vec", + nn_ops.aten_upsample_nearestnd_vec, + ), TorchLibOpInfo( "ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d, @@ -2379,7 +2391,6 @@ def _where_input_wrangler( "signbit", "sin", "sinh", - "slice", "sqrt", "squeeze", "sub",