diff --git a/.gitignore b/.gitignore index cff7bce..89e38d1 100644 --- a/.gitignore +++ b/.gitignore @@ -160,5 +160,4 @@ cython_debug/ #.idea/ *.onnx -*.textproto *.pb diff --git a/src/model_explorer_onnx/main.py b/src/model_explorer_onnx/main.py index bf232a3..d191b30 100644 --- a/src/model_explorer_onnx/main.py +++ b/src/model_explorer_onnx/main.py @@ -54,9 +54,14 @@ def can_display_tensor_json( return True -def display_tensor_json(tensor: ir.TensorProtocol | None, settings: Settings) -> str: +def display_tensor_json( + tensor: ir.TensorProtocol | np.ndarray, settings: Settings +) -> str: try: - array = _tensor_to_numpy(tensor) + if isinstance(tensor, np.ndarray): + array = tensor + else: + array = _tensor_to_numpy(tensor) size_limit = settings.const_element_count_limit if size_limit < 0 or size_limit >= array.size: # Use separators=(',', ':') to remove spaces @@ -215,12 +220,27 @@ def add_node_attrs( if isinstance(attr, ir.Attr): if attr.type == ir.AttributeType.TENSOR: if can_display_tensor_json(attr.value, settings=settings): + assert attr.value is not None set_attr( node, "__value", display_tensor_json(attr.value, settings=settings), ) attr_value = display_tensor_repr(attr.value) + elif onnx_node.op_type == "Constant" and attr.name in { + "value_float", + "value_int", + "value_string", + "value_floats", + "value_ints", + "value_strings", + }: + set_attr( + node, + "__value", + display_tensor_json(np.array(attr.value), settings=settings), + ) + attr_value = str(attr.value) elif onnx_node.op_type == "Cast" and attr.name == "to": attr_value = str(ir.DataType(attr.value)) else: @@ -417,6 +437,7 @@ def add_initializers( node = all_nodes[initializer_node_name] metadata = node.outputsMetadata[0] if can_display_tensor_json(initializer.const_value, settings=settings): + assert initializer.const_value is not None set_attr( node, "__value", @@ -440,6 +461,7 @@ def add_initializers( set_attr(metadata, "__tensor_tag", initializer.name or "") set_type_shape_metadata(metadata, initializer.const_value) if can_display_tensor_json(initializer.const_value, settings=settings): + assert initializer.const_value is not None set_attr( node, "__value", diff --git a/testdata/constant_node_tests.textproto b/testdata/constant_node_tests.textproto new file mode 100644 index 0000000..a048b86 --- /dev/null +++ b/testdata/constant_node_tests.textproto @@ -0,0 +1,107 @@ +ir_version: 8 +graph { + node { + output: "val_0" + name: "node_Constant_0" + op_type: "Constant" + attribute { + name: "value_int" + i: 42 + type: INT + } + } + node { + input: "val_0" + output: "val_1" + name: "node_Identity_1" + op_type: "Identity" + } + node { + output: "val_2" + name: "node_Constant_2" + op_type: "Constant" + attribute { + name: "value_ints" + ints: 0 + ints: 42 + type: INTS + } + } + node { + input: "val_2" + output: "val_3" + name: "node_Identity_3" + op_type: "Identity" + } + node { + output: "val_4" + name: "node_Constant_4" + op_type: "Constant" + attribute { + name: "value_float" + f: 42.0 + type: FLOAT + } + } + node { + input: "val_4" + output: "val_5" + name: "node_Identity_5" + op_type: "Identity" + } + node { + output: "val_6" + name: "node_Constant_6" + op_type: "Constant" + attribute { + name: "value_floats" + floats: 0.0 + floats: 3.1415927 + type: FLOATS + } + } + node { + input: "val_6" + output: "val_7" + name: "node_Identity_7" + op_type: "Identity" + } + node { + output: "val_8" + name: "node_Constant_8" + op_type: "Constant" + attribute { + name: "value_string" + s: "hello" + type: STRING + } + } + node { + input: "val_8" + output: "val_9" + name: "node_Identity_9" + op_type: "Identity" + } + node { + output: "val_10" + name: "node_Constant_10" + op_type: "Constant" + attribute { + name: "value_strings" + strings: "hello" + strings: "world" + type: STRINGS + } + } + node { + input: "val_10" + output: "val_11" + name: "node_Identity_11" + op_type: "Identity" + } + name: "constant_node_tests" +} +opset_import { + domain: "" + version: 20 +} diff --git a/tools/test_generation/constant_node_tests.py b/tools/test_generation/constant_node_tests.py new file mode 100644 index 0000000..3ec7535 --- /dev/null +++ b/tools/test_generation/constant_node_tests.py @@ -0,0 +1,58 @@ +import math +import pathlib + +import onnx +from onnxscript import ir + + +def main(): + graph = ir.Graph( + [], + [], + nodes=[ + c1 := ir.Node( + "", "Constant", [], attributes=[ir.AttrInt64("value_int", 42)] + ), + ir.Node("", "Identity", c1.outputs), + c2 := ir.Node( + "", "Constant", [], attributes=[ir.AttrInt64s("value_ints", [0, 42])] + ), + ir.Node("", "Identity", c2.outputs), + c3 := ir.Node( + "", "Constant", [], attributes=[ir.AttrFloat32("value_float", 42.0)] + ), + ir.Node("", "Identity", c3.outputs), + c4 := ir.Node( + "", + "Constant", + [], + attributes=[ir.AttrFloat32s("value_floats", [0.0, math.pi])], + ), + ir.Node("", "Identity", c4.outputs), + c5 := ir.Node( + "", "Constant", [], attributes=[ir.AttrString("value_string", "hello")] + ), + ir.Node("", "Identity", c5.outputs), + c6 := ir.Node( + "", + "Constant", + [], + attributes=[ir.AttrStrings("value_strings", ["hello", "world"])], + ), + ir.Node("", "Identity", c6.outputs), + ], + opset_imports={"": 20}, + name="constant_node_tests", + ) + model = ir.Model(graph, ir_version=8) + model_proto = ir.serde.serialize_model(model) + onnx.save( + model_proto, + pathlib.Path(__file__).parent.parent.parent + / "testdata" + / "constant_node_tests.textproto", + ) + + +if __name__ == "__main__": + main()