Skip to content

Commit

Permalink
Added correct past key value preprocessing for MPT models.
Browse files Browse the repository at this point in the history
  • Loading branch information
andreyanufr committed Jul 4, 2023
1 parent c4a1866 commit 5cb208f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
8 changes: 4 additions & 4 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,10 +576,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
sequence_length = dummy_input_gen.sequence_length
if "sequence_length" in kwargs and kwargs["sequence_length"] != 1:
logger.info(
f"Asked a sequence length of {kwargs['sequence_length']}, but a sequence length of 1 "
f"Asked a sequence length of {kwargs['sequence_length']}, but a sequence length of 2 "
f"will be used with use_past == True for `{input_name}`."
)
dummy_input_gen.sequence_length = 1
dummy_input_gen.sequence_length = 2
dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework=framework)
dummy_input_gen.sequence_length = sequence_length
else:
Expand All @@ -601,7 +601,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
past_length = dummy_inputs["past_key_values"][0][0].shape[2]
dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
desired_length=past_length + 1,
desired_length=past_length + 2,
dim=1,
dtype=dummy_inputs["attention_mask"].dtype,
)
Expand All @@ -610,7 +610,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
past_length = dummy_inputs["past_key_values"][0][0].shape[2]
dummy_inputs["decoder_attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["decoder_attention_mask"],
desired_length=past_length + 1,
desired_length=past_length + 2,
dim=1,
dtype=dummy_inputs["decoder_attention_mask"].dtype,
)
Expand Down
25 changes: 25 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,31 @@ class MPTOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(hidden_size="d_model", num_attention_heads="n_heads", num_layers="n_layers")

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
"""
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
Args:
inputs_or_outputs (`Dict[str, Dict[int, str]]`):
The mapping to fill.
direction (`str`):
either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
output mapping, this is important for axes naming.
"""
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 3: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name}


class OPTOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
Expand Down

0 comments on commit 5cb208f

Please sign in to comment.