From 05112da833b68638bbc077f06ec5a5d61a0b666d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 13 Mar 2024 19:47:23 +0000 Subject: [PATCH] Print fx graph size --- docs/examples/llama/llama_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/examples/llama/llama_model.py b/docs/examples/llama/llama_model.py index 124645675..f55e181d7 100644 --- a/docs/examples/llama/llama_model.py +++ b/docs/examples/llama/llama_model.py @@ -191,7 +191,8 @@ def __init__(self, config): def forward(self, input_ids, attention_mask): model_output = self.model(input_ids, attention_mask=attention_mask) - return model_output[0] + # Output 2, 3 are None + return model_output[0], model_output[1] def generate_example_inputs(batch: int, seq: int, vocab_size: int): input_ids = ids_tensor([batch, seq], vocab_size) @@ -221,7 +222,9 @@ def display_model_stats(model: onnx.ModelProto): def export(): model, example_args_collection = get_llama_model() exported = torch.export.export(model, example_args_collection[0]) + print("===exported fx graph===") print(exported) + print("FX Node count:", len(exported.graph.nodes)) exported_onnx = torch.onnx.dynamo_export(exported, *example_args_collection[0]).model_proto print("===exported_onnx===") display_model_stats(exported_onnx)