From dabfaa8ca51767ef5c2fd7d20fe85439e0db2880 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 27 Jul 2023 19:50:33 +0000 Subject: [PATCH] Update `flatten_past_key_values` --- optimum/exporters/onnx/model_configs.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index a4df2591f3..2c1d464948 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -306,6 +306,9 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire 1: decoder_sequence_name, } + def flatten_past_key_values(self, flattened_output, name, idx, t): + flattened_output[f"{name}.{idx}.key_value"] = t + class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt"): encoder_shape = (