Skip to content

Commit

Permalink
Update flatten_past_key_values
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Jul 27, 2023
1 parent 674639a commit dabfaa8
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
1: decoder_sequence_name,
}

def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key_value"] = t

class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
encoder_shape = (
Expand Down

0 comments on commit dabfaa8

Please sign in to comment.