Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Add the identity nodes back #1703

Merged
merged 8 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@
def aten_alias(self: TTensor) -> TTensor:
"""alias(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)

Check warning on line 311 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L311

Added line #L311 was not covered by tests


def aten_alias_copy(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -374,7 +374,7 @@
self = aten_all_dim(self, d, keepdim=True)
if not keepdim:
self = op.Squeeze(self, list(dim))
return self
return op.Identity(self)


@torch_op("aten::all.dims", traceable=True)
Expand Down Expand Up @@ -499,7 +499,7 @@
self = aten_any_dim(self, d, keepdim=True)
if not keepdim:
self = op.Squeeze(self, list(dim))
return self
return op.Identity(self)


@torch_op("aten::any.dims", traceable=True)
Expand Down Expand Up @@ -940,7 +940,7 @@

if IsScalar(self):
self = op.Reshape(self, op.Constant(value_ints=[1]))
return self
return op.Identity(self)


@torch_op("aten::atleast_1d.Sequence")
Expand All @@ -964,7 +964,7 @@

if Rank(self) <= 1:
self = op.Reshape(self, op.Constant(value_ints=[1, -1]))
return self
return op.Identity(self)


@torch_op("aten::atleast_2d.Sequence")
Expand All @@ -991,7 +991,7 @@
self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1]))
elif rank == 2:
self = op.Unsqueeze(self, op.Constant(value_ints=[-1]))
return self
return op.Identity(self)


@torch_op("aten::atleast_3d.Sequence")
Expand Down Expand Up @@ -1691,7 +1691,7 @@
) -> TTensor:
"""clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor"""

return self
return op.Identity(self)


def aten_coalesce(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -1749,7 +1749,7 @@
def aten_conj(self: TTensor) -> TTensor:
"""conj(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


@torch_op("aten::conj", complex=True, private=True)
Expand Down Expand Up @@ -1825,7 +1825,7 @@
"""contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)"""

# ONNX does not have the notion of memory_format. It is always treated as a no-op.
return self
return op.Identity(self)


@torch_op("aten::conv1d", trace_only=True)
Expand Down Expand Up @@ -2168,7 +2168,7 @@
"""_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor"""

if dtype == -1:
return self
return op.Identity(self)

Check warning on line 2171 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2171

Added line #L2171 was not covered by tests
else:
return common_ops.cast_to(self, dtype=dtype)

Expand Down Expand Up @@ -2493,7 +2493,7 @@
def aten_detach(self: TensorType) -> TensorType:
"""detach(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


def aten_detach_copy(self: TensorType) -> TensorType:
Expand Down Expand Up @@ -4061,7 +4061,7 @@
if _has_none_in_middle(indices):
# If there is None in the middle, Advanced Indexing cannot decide where to put
# the new dimensions. So it places them in the front, like GatherND does.
return self
return op.Identity(self)

# When the indices are consecutive, Advanced Indexing will place the new dimensions
# (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
Expand Down Expand Up @@ -4227,7 +4227,7 @@
index = op.SequenceAt(indices, 0) # assume indices only have 1 element
# FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it
index_int = op.Cast(index, to=INT32.dtype)
# if all False, return self
# if all False, return op.Identity(self)
if op.ReduceSum(index_int) == 0:
result = self
else:
Expand Down Expand Up @@ -4700,7 +4700,7 @@
def aten_lift_fresh_copy(self: TensorType) -> TensorType:
"""lift_fresh_copy(Tensor self) -> Tensor"""

return self
return op.Identity(self)


def aten_linear_backward(
Expand Down Expand Up @@ -7082,14 +7082,14 @@
def aten_resolve_conj(self: TTensor) -> TTensor:
"""resolve_conj(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


@torch_op("aten::resolve_neg", trace_only=True)
def aten_resolve_neg(self: TTensor) -> TTensor:
"""resolve_neg(Tensor(a) self) -> Tensor(a)"""

return self
return op.Identity(self)


def aten_result_type(tensor: TensorType, other: TensorType) -> int:
Expand Down Expand Up @@ -7142,9 +7142,9 @@

self_rank = len(self.shape)
if self_rank == 0:
return self
return op.Identity(self)
elif self.shape[0] == 0: # empty tensor
return self
return op.Identity(self)
else:
# NOTE: In pytorch, default value of dims is an empty list.
if len(dims) == 0: # Empty sequence
Expand All @@ -7166,10 +7166,10 @@

self_rank = len(self.shape)
if self_rank == 1:
return self
return op.Identity(self)

if self.shape[0] == 0: # empty tensor
return self
return op.Identity(self)

self_real = op.Slice(self, [0], [1], axes=[-1])
self_imag = op.Slice(self, [1], [2], axes=[-1])
Expand Down Expand Up @@ -7819,7 +7819,7 @@
if signal_rank == 1:
# Add a batch dimension
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
return self, signal_rank
return op.Identity(self), signal_rank


@torch_op("aten::stft", private=True)
Expand Down Expand Up @@ -8768,7 +8768,7 @@

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_as_complex_copy", trace_only=True)
Expand All @@ -8777,7 +8777,7 @@

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_as_real", complex=True, trace_only=True)
Expand All @@ -8786,7 +8786,7 @@

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_as_real_copy", complex=True, trace_only=True)
Expand All @@ -8795,7 +8795,7 @@

# We always operate on the real representation of a complex number in torchlib
# So this is a no-op
return self
return op.Identity(self)


@torch_op("aten::view_copy")
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/rewriter/llama_rule_sets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_llama_p0_rule_set_cast_cast(self):
rewritten_model = ir.serde.serialize_model(ir_model)

self.assertEqual(["Cast"], [n.op_type for n in rewritten_model.graph.node])
self._check_model(model_proto, rewritten_model, atol=1e-3)
self._check_model(model_proto, rewritten_model, atol=1e-2)

@classmethod
def _cast_identity_models(cls):
Expand Down Expand Up @@ -376,6 +376,7 @@ def _slides_split_models(cls):
]
return models

@unittest.skipIf(True, reason="see https://github.com/microsoft/onnxscript/issues/1642")
def test_llama_p0_rule_set_slice_split(self):
for model_proto in self._slides_split_models():
ir_model = ir.serde.deserialize_model(model_proto)
Expand Down
Loading