Skip to content

Commit

Permalink
[torchlib] Implement upsample_nearest{nd}.vec (#1874)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Sep 24, 2024
1 parent 65bc496 commit 99ae64e
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 75 deletions.
106 changes: 53 additions & 53 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2355,17 +2355,6 @@ def _get_upsample_align_corners_mode(align_corners: bool) -> str:
return "align_corners" if align_corners else "pytorch_half_pixel"


@torch_op(
(
"aten::upsample_bicubic2d",
"aten::upsample_bilinear2d",
"aten::upsample_nearest1d",
"aten::upsample_nearest2d",
"aten::upsample_nearest3d",
"aten::upsample_trilinear3d",
),
private=True,
)
def _aten_upsample_output_size(
self: TReal,
output_size: INT64,
Expand All @@ -2388,22 +2377,22 @@ def _aten_upsample_output_size(
)


@torch_op(("aten::upsample_bicubic2d", "aten::upsample_bilinear2d"), private=True)
def _aten_upsample_scales(
self: TReal,
scale_factors: TFloat,
scale_factors: Sequence[float],
mode: str,
coordinate_transformation_mode: str,
) -> TReal:
scale_factors = op.Cast(scale_factors, to=FLOAT.dtype)
scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0)
return op.Resize(
self,
None,
scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
op.Constant(
value_floats=[1.0, 1.0, *scale_factors]
), # format should be: [1.0, 1.0, scale_h, scale_w]
None,
mode=mode,
coordinate_transformation_mode=coordinate_transformation_mode,
nearest_mode="floor",
)


Expand Down Expand Up @@ -2441,7 +2430,7 @@ def aten_upsample_bicubic2d_vec(
if scale_factors is not None:
result = _aten_upsample_scales(
self,
op.Constant(value_floats=scale_factors),
scale_factors,
mode="cubic",
coordinate_transformation_mode=coordinate_transformation_mode,
)
Expand Down Expand Up @@ -2503,11 +2492,12 @@ def aten_upsample_bilinear2d_vec(
if scale_factors is not None:
result = _aten_upsample_scales(
self,
op.Constant(value_floats=scale_factors),
scale_factors,
mode="linear",
coordinate_transformation_mode=coordinate_transformation_mode,
)
else:
assert output_size is not None
result = _aten_upsample_output_size(
self,
output_size,
Expand Down Expand Up @@ -2536,9 +2526,8 @@ def aten_upsample_linear1d(
self: TReal, output_size: INT64, align_corners: bool, scales: Optional[float] = None
) -> TReal:
"""upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor"""
# FIXME(justinchuby): Support when scales is provided and align_corners is False
del scales
coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners)
# scales is ignored in PyTorch
return _aten_upsample_output_size(
self,
output_size,
Expand All @@ -2561,31 +2550,35 @@ def aten_upsample_linear1d_backward(

@torch_op("aten::upsample_nearest1d", trace_only=True)
def aten_upsample_nearest1d(
self: TReal, size: INT64, scale_factor: Optional[float] = None
self: TReal, output_size: INT64, scales: Optional[float] = None
) -> TReal:
"""upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor"""
if size is not None:
return _aten_upsample_output_size(self, size, "nearest", "asymmetric")
if scales is not None:
return _aten_upsample_scales(self, [scales], "nearest", "asymmetric")
else:
return _aten_upsample_nearest1d_scales(self, scale_factor)
return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric")


@torch_op("aten::upsample_nearest1d", private=True)
def _aten_upsample_nearest1d_scales(
self: TReal,
scale_factors: TFloat,
@torch_op(
(
"aten::upsample_nearest1d.vec",
"aten::upsample_nearest2d.vec",
"aten::upsample_nearest3d.vec",
),
trace_only=True,
)
def aten_upsample_nearestnd_vec(
input: TReal,
output_size: Optional[INT64] = None,
scale_factors: Optional[Sequence[float]] = None,
) -> TReal:
scale_factors = op.Cast(scale_factors, to=FLOAT.dtype)
scale_factors = op.Concat(op.Constant(value_floats=[1.0, 1.0]), scale_factors, axis=0)
return op.Resize(
self,
None,
scale_factors, # format should be: [1.0, 1.0, scale_h, scale_w]
None,
mode="nearest",
coordinate_transformation_mode="asymmetric",
nearest_mode="floor",
)
"""upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor"""

if scale_factors is not None:
return _aten_upsample_scales(input, scale_factors, "nearest", "asymmetric")
else:
assert output_size is not None
return _aten_upsample_output_size(input, output_size, "nearest", "asymmetric")


def aten_upsample_nearest1d_backward(
Expand All @@ -2602,18 +2595,21 @@ def aten_upsample_nearest1d_backward(
@torch_op("aten::upsample_nearest2d", trace_only=True)
def aten_upsample_nearest2d(
self: TReal,
size: INT64,
output_size: INT64,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TReal:
"""upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor"""

# NOTE: trace_only because optional attributes are not supported by ONNX
# TODO(justinchuby): Conditionally use scales
del scales_h
del scales_w

return _aten_upsample_output_size(self, size, "nearest", "asymmetric")
if scales_h is not None and scales_w is not None:
return _aten_upsample_scales(
self,
[scales_h, scales_w],
"nearest",
"asymmetric",
)
else:
return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric")


def aten_upsample_nearest2d_backward(
Expand All @@ -2631,18 +2627,22 @@ def aten_upsample_nearest2d_backward(
@torch_op("aten::upsample_nearest3d", trace_only=True)
def aten_upsample_nearest3d(
self: TReal,
size: INT64,
output_size: INT64,
scales_d: Optional[float] = None,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> TReal:
"""upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor"""

del scales_h
del scales_w
del scales_d

return _aten_upsample_output_size(self, size, "nearest", "asymmetric")
if scales_d is not None and scales_h is not None and scales_w is not None:
return _aten_upsample_scales(
self,
[scales_d, scales_h, scales_w],
"nearest",
"asymmetric",
)
else:
return _aten_upsample_output_size(self, output_size, "nearest", "asymmetric")


def aten_upsample_nearest3d_backward(
Expand Down Expand Up @@ -2695,7 +2695,7 @@ def aten_upsample_trilinear3d_vec(
if scale_factors is not None:
result = _aten_upsample_scales(
self,
op.Constant(value_floats=scale_factors),
scale_factors,
mode="linear",
coordinate_transformation_mode=coordinate_transformation_mode,
)
Expand Down
Loading

0 comments on commit 99ae64e

Please sign in to comment.