Skip to content

Commit

Permalink
Create a new node type for inputs that are initialized (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored May 22, 2024
1 parent dea54ba commit 6aea85b
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 5 deletions.
28 changes: 23 additions & 5 deletions src/model_explorer_onnx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,13 @@ def add_graph_io(
id=get_graph_io_node_name(value),
label=type_,
)
set_attr(
node,
"explanation",
f"This is a graph {type_.lower()}. ONNX {type_.lower()}s are values (edges), not nodes. "
f"A node is displayed here for visualization.",
)

producer = value.producer()
if producer is not None:
assert producer.name, "Bug: Node name is required"
Expand Down Expand Up @@ -437,20 +444,31 @@ def add_initializers(
"Initializer does not have a name. Skipping: %s", initializer
)
continue
initializer_node_name = get_initializer_node_name(initializer)
if initializer_node_name in all_nodes:
# The initializer is also a graph input. Fill in the missing metadata.
node = all_nodes[initializer_node_name]
metadata = node.outputsMetadata[0]
input_node_name = get_graph_io_node_name(initializer)
if input_node_name in all_nodes:
# The initializer is also a graph input.
# Convert it into an InitializedInput and fill in the missing metadata
node = all_nodes[input_node_name]
node.label = "InitializedInput"
set_attr(
node,
"explanation",
"This is a graph input that is initialized by an initializer. "
"ONNX inputs are values (edges), not nodes. A node is displayed here for visualization.",
)
# Display the constant value
if can_display_tensor_json(initializer.const_value, settings=settings):
assert initializer.const_value is not None
set_attr(
node,
"__value",
display_tensor_json(initializer.const_value, settings=settings),
)
# Set output metadata
metadata = node.outputsMetadata[0]
set_attr(metadata, "value", display_tensor_repr(initializer.const_value))
continue
initializer_node_name = get_initializer_node_name(initializer)
node = graph_builder.GraphNode(
id=initializer_node_name,
label="Initializer",
Expand Down
95 changes: 95 additions & 0 deletions testdata/initializer_node_tests.textproto
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
ir_version: 8
graph {
node {
input: "input"
output: "output"
name: "node_Identity_0"
op_type: "Identity"
doc_string: "An identity node 1"
}
node {
input: "input_initialized"
output: "output_initialized"
name: "node_Identity_1"
op_type: "Identity"
doc_string: "An identity node 2"
}
node {
input: "normal_initializer"
output: "val_0"
name: "node_Identity_2"
op_type: "Identity"
doc_string: "An identity node 3"
}
name: "initializer_node_tests"
initializer {
dims: 3
data_type: 1
name: "input_initialized"
raw_data: "\000\000\000\000\000\000\200?\000\000\000@"
}
initializer {
dims: 1
data_type: 7
name: "normal_initializer"
raw_data: "*\000\000\000\000\000\000\000"
}
input {
name: "input"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 42
}
}
}
}
doc_string: "An input"
}
input {
name: "input_initialized"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
doc_string: "An initialized input"
}
output {
name: "output"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 42
}
}
}
}
}
output {
name: "output_initialized"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
domain: ""
version: 20
}
File renamed without changes.
70 changes: 70 additions & 0 deletions tools/test_generation/initializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pathlib

import onnx
from onnxscript import ir
import numpy as np


def main():
input_ = ir.Input(
"input", ir.Shape([42]), ir.TensorType(ir.DataType.FLOAT), "An input"
)
input_initialized = ir.Input(
"input_initialized",
ir.Shape([3]),
ir.TensorType(ir.DataType.FLOAT),
"An initialized input",
)
input_initializer = ir.Tensor(
np.array([0.0, 1.0, 2.0], dtype=np.float32), name="input_initialized"
)
input_initialized.const_value = input_initializer
normal_value = ir.Value(name="normal_initializer")
normal_initializer = ir.Tensor(
np.array([42], dtype=np.int64), name="normal_initializer"
)
normal_value.const_value = normal_initializer
output = ir.Value(
shape=ir.Shape([42]), type=ir.TensorType(ir.DataType.FLOAT), name="output"
)
output_initialized = ir.Value(
shape=ir.Shape([3]),
type=ir.TensorType(ir.DataType.FLOAT),
name="output_initialized",
)
graph = ir.Graph(
[input_, input_initialized],
[output, output_initialized],
nodes=[
ir.Node(
"",
"Identity",
[input_],
outputs=[output],
doc_string="An identity node 1",
),
ir.Node(
"",
"Identity",
[input_initialized],
outputs=[output_initialized],
doc_string="An identity node 2",
),
ir.Node("", "Identity", [normal_value], doc_string="An identity node 3"),
],
initializers=[input_initialized, normal_value],
opset_imports={"": 20},
name="initializer_node_tests",
)
model = ir.Model(graph, ir_version=8)
model_proto = ir.serde.serialize_model(model)
onnx.save(
model_proto,
pathlib.Path(__file__).parent.parent.parent
/ "testdata"
/ "initializer_node_tests.textproto",
)


if __name__ == "__main__":
main()

0 comments on commit 6aea85b

Please sign in to comment.