Skip to content

Commit

Permalink
Display values from constant nodes (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored May 21, 2024
1 parent b9473ff commit aac7b04
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 3 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,5 +160,4 @@ cython_debug/
#.idea/

*.onnx
*.textproto
*.pb
26 changes: 24 additions & 2 deletions src/model_explorer_onnx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@ def can_display_tensor_json(
return True


def display_tensor_json(tensor: ir.TensorProtocol | None, settings: Settings) -> str:
def display_tensor_json(
tensor: ir.TensorProtocol | np.ndarray, settings: Settings
) -> str:
try:
array = _tensor_to_numpy(tensor)
if isinstance(tensor, np.ndarray):
array = tensor
else:
array = _tensor_to_numpy(tensor)
size_limit = settings.const_element_count_limit
if size_limit < 0 or size_limit >= array.size:
# Use separators=(',', ':') to remove spaces
Expand Down Expand Up @@ -215,12 +220,27 @@ def add_node_attrs(
if isinstance(attr, ir.Attr):
if attr.type == ir.AttributeType.TENSOR:
if can_display_tensor_json(attr.value, settings=settings):
assert attr.value is not None
set_attr(
node,
"__value",
display_tensor_json(attr.value, settings=settings),
)
attr_value = display_tensor_repr(attr.value)
elif onnx_node.op_type == "Constant" and attr.name in {
"value_float",
"value_int",
"value_string",
"value_floats",
"value_ints",
"value_strings",
}:
set_attr(
node,
"__value",
display_tensor_json(np.array(attr.value), settings=settings),
)
attr_value = str(attr.value)
elif onnx_node.op_type == "Cast" and attr.name == "to":
attr_value = str(ir.DataType(attr.value))
else:
Expand Down Expand Up @@ -417,6 +437,7 @@ def add_initializers(
node = all_nodes[initializer_node_name]
metadata = node.outputsMetadata[0]
if can_display_tensor_json(initializer.const_value, settings=settings):
assert initializer.const_value is not None
set_attr(
node,
"__value",
Expand All @@ -440,6 +461,7 @@ def add_initializers(
set_attr(metadata, "__tensor_tag", initializer.name or "")
set_type_shape_metadata(metadata, initializer.const_value)
if can_display_tensor_json(initializer.const_value, settings=settings):
assert initializer.const_value is not None
set_attr(
node,
"__value",
Expand Down
107 changes: 107 additions & 0 deletions testdata/constant_node_tests.textproto
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
ir_version: 8
graph {
node {
output: "val_0"
name: "node_Constant_0"
op_type: "Constant"
attribute {
name: "value_int"
i: 42
type: INT
}
}
node {
input: "val_0"
output: "val_1"
name: "node_Identity_1"
op_type: "Identity"
}
node {
output: "val_2"
name: "node_Constant_2"
op_type: "Constant"
attribute {
name: "value_ints"
ints: 0
ints: 42
type: INTS
}
}
node {
input: "val_2"
output: "val_3"
name: "node_Identity_3"
op_type: "Identity"
}
node {
output: "val_4"
name: "node_Constant_4"
op_type: "Constant"
attribute {
name: "value_float"
f: 42.0
type: FLOAT
}
}
node {
input: "val_4"
output: "val_5"
name: "node_Identity_5"
op_type: "Identity"
}
node {
output: "val_6"
name: "node_Constant_6"
op_type: "Constant"
attribute {
name: "value_floats"
floats: 0.0
floats: 3.1415927
type: FLOATS
}
}
node {
input: "val_6"
output: "val_7"
name: "node_Identity_7"
op_type: "Identity"
}
node {
output: "val_8"
name: "node_Constant_8"
op_type: "Constant"
attribute {
name: "value_string"
s: "hello"
type: STRING
}
}
node {
input: "val_8"
output: "val_9"
name: "node_Identity_9"
op_type: "Identity"
}
node {
output: "val_10"
name: "node_Constant_10"
op_type: "Constant"
attribute {
name: "value_strings"
strings: "hello"
strings: "world"
type: STRINGS
}
}
node {
input: "val_10"
output: "val_11"
name: "node_Identity_11"
op_type: "Identity"
}
name: "constant_node_tests"
}
opset_import {
domain: ""
version: 20
}
58 changes: 58 additions & 0 deletions tools/test_generation/constant_node_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import math
import pathlib

import onnx
from onnxscript import ir


def main():
graph = ir.Graph(
[],
[],
nodes=[
c1 := ir.Node(
"", "Constant", [], attributes=[ir.AttrInt64("value_int", 42)]
),
ir.Node("", "Identity", c1.outputs),
c2 := ir.Node(
"", "Constant", [], attributes=[ir.AttrInt64s("value_ints", [0, 42])]
),
ir.Node("", "Identity", c2.outputs),
c3 := ir.Node(
"", "Constant", [], attributes=[ir.AttrFloat32("value_float", 42.0)]
),
ir.Node("", "Identity", c3.outputs),
c4 := ir.Node(
"",
"Constant",
[],
attributes=[ir.AttrFloat32s("value_floats", [0.0, math.pi])],
),
ir.Node("", "Identity", c4.outputs),
c5 := ir.Node(
"", "Constant", [], attributes=[ir.AttrString("value_string", "hello")]
),
ir.Node("", "Identity", c5.outputs),
c6 := ir.Node(
"",
"Constant",
[],
attributes=[ir.AttrStrings("value_strings", ["hello", "world"])],
),
ir.Node("", "Identity", c6.outputs),
],
opset_imports={"": 20},
name="constant_node_tests",
)
model = ir.Model(graph, ir_version=8)
model_proto = ir.serde.serialize_model(model)
onnx.save(
model_proto,
pathlib.Path(__file__).parent.parent.parent
/ "testdata"
/ "constant_node_tests.textproto",
)


if __name__ == "__main__":
main()

0 comments on commit aac7b04

Please sign in to comment.