diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 5a6c4a4201..6b2b8bdb2a 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -576,10 +576,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): sequence_length = dummy_input_gen.sequence_length if "sequence_length" in kwargs and kwargs["sequence_length"] != 1: logger.info( - f"Asked a sequence length of {kwargs['sequence_length']}, but a sequence length of 1 " + f"Asked a sequence length of {kwargs['sequence_length']}, but a sequence length of 2 " f"will be used with use_past == True for `{input_name}`." ) - dummy_input_gen.sequence_length = 1 + dummy_input_gen.sequence_length = 2 dummy_inputs[input_name] = dummy_input_gen.generate(input_name, framework=framework) dummy_input_gen.sequence_length = sequence_length else: @@ -601,7 +601,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): past_length = dummy_inputs["past_key_values"][0][0].shape[2] dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( dummy_inputs["attention_mask"], - desired_length=past_length + 1, + desired_length=past_length + 2, dim=1, dtype=dummy_inputs["attention_mask"].dtype, ) @@ -610,7 +610,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): past_length = dummy_inputs["past_key_values"][0][0].shape[2] dummy_inputs["decoder_attention_mask"] = DummyInputGenerator.pad_input_on_dim( dummy_inputs["decoder_attention_mask"], - desired_length=past_length + 1, + desired_length=past_length + 2, dim=1, dtype=dummy_inputs["decoder_attention_mask"].dtype, ) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 5c20f1ac52..2b9b49a039 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -234,6 +234,31 @@ class MPTOnnxConfig(TextDecoderOnnxConfig): DEFAULT_ONNX_OPSET = 14 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(hidden_size="d_model", num_attention_heads="n_heads", num_layers="n_layers") + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + """ + Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. + + Args: + inputs_or_outputs (`Dict[str, Dict[int, str]]`): + The mapping to fill. + direction (`str`): + either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the + output mapping, this is important for axes naming. + """ + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 3: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name} + class OPTOnnxConfig(TextDecoderOnnxConfig): DEFAULT_ONNX_OPSET = 13