From 6aea85ba571bac892b6afcdfdce8ecc67187a818 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 21 May 2024 19:25:19 -0700 Subject: [PATCH] Create a new node type for inputs that are initialized (#29) image --- src/model_explorer_onnx/main.py | 28 +++++- testdata/initializer_node_tests.textproto | 95 +++++++++++++++++++ ...nstant_node_tests.py => constant_nodes.py} | 0 tools/test_generation/initializers.py | 70 ++++++++++++++ 4 files changed, 188 insertions(+), 5 deletions(-) create mode 100644 testdata/initializer_node_tests.textproto rename tools/test_generation/{constant_node_tests.py => constant_nodes.py} (100%) create mode 100644 tools/test_generation/initializers.py diff --git a/src/model_explorer_onnx/main.py b/src/model_explorer_onnx/main.py index 2a11074..086e3d4 100644 --- a/src/model_explorer_onnx/main.py +++ b/src/model_explorer_onnx/main.py @@ -359,6 +359,13 @@ def add_graph_io( id=get_graph_io_node_name(value), label=type_, ) + set_attr( + node, + "explanation", + f"This is a graph {type_.lower()}. ONNX {type_.lower()}s are values (edges), not nodes. " + f"A node is displayed here for visualization.", + ) + producer = value.producer() if producer is not None: assert producer.name, "Bug: Node name is required" @@ -437,11 +444,19 @@ def add_initializers( "Initializer does not have a name. Skipping: %s", initializer ) continue - initializer_node_name = get_initializer_node_name(initializer) - if initializer_node_name in all_nodes: - # The initializer is also a graph input. Fill in the missing metadata. - node = all_nodes[initializer_node_name] - metadata = node.outputsMetadata[0] + input_node_name = get_graph_io_node_name(initializer) + if input_node_name in all_nodes: + # The initializer is also a graph input. + # Convert it into an InitializedInput and fill in the missing metadata + node = all_nodes[input_node_name] + node.label = "InitializedInput" + set_attr( + node, + "explanation", + "This is a graph input that is initialized by an initializer. " + "ONNX inputs are values (edges), not nodes. A node is displayed here for visualization.", + ) + # Display the constant value if can_display_tensor_json(initializer.const_value, settings=settings): assert initializer.const_value is not None set_attr( @@ -449,8 +464,11 @@ def add_initializers( "__value", display_tensor_json(initializer.const_value, settings=settings), ) + # Set output metadata + metadata = node.outputsMetadata[0] set_attr(metadata, "value", display_tensor_repr(initializer.const_value)) continue + initializer_node_name = get_initializer_node_name(initializer) node = graph_builder.GraphNode( id=initializer_node_name, label="Initializer", diff --git a/testdata/initializer_node_tests.textproto b/testdata/initializer_node_tests.textproto new file mode 100644 index 0000000..3c0c54f --- /dev/null +++ b/testdata/initializer_node_tests.textproto @@ -0,0 +1,95 @@ +ir_version: 8 +graph { + node { + input: "input" + output: "output" + name: "node_Identity_0" + op_type: "Identity" + doc_string: "An identity node 1" + } + node { + input: "input_initialized" + output: "output_initialized" + name: "node_Identity_1" + op_type: "Identity" + doc_string: "An identity node 2" + } + node { + input: "normal_initializer" + output: "val_0" + name: "node_Identity_2" + op_type: "Identity" + doc_string: "An identity node 3" + } + name: "initializer_node_tests" + initializer { + dims: 3 + data_type: 1 + name: "input_initialized" + raw_data: "\000\000\000\000\000\000\200?\000\000\000@" + } + initializer { + dims: 1 + data_type: 7 + name: "normal_initializer" + raw_data: "*\000\000\000\000\000\000\000" + } + input { + name: "input" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 42 + } + } + } + } + doc_string: "An input" + } + input { + name: "input_initialized" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + doc_string: "An initialized input" + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 42 + } + } + } + } + } + output { + name: "output_initialized" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + domain: "" + version: 20 +} diff --git a/tools/test_generation/constant_node_tests.py b/tools/test_generation/constant_nodes.py similarity index 100% rename from tools/test_generation/constant_node_tests.py rename to tools/test_generation/constant_nodes.py diff --git a/tools/test_generation/initializers.py b/tools/test_generation/initializers.py new file mode 100644 index 0000000..108428e --- /dev/null +++ b/tools/test_generation/initializers.py @@ -0,0 +1,70 @@ +import pathlib + +import onnx +from onnxscript import ir +import numpy as np + + +def main(): + input_ = ir.Input( + "input", ir.Shape([42]), ir.TensorType(ir.DataType.FLOAT), "An input" + ) + input_initialized = ir.Input( + "input_initialized", + ir.Shape([3]), + ir.TensorType(ir.DataType.FLOAT), + "An initialized input", + ) + input_initializer = ir.Tensor( + np.array([0.0, 1.0, 2.0], dtype=np.float32), name="input_initialized" + ) + input_initialized.const_value = input_initializer + normal_value = ir.Value(name="normal_initializer") + normal_initializer = ir.Tensor( + np.array([42], dtype=np.int64), name="normal_initializer" + ) + normal_value.const_value = normal_initializer + output = ir.Value( + shape=ir.Shape([42]), type=ir.TensorType(ir.DataType.FLOAT), name="output" + ) + output_initialized = ir.Value( + shape=ir.Shape([3]), + type=ir.TensorType(ir.DataType.FLOAT), + name="output_initialized", + ) + graph = ir.Graph( + [input_, input_initialized], + [output, output_initialized], + nodes=[ + ir.Node( + "", + "Identity", + [input_], + outputs=[output], + doc_string="An identity node 1", + ), + ir.Node( + "", + "Identity", + [input_initialized], + outputs=[output_initialized], + doc_string="An identity node 2", + ), + ir.Node("", "Identity", [normal_value], doc_string="An identity node 3"), + ], + initializers=[input_initialized, normal_value], + opset_imports={"": 20}, + name="initializer_node_tests", + ) + model = ir.Model(graph, ir_version=8) + model_proto = ir.serde.serialize_model(model) + onnx.save( + model_proto, + pathlib.Path(__file__).parent.parent.parent + / "testdata" + / "initializer_node_tests.textproto", + ) + + +if __name__ == "__main__": + main()