diff --git a/src/model_explorer_onnx/main.py b/src/model_explorer_onnx/main.py index 0f0dffb..7d9dc40 100644 --- a/src/model_explorer_onnx/main.py +++ b/src/model_explorer_onnx/main.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import os from typing import Any, Literal, Sequence import ml_dtypes @@ -20,23 +19,23 @@ def display_tensor(tensor: ir.TensorProtocol | None) -> str: if tensor is None: return "Data not available" - if tensor.size < _TENSOR_DISPLAY_LIMIT: - try: - array = tensor.numpy() - if tensor.dtype == ir.DataType.BFLOAT16: - array = array.view(ml_dtypes.bfloat16) - elif tensor.dtype == ir.DataType.FLOAT8E4M3FN: - array = array.view(ml_dtypes.float8_e4m3fn) - elif tensor.dtype == ir.DataType.FLOAT8E4M3FNUZ: - array = array.view(ml_dtypes.float8_e4m3fnuz) - elif tensor.dtype == ir.DataType.FLOAT8E5M2: - array = array.view(ml_dtypes.float8_e5m2) - elif tensor.dtype == ir.DataType.FLOAT8E5M2FNUZ: - array = array.view(ml_dtypes.float8_e5m2fnuz) - return np.array2string(array, separator=",") - except Exception as e: - logger.warning("Failed to display tensor: %s", e) - return str(tensor) + if tensor.size > _TENSOR_DISPLAY_LIMIT or isinstance(tensor, ir.ExternalTensor): + return str(tensor) + try: + array = tensor.numpy() + if tensor.dtype == ir.DataType.BFLOAT16: + array = array.view(ml_dtypes.bfloat16) + elif tensor.dtype == ir.DataType.FLOAT8E4M3FN: + array = array.view(ml_dtypes.float8_e4m3fn) + elif tensor.dtype == ir.DataType.FLOAT8E4M3FNUZ: + array = array.view(ml_dtypes.float8_e4m3fnuz) + elif tensor.dtype == ir.DataType.FLOAT8E5M2: + array = array.view(ml_dtypes.float8_e5m2) + elif tensor.dtype == ir.DataType.FLOAT8E5M2FNUZ: + array = array.view(ml_dtypes.float8_e5m2fnuz) + return np.array2string(array, separator=",") + except Exception as e: + logger.warning("Failed to display tensor: %s", e) return str(tensor) @@ -262,9 +261,13 @@ def create_node( """ assert onnx_node.name, "Bug: Node name is required" - embedded_namespace = parse_namespace(onnx_node.name) - if embedded_namespace: - namespace = namespace + "/" + "/".join(embedded_namespace) + if onnx_node.op_type == "Constant": + # Move the constant closer to the user node's namespace + namespace = get_constant_namespace(onnx_node.outputs[0], namespace) + else: + embedded_namespace = parse_namespace(onnx_node.name) + if embedded_namespace: + namespace = namespace + "/" + "/".join(embedded_namespace) node = graph_builder.GraphNode( id=onnx_node.name, label=create_op_label(onnx_node.domain, onnx_node.op_type), @@ -312,8 +315,8 @@ def add_graph_io( all_nodes[node.id] = node -def get_initializer_namespace(initializer: ir.Value, root_namespace: str) -> str: - # If the initializer is used by a single node, move it to the same namespace as the node +def get_constant_namespace(initializer: ir.Value, root_namespace: str) -> str: + """Move the constant/initializer closer to the user's namespace.""" initializer_namespace = root_namespace # A single node can have multiple uses of the same value. # Here we only count the unique nodes that use the initializer to push the @@ -323,6 +326,7 @@ def get_initializer_namespace(initializer: ir.Value, root_namespace: str) -> str # The initializer is not used by any node. Keep it in the root namespace. return initializer_namespace if len(user_nodes) == 1: + # If the initializer is used by a single node, move it to the same namespace as the node user_node = user_nodes[0] assert ( user_node.name @@ -376,7 +380,7 @@ def add_initializers( node = graph_builder.GraphNode( id=initializer_node_name, label="Initializer", - namespace=get_initializer_namespace(initializer, namespace), + namespace=get_constant_namespace(initializer, namespace), ) # Add metadata for the output tensor if initializer.const_value is None: @@ -458,24 +462,25 @@ def convert( ) -> model_explorer.ModelExplorerGraphs: del settings # Unused + # Do not load external data because the model file is copied to a temporary location + # and the external data paths are not valid anymore. onnx_model = onnx.load(model_path, load_external_data=False) try: - onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + onnx_model = onnx.shape_inference.infer_shapes(onnx_model, data_prop=True) except Exception as e: logger.warning( "Failed to infer shapes. Continue with the original model. Error: %s", e ) - # Load external data after shape inference - model_filepath = os.path.abspath(model_path) - base_dir = os.path.dirname(model_filepath) - onnx.load_external_data_for_model(onnx_model, base_dir) - # Convert to ONNX IR model = ir.serde.deserialize_model(onnx_model) all_function_ids = set(model.functions) graphs = [] - opset_version = model.opset_imports.get("", _DEFAULT_OPSET_VERSION) + opset_version = model.opset_imports.get("") + if opset_version is None: + opset_version = model.opset_imports.get("ai.onnx") + if opset_version is None: + opset_version = _DEFAULT_OPSET_VERSION # TODO: Better support subgraphs in nodes main_graph = create_graph( model.graph, all_function_ids, opset_version=opset_version