Skip to content

Commit

Permalink
Display initialized inputs properly (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored May 22, 2024
1 parent 6aea85b commit 41b6efb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 30 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ onnxvis
onnxvis model.onnx
```

## Notes on representation

Graph input/output/initializers in ONNX are values (edges), not nodes. A node is displayed here for visualization. Graph inputs that are initialized by initializers are displayed as `InitializedInput`, and are displayed closer to nodes that use them.

## Screenshots

<img width="1294" alt="image" src="https://github.com/justinchuby/model-explorer-onnx/assets/11205048/ed7e1eee-a693-48bd-811d-b384f784ef9b">
Expand Down
45 changes: 15 additions & 30 deletions src/model_explorer_onnx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def display_tensor_repr(tensor: ir.TensorProtocol | None) -> str:
def can_display_tensor_json(
tensor: ir.TensorProtocol | None, settings: Settings
) -> bool:
"""Check if the tensor can be displayed as JSON."""
del settings # Unused
if tensor is None:
return False
if tensor.size > settings.const_element_count_limit:
return False
if isinstance(tensor, ir.ExternalTensor):
return False
return True
Expand Down Expand Up @@ -79,8 +79,8 @@ def format_shape(shape: ir.ShapeProtocol | None) -> str:
return str(shape) if shape is not None else "[?]"


def format_type(type: ir.TypeProtocol | None) -> str:
return str(type) if type is not None else "?"
def format_type(type_: ir.TypeProtocol | None) -> str:
return str(type_) if type_ is not None else "?"


def format_tensor_shape(value: ir.Value | ir.TensorProtocol) -> str:
Expand All @@ -89,12 +89,9 @@ def format_tensor_shape(value: ir.Value | ir.TensorProtocol) -> str:
return f"{value.dtype or '?'}{format_shape(value.shape)}"


def get_graph_io_node_name(value: ir.Value) -> str:
return f"[io] {value.name}"


def get_initializer_node_name(value: ir.Value) -> str:
return f"[initializer] {value.name}"
def get_value_node_name(value: ir.Value) -> str:
"""Create name for node that is created from a value, that is for visualization only. E.g. Input."""
return f"[value] {value.name}"


def get_function_graph_name(identifier: ir.OperatorIdentifier) -> str:
Expand Down Expand Up @@ -271,7 +268,7 @@ def add_incoming_edges(
continue
if input_value in graph_inputs:
# The input is a graph input. Create an input edge.
source_node_id = get_graph_io_node_name(input_value)
source_node_id = get_value_node_name(input_value)
source_node_output_id = "0"
else:
input_node = input_value.producer()
Expand All @@ -280,7 +277,7 @@ def add_incoming_edges(
"Input value %s does not have a producer. Treating as initializer.",
input_value,
)
source_node_id = get_initializer_node_name(input_value)
source_node_id = get_value_node_name(input_value)
source_node_output_id = "0"
else:
assert input_node.name, "Bug: Node name is required"
Expand Down Expand Up @@ -356,16 +353,9 @@ def add_graph_io(
) -> None:
for i, value in enumerate(input_or_outputs):
node = graph_builder.GraphNode(
id=get_graph_io_node_name(value),
id=get_value_node_name(value),
label=type_,
)
set_attr(
node,
"explanation",
f"This is a graph {type_.lower()}. ONNX {type_.lower()}s are values (edges), not nodes. "
f"A node is displayed here for visualization.",
)

producer = value.producer()
if producer is not None:
assert producer.name, "Bug: Node name is required"
Expand Down Expand Up @@ -444,18 +434,14 @@ def add_initializers(
"Initializer does not have a name. Skipping: %s", initializer
)
continue
input_node_name = get_graph_io_node_name(initializer)
if input_node_name in all_nodes:
initializer_node_name = get_value_node_name(initializer)
if initializer_node_name in all_nodes:
# The initializer is also a graph input.
# Convert it into an InitializedInput and fill in the missing metadata
node = all_nodes[input_node_name]
node = all_nodes[initializer_node_name]
node.label = "InitializedInput"
set_attr(
node,
"explanation",
"This is a graph input that is initialized by an initializer. "
"ONNX inputs are values (edges), not nodes. A node is displayed here for visualization.",
)
# Push the initializer closer to the user node's namespace
node.namespace = get_constant_namespace(initializer, namespace)
# Display the constant value
if can_display_tensor_json(initializer.const_value, settings=settings):
assert initializer.const_value is not None
Expand All @@ -468,7 +454,6 @@ def add_initializers(
metadata = node.outputsMetadata[0]
set_attr(metadata, "value", display_tensor_repr(initializer.const_value))
continue
initializer_node_name = get_initializer_node_name(initializer)
node = graph_builder.GraphNode(
id=initializer_node_name,
label="Initializer",
Expand Down

0 comments on commit 41b6efb

Please sign in to comment.