Skip to content

Commit

Permalink
Create groupNodeAttributes for models
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Aug 27, 2024
1 parent aa985f9 commit 6fcedc3
Showing 1 changed file with 51 additions and 31 deletions.
82 changes: 51 additions & 31 deletions src/model_explorer_onnx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import model_explorer
import numpy as np
import onnx
from model_explorer import graph_builder
from model_explorer import graph_builder as gb
from onnxscript import ir

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -144,27 +144,25 @@ def get_node_output_param_name(
logger.warning("Failed to get output schema name: %s", e)


def set_attr(
obj: graph_builder.GraphNode | graph_builder.MetadataItem, key: str, value: str
) -> None:
def set_attr(obj: gb.GraphNode | gb.MetadataItem, key: str, value: str) -> None:
"""Set an attribute on a GraphNode or MetadataItem."""
obj.attrs.append(graph_builder.KeyValue(key=key, value=value))
obj.attrs.append(gb.KeyValue(key=key, value=value))


def set_type_shape_metadata(
metadata: graph_builder.MetadataItem, value: ir.Value | ir.TensorProtocol
metadata: gb.MetadataItem, value: ir.Value | ir.TensorProtocol
) -> None:
# tensor_shape is a special key that is used to display the type and shape of the tensor
set_attr(metadata, "tensor_shape", format_tensor_shape(value))


def set_metadata_props(metadata: graph_builder.MetadataItem, value: ir.Value) -> None:
def set_metadata_props(metadata: gb.MetadataItem, value: ir.Value) -> None:
for prop_key, prop_value in value.metadata_props.items():
set_attr(metadata, f"[metadata] {prop_key}", prop_value)


def add_inputs_metadata(
onnx_node: ir.Node, node: graph_builder.GraphNode, opset_version: int
onnx_node: ir.Node, node: gb.GraphNode, opset_version: int
) -> None:
if onnx.defs.has(onnx_node.op_type, max_inclusive_version=opset_version):
schema = onnx.defs.get_schema(
Expand All @@ -173,7 +171,7 @@ def add_inputs_metadata(
else:
schema = None
for i, input_value in enumerate(onnx_node.inputs):
metadata = graph_builder.MetadataItem(id=str(i), attrs=[])
metadata = gb.MetadataItem(id=str(i), attrs=[])
if input_value is None:
set_attr(metadata, "__tensor_tag", "None")
else:
Expand All @@ -186,17 +184,15 @@ def add_inputs_metadata(
node.inputsMetadata.append(metadata)


def add_outputs_metadata(
onnx_node: ir.Node, node: graph_builder.GraphNode, opset_version: int
):
def add_outputs_metadata(onnx_node: ir.Node, node: gb.GraphNode, opset_version: int):
if onnx.defs.has(onnx_node.op_type, max_inclusive_version=opset_version):
schema = onnx.defs.get_schema(
onnx_node.op_type, max_inclusive_version=opset_version
)
else:
schema = None
for output_value in onnx_node.outputs:
metadata = graph_builder.MetadataItem(id=str(output_value.index()), attrs=[])
metadata = gb.MetadataItem(id=str(output_value.index()), attrs=[])
set_attr(metadata, "__tensor_tag", output_value.name or "None")
set_type_shape_metadata(metadata, output_value)
set_metadata_props(metadata, output_value)
Expand All @@ -213,9 +209,7 @@ def add_outputs_metadata(
node.outputsMetadata.append(metadata)


def add_node_attrs(
onnx_node: ir.Node, node: graph_builder.GraphNode, settings: Settings
) -> None:
def add_node_attrs(onnx_node: ir.Node, node: gb.GraphNode, settings: Settings) -> None:
for attr in onnx_node.attributes.values():
if isinstance(attr, ir.Attr):
if attr.type == ir.AttributeType.TENSOR:
Expand Down Expand Up @@ -263,7 +257,7 @@ def add_node_attrs(

def add_incoming_edges(
onnx_node: ir.Node,
node: graph_builder.GraphNode,
node: gb.GraphNode,
graph_inputs: set[ir.Value],
) -> None:
for target_input_id, input_value in enumerate(onnx_node.inputs):
Expand All @@ -288,7 +282,7 @@ def add_incoming_edges(
source_node_output_id = str(input_value.index())
assert source_node_id is not None
node.incomingEdges.append(
graph_builder.IncomingEdge(
gb.IncomingEdge(
sourceNodeId=source_node_id,
sourceNodeOutputId=source_node_output_id,
targetNodeInputId=str(target_input_id),
Expand Down Expand Up @@ -325,7 +319,7 @@ def create_node(
all_function_ids: set[ir.OperatorIdentifier],
opset_version: int,
settings: Settings,
) -> graph_builder.GraphNode | None:
) -> gb.GraphNode | None:
"""Create a GraphNode from an ONNX node.
Args:
Expand All @@ -344,7 +338,7 @@ def create_node(
embedded_namespace = get_node_namespace(onnx_node)
if embedded_namespace:
namespace = namespace + "/" + "/".join(embedded_namespace)
node = graph_builder.GraphNode(
node = gb.GraphNode(
id=onnx_node.name,
label=create_op_label(onnx_node.domain, onnx_node.op_type),
namespace=namespace,
Expand All @@ -359,28 +353,28 @@ def create_node(


def add_graph_io(
graph: graph_builder.Graph,
graph: gb.Graph,
input_or_outputs: Sequence[ir.Value],
type_: Literal["Input", "Output"],
all_nodes: dict[str, graph_builder.GraphNode],
all_nodes: dict[str, gb.GraphNode],
) -> None:
for i, value in enumerate(input_or_outputs):
node = graph_builder.GraphNode(
node = gb.GraphNode(
id=get_value_node_name(value),
label=type_,
)
producer = value.producer()
if producer is not None:
assert producer.name, "Bug: Node name is required"
node.incomingEdges.append(
graph_builder.IncomingEdge(
gb.IncomingEdge(
sourceNodeId=producer.name,
sourceNodeOutputId=str(value.index()),
targetNodeInputId="0",
)
)
if type_ == "Input":
metadata = graph_builder.MetadataItem(id="0", attrs=[])
metadata = gb.MetadataItem(id="0", attrs=[])
set_attr(metadata, "__tensor_tag", value.name or "")
set_type_shape_metadata(metadata, value)
set_metadata_props(metadata, value)
Expand Down Expand Up @@ -430,10 +424,10 @@ def get_constant_namespace(initializer: ir.Value, root_namespace: str) -> str:


def add_initializers(
graph: graph_builder.Graph,
graph: gb.Graph,
initializers: Sequence[ir.Value],
namespace: str,
all_nodes: dict[str, graph_builder.GraphNode],
all_nodes: dict[str, gb.GraphNode],
settings: Settings,
) -> None:
for initializer in initializers:
Expand Down Expand Up @@ -462,7 +456,7 @@ def add_initializers(
metadata = node.outputsMetadata[0]
set_attr(metadata, "value", display_tensor_repr(initializer.const_value))
continue
node = graph_builder.GraphNode(
node = gb.GraphNode(
id=initializer_node_name,
label="Initializer",
namespace=get_constant_namespace(initializer, namespace),
Expand All @@ -474,7 +468,7 @@ def add_initializers(
)
graph.nodes.append(node)
continue
metadata = graph_builder.MetadataItem(id="0", attrs=[])
metadata = gb.MetadataItem(id="0", attrs=[])
set_attr(metadata, "__tensor_tag", initializer.name or "")
set_type_shape_metadata(metadata, initializer.const_value)
if can_display_tensor_json(initializer.const_value, settings=settings):
Expand All @@ -498,15 +492,20 @@ def create_graph(
all_function_ids: set[ir.OperatorIdentifier],
opset_version: int,
settings: Settings,
) -> graph_builder.Graph | None:
attrs: dict[str, Any],
) -> gb.Graph | None:
if isinstance(onnx_graph, ir.Function):
graph_name = get_function_graph_name(onnx_graph.identifier())
elif onnx_graph.name is None:
logger.warning("Graph does not have a name. skipping graph: %s", onnx_graph)
return None
else:
graph_name = onnx_graph.name
graph = graph_builder.Graph(id=graph_name, nodes=[])
graph = gb.Graph(
id=graph_name,
nodes=[],
groupNodeAttributes={"": {key: str(value) for key, value in attrs.items()}},
)
graph_inputs = set(onnx_graph.inputs)
all_nodes = {}
add_graph_io(graph, onnx_graph.inputs, type_="Input", all_nodes=all_nodes)
Expand Down Expand Up @@ -589,6 +588,18 @@ def convert(
all_function_ids,
opset_version=opset_version,
settings=parsed_settings,
attrs={
"opset_imports": model.opset_imports,
"producer_name": model.producer_name,
"producer_version": model.producer_version,
"domain": model.domain,
"model_version": model.model_version,
"doc_string": model.doc_string,
**{
f"[metadata] {key}": value
for key, value in model.metadata_props.items()
},
},
)
assert main_graph is not None, "Bug: Main graph should not be None"
graphs.append(main_graph)
Expand All @@ -599,6 +610,15 @@ def convert(
all_function_ids,
opset_version=opset_version,
settings=parsed_settings,
attrs={
"opset_imports": model.opset_imports,
"attributes": function.attributes,
"doc_string": function.doc_string,
**{
f"[metadata] {key}": value
for key, value in function.metadata_props.items()
},
},
)
assert function_graph is not None
graphs.append(function_graph)
Expand Down

0 comments on commit 6fcedc3

Please sign in to comment.