Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show input/output names of the nodes with information from OpSchema #11

Merged
merged 6 commits into from
May 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 101 additions & 7 deletions src/model_explorer_onnx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
logger = logging.getLogger(__name__)

_TENSOR_DISPLAY_LIMIT = 1024
_DEFAULT_OPSET_VERSION = 18


def display_tensor(tensor: ir.TensorProtocol | None) -> str:
Expand Down Expand Up @@ -68,7 +69,58 @@ def get_function_graph_name(identifier: ir.OperatorIdentifier) -> str:
return name


def add_inputs_metadata(onnx_node: ir.Node, node: graph_builder.GraphNode):
def get_node_input_param_name(
schema: onnx.defs.OpSchema, input_index: int
) -> str | None:
"""Get the name of the input parameter of the node from OpSchema."""
try:
if len(schema.inputs) == 0:
# Invalid schema.
return None
if input_index < len(schema.inputs):
return schema.inputs[input_index].name
if (
schema.inputs[-1].option
== onnx.defs.OpSchema.FormalParameterOption.Variadic
):
# The last input is variadic. Return the name of the last input.
return schema.inputs[-1].name
return None
except Exception as e:
logger.warning("Failed to get input schema name: %s", e)
return None


def get_node_output_param_name(
schema: onnx.defs.OpSchema, output_index: int
) -> str | None:
"""Get the name of the output parameter of the node from OpSchema."""
try:
if len(schema.outputs) == 0:
# Invalid schema. Return the output index as a fallback.
return None
if output_index < len(schema.outputs):
return schema.outputs[output_index].name
if (
schema.outputs[-1].option
== onnx.defs.OpSchema.FormalParameterOption.Variadic
):
# The last input is variadic. Return the name of the last input.
return schema.outputs[-1].name
return None
except Exception as e:
logger.warning("Failed to get output schema name: %s", e)


def add_inputs_metadata(
onnx_node: ir.Node, node: graph_builder.GraphNode, opset_version: int
):
if onnx.defs.has(onnx_node.op_type, max_inclusive_version=opset_version):
schema = onnx.defs.get_schema(
onnx_node.op_type, max_inclusive_version=opset_version
)
else:
schema = None
for i, input_value in enumerate(onnx_node.inputs):
metadata = graph_builder.MetadataItem(id=str(i), attrs=[])
if input_value is None:
Expand All @@ -83,10 +135,23 @@ def add_inputs_metadata(onnx_node: ir.Node, node: graph_builder.GraphNode):
key="tensor_shape", value=format_tensor_shape(input_value)
)
)
if schema is not None:
if (param_name := get_node_input_param_name(schema, i)) is not None:
metadata.attrs.append(
graph_builder.KeyValue(key="param_name", value=param_name)
)
node.inputsMetadata.append(metadata)


def add_outputs_metadata(onnx_node: ir.Node, node: graph_builder.GraphNode):
def add_outputs_metadata(
onnx_node: ir.Node, node: graph_builder.GraphNode, opset_version: int
):
if onnx.defs.has(onnx_node.op_type, max_inclusive_version=opset_version):
schema = onnx.defs.get_schema(
onnx_node.op_type, max_inclusive_version=opset_version
)
else:
schema = None
for output in onnx_node.outputs:
metadata = graph_builder.MetadataItem(id=str(output.index()), attrs=[])
metadata.attrs.append(
Expand All @@ -98,6 +163,15 @@ def add_outputs_metadata(onnx_node: ir.Node, node: graph_builder.GraphNode):
key="tensor_shape", value=format_tensor_shape(output)
)
)
if schema is not None:
output_index = output.index()
assert output_index is not None
if (
param_name := get_node_output_param_name(schema, output_index)
) is not None:
metadata.attrs.append(
graph_builder.KeyValue(key="param_name", value=param_name)
)
node.outputsMetadata.append(metadata)


Expand All @@ -106,6 +180,8 @@ def add_node_attrs(onnx_node: ir.Node, node: graph_builder.GraphNode):
if isinstance(attr, ir.Attr):
if attr.type == ir.AttributeType.TENSOR:
attr_value = display_tensor(attr.value)
elif onnx_node.op_type == "Cast" and attr.name == "to":
attr_value = str(ir.DataType(attr.value))
else:
attr_value = str(attr.value)
node.attrs.append(graph_builder.KeyValue(key=attr.name, value=attr_value))
Expand Down Expand Up @@ -167,7 +243,17 @@ def create_node(
graph_inputs: set[ir.Value],
namespace: str,
all_function_ids: set[ir.OperatorIdentifier],
opset_version: int,
) -> graph_builder.GraphNode | None:
"""Create a GraphNode from an ONNX node.

Args:
onnx_node: The ONNX node to convert.
graph_inputs: The set of graph inputs.
namespace: The namespace of the node.
all_function_ids: The set of all function identifiers.
opset_version: The current ONNX opset version.
"""
if onnx_node.name is None:
logger.warning("Node does not have a name. Skipping node %s.", onnx_node)
return None
Expand All @@ -183,8 +269,8 @@ def create_node(
)
add_incoming_edges(onnx_node, node, graph_inputs)
add_node_attrs(onnx_node, node)
add_inputs_metadata(onnx_node, node)
add_outputs_metadata(onnx_node, node)
add_inputs_metadata(onnx_node, node, opset_version=opset_version)
add_outputs_metadata(onnx_node, node, opset_version=opset_version)
if onnx_node.op_identifier() in all_function_ids:
node.subgraphIds.append(get_function_graph_name(onnx_node.op_identifier()))
return node
Expand Down Expand Up @@ -281,7 +367,9 @@ def add_initializers(


def create_graph(
onnx_graph: ir.Graph | ir.Function, all_function_ids: set[ir.OperatorIdentifier]
onnx_graph: ir.Graph | ir.Function,
all_function_ids: set[ir.OperatorIdentifier],
opset_version: int,
) -> graph_builder.Graph | None:
if isinstance(onnx_graph, ir.Function):
graph_name = get_function_graph_name(onnx_graph.identifier())
Expand All @@ -301,6 +389,7 @@ def create_graph(
graph_inputs, # type: ignore
namespace=graph_name,
all_function_ids=all_function_ids,
opset_version=opset_version,
) # type: ignore
if node is None:
continue
Expand Down Expand Up @@ -354,13 +443,18 @@ def convert(
model = ir.serde.deserialize_model(onnx_model)
all_function_ids = set(model.functions)
graphs = []
opset_version = model.opset_imports.get("", _DEFAULT_OPSET_VERSION)
# TODO: Better support subgraphs in nodes
main_graph = create_graph(model.graph, all_function_ids)
main_graph = create_graph(
model.graph, all_function_ids, opset_version=opset_version
)
assert main_graph is not None
graphs.append(main_graph)

for function in model.functions.values():
function_graph = create_graph(function, all_function_ids)
function_graph = create_graph(
function, all_function_ids, opset_version=opset_version
)
assert function_graph is not None
graphs.append(function_graph)
return {"graphs": graphs}
Loading