Skip to content

Commit

Permalink
Fix incorrect names for usage blenderbot for causallm (#1887)
Browse files Browse the repository at this point in the history
* Fix incorrect names for usage blenderbot for causallm

* fix input dynamic shapes as dummy input seq len  != 1

* apply code style
  • Loading branch information
eaidova authored Jul 1, 2024
1 parent d0a84a9 commit d82d4c6
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def inputs_for_default_and_seq2seq_lm(self):
def inputs_for_causal_lm(self):
if self.use_past_in_inputs:
common_inputs = {
"input_ids": {0: "batch_size"},
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "past_sequence_length + 1"},
}
for i in range(self._normalized_config.decoder_num_layers):
Expand Down Expand Up @@ -645,7 +645,11 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super(OnnxConfigWithPast, self).outputs
if self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
for i in range(self._normalized_config.encoder_num_layers):
for i in range(
self._normalized_config.encoder_num_layers
if self.task != "text-generation"
else self._normalized_config.decoder_num_layers
):
common_outputs[f"present.{i}.key"] = {0: "batch_size", 2: "past_sequence_length + sequence_length"}
common_outputs[f"present.{i}.value"] = {
0: "batch_size",
Expand Down

0 comments on commit d82d4c6

Please sign in to comment.