Skip to content

Commit

Permalink
[torchlib] Fix aten_empty_like (#1863)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Sep 10, 2024
1 parent e6dabeb commit a99e443
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3225,22 +3225,21 @@ def aten_empty(


@torch_op("aten::empty_like", trace_only=True)
def aten_empty_like(self: TTensor, dtype: int = -1) -> TTensor:
def aten_empty_like(
self: TTensor,
dtype: int = -1,
layout: str = "",
device: str = "",
pin_memory: bool = False,
memory_format: str = "",
) -> TTensor:
"""empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.

if dtype == -1:
if dtype == -1 or dtype is None:
zero = op.CastLike(0, self)
else:
zero = op.Cast(0, to=dtype)

return _aten_empty_like_onnx(self, zero)


@torch_op("aten::empty_like", private=True)
def _aten_empty_like_onnx(self: TTensor, zero) -> TTensor:
shape = op.Shape(self)
return op.Expand(zero, shape)

Expand Down

0 comments on commit a99e443

Please sign in to comment.