Skip to content

Commit

Permalink
[IR] Do not serialize the trailing outputs that have empty names (#1905)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Oct 15, 2024
1 parent 1544ee1 commit d4b81dc
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 27 deletions.
29 changes: 27 additions & 2 deletions onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,12 @@ def _should_create_value_info_for_value(value: _protocols.ValueProtocol) -> bool
True if value info should be created for the value.
"""
# No need to serialize value info if it is not set
return not (value.shape is None and value.type is None)
if value.shape is None and value.type is None:
return False
if not value.name:
logger.debug("Did not serialize '%s' because its name is empty", value)
return False
return True


def _serialize_experimental_value_info_for_function_ir9_into(
Expand Down Expand Up @@ -1269,6 +1274,23 @@ def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto:
return node_proto


def _remove_trailing_outputs(
outputs: Sequence[_protocols.ValueProtocol],
) -> Sequence[_protocols.ValueProtocol]:
"""Remove trailing outputs that have empty names.
Args:
outputs: The outputs to remove trailing outputs from.
Returns:
The outputs with trailing outputs removed.
"""
for i, output in enumerate(reversed(outputs)):
if output.name:
return outputs[: len(outputs) - i]
return []


@_capture_errors(lambda node_proto, from_: repr(from_))
def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None:
node_proto.op_type = from_.op_type
Expand All @@ -1288,8 +1310,11 @@ def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtoc
node_proto.input.append("")
else:
node_proto.input.append(input_.name)
for output in from_.outputs:

# Do not include the trailing outputs that have empty names
for output in _remove_trailing_outputs(from_.outputs):
node_proto.output.append(output.name)

for attr in from_.attributes.values():
if isinstance(attr, _core.Attr):
serialize_attribute_into(node_proto.attribute.add(), from_=attr)
Expand Down
32 changes: 7 additions & 25 deletions onnxscript/optimizer/_remove_unused_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

@parameterized.parameterized_class(("using_ir",), [(False,), (True,)])
class RemoveUnusedTest(unittest.TestCase):
using_ir: bool

def remove_unused_nodes(self, model: onnx.ModelProto):
if self.using_ir:
model_ir = ir.serde.deserialize_model(model)
Expand Down Expand Up @@ -81,11 +83,7 @@ def test_remove_unused_optional_outputs_maxpool(self):
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "MaxPool")
if self.using_ir:
expected_outputs = ["z", ""]
else:
expected_outputs = ["z"]
self.assertEqual(model.graph.node[0].output, expected_outputs)
self.assertEqual(model.graph.node[0].output, ["z"])

def test_remove_unused_optional_outputs_dropout_in_function(self):
model = onnx.parser.parse_model(
Expand All @@ -110,11 +108,7 @@ def test_remove_unused_optional_outputs_dropout_in_function(self):
self.assertEqual(len(model.functions), 1)
self.assertEqual(len(model.functions[0].node), 1)
self.assertEqual(model.functions[0].node[0].op_type, "MaxPool")
if self.using_ir:
expected_outputs = ["z", ""]
else:
expected_outputs = ["z"]
self.assertEqual(model.functions[0].node[0].output, expected_outputs)
self.assertEqual(model.functions[0].node[0].output, ["z"])

def test_remove_used_optional_outputs_maxpool(self):
model = onnx.parser.parse_model(
Expand Down Expand Up @@ -150,11 +144,7 @@ def test_remove_multiple_unused_optional_outputs_layernorm(self):
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 3)
self.assertEqual(model.graph.node[2].op_type, "LayerNormalization")
if self.using_ir:
expected_outputs = ["z", "", ""]
else:
expected_outputs = ["z"]
self.assertEqual(list(model.graph.node[2].output), expected_outputs)
self.assertEqual(list(model.graph.node[2].output), ["z"])

def test_remove_trailing_unused_optional_outputs_layernorm(self):
model = onnx.parser.parse_model(
Expand All @@ -173,11 +163,7 @@ def test_remove_trailing_unused_optional_outputs_layernorm(self):
model = self.remove_unused_nodes(model)
self.assertEqual(len(model.graph.node), 3)
self.assertEqual(model.graph.node[2].op_type, "LayerNormalization")
if self.using_ir:
expected_outputs = ["z", "mean", ""]
else:
expected_outputs = ["z", "mean"]
self.assertEqual(list(model.graph.node[2].output), expected_outputs)
self.assertEqual(list(model.graph.node[2].output), ["z", "mean"])

def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self):
model = onnx.parser.parse_model(
Expand Down Expand Up @@ -212,11 +198,7 @@ def test_remove_trailing_unused_optional_outputs_batchnorm(self):
self.assertEqual(len(model.graph.node), 1)
self.assertEqual(model.graph.node[0].op_type, "BatchNormalization")
# Check that both the mean/var outputs are removed, and training_mode attribute is removed.
if self.using_ir:
expected_outputs = ["z", "", ""]
else:
expected_outputs = ["z"]
self.assertEqual(list(model.graph.node[0].output), expected_outputs)
self.assertEqual(list(model.graph.node[0].output), ["z"])
self.assertEqual(len(model.graph.node[0].attribute), 0)

def test_avoid_remove_used_optional_outputs_batchnorm(self):
Expand Down

0 comments on commit d4b81dc

Please sign in to comment.