From a99e443edf3ff5c73e2df3330ca10f7cc1a6612b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 10 Sep 2024 08:33:20 -0700 Subject: [PATCH] [torchlib] Fix aten_empty_like (#1863) Fix https://github.com/pytorch/pytorch/issues/135532 --- .../function_libs/torch_lib/ops/core.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2ca22c7e4..30e9b7d33 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3225,22 +3225,21 @@ def aten_empty( @torch_op("aten::empty_like", trace_only=True) -def aten_empty_like(self: TTensor, dtype: int = -1) -> TTensor: +def aten_empty_like( + self: TTensor, + dtype: int = -1, + layout: str = "", + device: str = "", + pin_memory: bool = False, + memory_format: str = "", +) -> TTensor: """empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. - - if dtype == -1: + if dtype == -1 or dtype is None: zero = op.CastLike(0, self) else: zero = op.Cast(0, to=dtype) - return _aten_empty_like_onnx(self, zero) - - -@torch_op("aten::empty_like", private=True) -def _aten_empty_like_onnx(self: TTensor, zero) -> TTensor: shape = op.Shape(self) return op.Expand(zero, shape)