diff --git a/README.md b/README.md index fc4c248..59e0b02 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,10 @@ onnxvis onnxvis model.onnx ``` +## Notes on representation + +Graph input/output/initializers in ONNX are values (edges), not nodes. A node is displayed here for visualization. Graph inputs that are initialized by initializers are displayed as `InitializedInput`, and are displayed closer to nodes that use them. + ## Screenshots image diff --git a/src/model_explorer_onnx/main.py b/src/model_explorer_onnx/main.py index 086e3d4..87b47d2 100644 --- a/src/model_explorer_onnx/main.py +++ b/src/model_explorer_onnx/main.py @@ -45,10 +45,10 @@ def display_tensor_repr(tensor: ir.TensorProtocol | None) -> str: def can_display_tensor_json( tensor: ir.TensorProtocol | None, settings: Settings ) -> bool: + """Check if the tensor can be displayed as JSON.""" + del settings # Unused if tensor is None: return False - if tensor.size > settings.const_element_count_limit: - return False if isinstance(tensor, ir.ExternalTensor): return False return True @@ -79,8 +79,8 @@ def format_shape(shape: ir.ShapeProtocol | None) -> str: return str(shape) if shape is not None else "[?]" -def format_type(type: ir.TypeProtocol | None) -> str: - return str(type) if type is not None else "?" +def format_type(type_: ir.TypeProtocol | None) -> str: + return str(type_) if type_ is not None else "?" def format_tensor_shape(value: ir.Value | ir.TensorProtocol) -> str: @@ -89,12 +89,9 @@ def format_tensor_shape(value: ir.Value | ir.TensorProtocol) -> str: return f"{value.dtype or '?'}{format_shape(value.shape)}" -def get_graph_io_node_name(value: ir.Value) -> str: - return f"[io] {value.name}" - - -def get_initializer_node_name(value: ir.Value) -> str: - return f"[initializer] {value.name}" +def get_value_node_name(value: ir.Value) -> str: + """Create name for node that is created from a value, that is for visualization only. E.g. Input.""" + return f"[value] {value.name}" def get_function_graph_name(identifier: ir.OperatorIdentifier) -> str: @@ -271,7 +268,7 @@ def add_incoming_edges( continue if input_value in graph_inputs: # The input is a graph input. Create an input edge. - source_node_id = get_graph_io_node_name(input_value) + source_node_id = get_value_node_name(input_value) source_node_output_id = "0" else: input_node = input_value.producer() @@ -280,7 +277,7 @@ def add_incoming_edges( "Input value %s does not have a producer. Treating as initializer.", input_value, ) - source_node_id = get_initializer_node_name(input_value) + source_node_id = get_value_node_name(input_value) source_node_output_id = "0" else: assert input_node.name, "Bug: Node name is required" @@ -356,16 +353,9 @@ def add_graph_io( ) -> None: for i, value in enumerate(input_or_outputs): node = graph_builder.GraphNode( - id=get_graph_io_node_name(value), + id=get_value_node_name(value), label=type_, ) - set_attr( - node, - "explanation", - f"This is a graph {type_.lower()}. ONNX {type_.lower()}s are values (edges), not nodes. " - f"A node is displayed here for visualization.", - ) - producer = value.producer() if producer is not None: assert producer.name, "Bug: Node name is required" @@ -444,18 +434,14 @@ def add_initializers( "Initializer does not have a name. Skipping: %s", initializer ) continue - input_node_name = get_graph_io_node_name(initializer) - if input_node_name in all_nodes: + initializer_node_name = get_value_node_name(initializer) + if initializer_node_name in all_nodes: # The initializer is also a graph input. # Convert it into an InitializedInput and fill in the missing metadata - node = all_nodes[input_node_name] + node = all_nodes[initializer_node_name] node.label = "InitializedInput" - set_attr( - node, - "explanation", - "This is a graph input that is initialized by an initializer. " - "ONNX inputs are values (edges), not nodes. A node is displayed here for visualization.", - ) + # Push the initializer closer to the user node's namespace + node.namespace = get_constant_namespace(initializer, namespace) # Display the constant value if can_display_tensor_json(initializer.const_value, settings=settings): assert initializer.const_value is not None @@ -468,7 +454,6 @@ def add_initializers( metadata = node.outputsMetadata[0] set_attr(metadata, "value", display_tensor_repr(initializer.const_value)) continue - initializer_node_name = get_initializer_node_name(initializer) node = graph_builder.GraphNode( id=initializer_node_name, label="Initializer",