Skip to content

Commit

Permalink
Skip loading external tensors; move constants closer to their users (#21
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored May 18, 2024
1 parent 164d553 commit b2458c2
Showing 1 changed file with 36 additions and 31 deletions.
67 changes: 36 additions & 31 deletions src/model_explorer_onnx/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import os
from typing import Any, Literal, Sequence

import ml_dtypes
Expand All @@ -20,23 +19,23 @@
def display_tensor(tensor: ir.TensorProtocol | None) -> str:
if tensor is None:
return "Data not available"
if tensor.size < _TENSOR_DISPLAY_LIMIT:
try:
array = tensor.numpy()
if tensor.dtype == ir.DataType.BFLOAT16:
array = array.view(ml_dtypes.bfloat16)
elif tensor.dtype == ir.DataType.FLOAT8E4M3FN:
array = array.view(ml_dtypes.float8_e4m3fn)
elif tensor.dtype == ir.DataType.FLOAT8E4M3FNUZ:
array = array.view(ml_dtypes.float8_e4m3fnuz)
elif tensor.dtype == ir.DataType.FLOAT8E5M2:
array = array.view(ml_dtypes.float8_e5m2)
elif tensor.dtype == ir.DataType.FLOAT8E5M2FNUZ:
array = array.view(ml_dtypes.float8_e5m2fnuz)
return np.array2string(array, separator=",")
except Exception as e:
logger.warning("Failed to display tensor: %s", e)
return str(tensor)
if tensor.size > _TENSOR_DISPLAY_LIMIT or isinstance(tensor, ir.ExternalTensor):
return str(tensor)
try:
array = tensor.numpy()
if tensor.dtype == ir.DataType.BFLOAT16:
array = array.view(ml_dtypes.bfloat16)
elif tensor.dtype == ir.DataType.FLOAT8E4M3FN:
array = array.view(ml_dtypes.float8_e4m3fn)
elif tensor.dtype == ir.DataType.FLOAT8E4M3FNUZ:
array = array.view(ml_dtypes.float8_e4m3fnuz)
elif tensor.dtype == ir.DataType.FLOAT8E5M2:
array = array.view(ml_dtypes.float8_e5m2)
elif tensor.dtype == ir.DataType.FLOAT8E5M2FNUZ:
array = array.view(ml_dtypes.float8_e5m2fnuz)
return np.array2string(array, separator=",")
except Exception as e:
logger.warning("Failed to display tensor: %s", e)
return str(tensor)


Expand Down Expand Up @@ -262,9 +261,13 @@ def create_node(
"""
assert onnx_node.name, "Bug: Node name is required"

embedded_namespace = parse_namespace(onnx_node.name)
if embedded_namespace:
namespace = namespace + "/" + "/".join(embedded_namespace)
if onnx_node.op_type == "Constant":
# Move the constant closer to the user node's namespace
namespace = get_constant_namespace(onnx_node.outputs[0], namespace)
else:
embedded_namespace = parse_namespace(onnx_node.name)
if embedded_namespace:
namespace = namespace + "/" + "/".join(embedded_namespace)
node = graph_builder.GraphNode(
id=onnx_node.name,
label=create_op_label(onnx_node.domain, onnx_node.op_type),
Expand Down Expand Up @@ -312,8 +315,8 @@ def add_graph_io(
all_nodes[node.id] = node


def get_initializer_namespace(initializer: ir.Value, root_namespace: str) -> str:
# If the initializer is used by a single node, move it to the same namespace as the node
def get_constant_namespace(initializer: ir.Value, root_namespace: str) -> str:
"""Move the constant/initializer closer to the user's namespace."""
initializer_namespace = root_namespace
# A single node can have multiple uses of the same value.
# Here we only count the unique nodes that use the initializer to push the
Expand All @@ -323,6 +326,7 @@ def get_initializer_namespace(initializer: ir.Value, root_namespace: str) -> str
# The initializer is not used by any node. Keep it in the root namespace.
return initializer_namespace
if len(user_nodes) == 1:
# If the initializer is used by a single node, move it to the same namespace as the node
user_node = user_nodes[0]
assert (
user_node.name
Expand Down Expand Up @@ -376,7 +380,7 @@ def add_initializers(
node = graph_builder.GraphNode(
id=initializer_node_name,
label="Initializer",
namespace=get_initializer_namespace(initializer, namespace),
namespace=get_constant_namespace(initializer, namespace),
)
# Add metadata for the output tensor
if initializer.const_value is None:
Expand Down Expand Up @@ -458,24 +462,25 @@ def convert(
) -> model_explorer.ModelExplorerGraphs:
del settings # Unused

# Do not load external data because the model file is copied to a temporary location
# and the external data paths are not valid anymore.
onnx_model = onnx.load(model_path, load_external_data=False)
try:
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model, data_prop=True)
except Exception as e:
logger.warning(
"Failed to infer shapes. Continue with the original model. Error: %s", e
)

# Load external data after shape inference
model_filepath = os.path.abspath(model_path)
base_dir = os.path.dirname(model_filepath)
onnx.load_external_data_for_model(onnx_model, base_dir)

# Convert to ONNX IR
model = ir.serde.deserialize_model(onnx_model)
all_function_ids = set(model.functions)
graphs = []
opset_version = model.opset_imports.get("", _DEFAULT_OPSET_VERSION)
opset_version = model.opset_imports.get("")
if opset_version is None:
opset_version = model.opset_imports.get("ai.onnx")
if opset_version is None:
opset_version = _DEFAULT_OPSET_VERSION
# TODO: Better support subgraphs in nodes
main_graph = create_graph(
model.graph, all_function_ids, opset_version=opset_version
Expand Down

0 comments on commit b2458c2

Please sign in to comment.