diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index b45499744..41571bcd3 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -1036,7 +1036,12 @@ def _should_create_value_info_for_value(value: _protocols.ValueProtocol) -> bool True if value info should be created for the value. """ # No need to serialize value info if it is not set - return not (value.shape is None and value.type is None) + if value.shape is None and value.type is None: + return False + if not value.name: + logger.debug("Did not serialize '%s' because its name is empty", value) + return False + return True def _serialize_experimental_value_info_for_function_ir9_into( @@ -1269,6 +1274,23 @@ def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto: return node_proto +def _remove_trailing_outputs( + outputs: Sequence[_protocols.ValueProtocol], +) -> Sequence[_protocols.ValueProtocol]: + """Remove trailing outputs that have empty names. + + Args: + outputs: The outputs to remove trailing outputs from. + + Returns: + The outputs with trailing outputs removed. + """ + for i, output in enumerate(reversed(outputs)): + if output.name: + return outputs[: len(outputs) - i] + return [] + + @_capture_errors(lambda node_proto, from_: repr(from_)) def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None: node_proto.op_type = from_.op_type @@ -1288,8 +1310,11 @@ def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtoc node_proto.input.append("") else: node_proto.input.append(input_.name) - for output in from_.outputs: + + # Do not include the trailing outputs that have empty names + for output in _remove_trailing_outputs(from_.outputs): node_proto.output.append(output.name) + for attr in from_.attributes.values(): if isinstance(attr, _core.Attr): serialize_attribute_into(node_proto.attribute.add(), from_=attr) diff --git a/onnxscript/optimizer/_remove_unused_test.py b/onnxscript/optimizer/_remove_unused_test.py index b87a176f6..425a00a44 100644 --- a/onnxscript/optimizer/_remove_unused_test.py +++ b/onnxscript/optimizer/_remove_unused_test.py @@ -11,6 +11,8 @@ @parameterized.parameterized_class(("using_ir",), [(False,), (True,)]) class RemoveUnusedTest(unittest.TestCase): + using_ir: bool + def remove_unused_nodes(self, model: onnx.ModelProto): if self.using_ir: model_ir = ir.serde.deserialize_model(model) @@ -81,11 +83,7 @@ def test_remove_unused_optional_outputs_maxpool(self): model = self.remove_unused_nodes(model) self.assertEqual(len(model.graph.node), 1) self.assertEqual(model.graph.node[0].op_type, "MaxPool") - if self.using_ir: - expected_outputs = ["z", ""] - else: - expected_outputs = ["z"] - self.assertEqual(model.graph.node[0].output, expected_outputs) + self.assertEqual(model.graph.node[0].output, ["z"]) def test_remove_unused_optional_outputs_dropout_in_function(self): model = onnx.parser.parse_model( @@ -110,11 +108,7 @@ def test_remove_unused_optional_outputs_dropout_in_function(self): self.assertEqual(len(model.functions), 1) self.assertEqual(len(model.functions[0].node), 1) self.assertEqual(model.functions[0].node[0].op_type, "MaxPool") - if self.using_ir: - expected_outputs = ["z", ""] - else: - expected_outputs = ["z"] - self.assertEqual(model.functions[0].node[0].output, expected_outputs) + self.assertEqual(model.functions[0].node[0].output, ["z"]) def test_remove_used_optional_outputs_maxpool(self): model = onnx.parser.parse_model( @@ -150,11 +144,7 @@ def test_remove_multiple_unused_optional_outputs_layernorm(self): model = self.remove_unused_nodes(model) self.assertEqual(len(model.graph.node), 3) self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - if self.using_ir: - expected_outputs = ["z", "", ""] - else: - expected_outputs = ["z"] - self.assertEqual(list(model.graph.node[2].output), expected_outputs) + self.assertEqual(list(model.graph.node[2].output), ["z"]) def test_remove_trailing_unused_optional_outputs_layernorm(self): model = onnx.parser.parse_model( @@ -173,11 +163,7 @@ def test_remove_trailing_unused_optional_outputs_layernorm(self): model = self.remove_unused_nodes(model) self.assertEqual(len(model.graph.node), 3) self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - if self.using_ir: - expected_outputs = ["z", "mean", ""] - else: - expected_outputs = ["z", "mean"] - self.assertEqual(list(model.graph.node[2].output), expected_outputs) + self.assertEqual(list(model.graph.node[2].output), ["z", "mean"]) def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self): model = onnx.parser.parse_model( @@ -212,11 +198,7 @@ def test_remove_trailing_unused_optional_outputs_batchnorm(self): self.assertEqual(len(model.graph.node), 1) self.assertEqual(model.graph.node[0].op_type, "BatchNormalization") # Check that both the mean/var outputs are removed, and training_mode attribute is removed. - if self.using_ir: - expected_outputs = ["z", "", ""] - else: - expected_outputs = ["z"] - self.assertEqual(list(model.graph.node[0].output), expected_outputs) + self.assertEqual(list(model.graph.node[0].output), ["z"]) self.assertEqual(len(model.graph.node[0].attribute), 0) def test_avoid_remove_used_optional_outputs_batchnorm(self):