Skip to content

Commit

Permalink
onnx config for MPT
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Jul 28, 2023
1 parent 52efe82 commit e47d95f
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 31 deletions.
6 changes: 2 additions & 4 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,6 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
if input_name == "past_key_values":
sequence_length_pkv = dummy_input_gen.sequence_length
# models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name
# while models from TextDecoderOnnxConfig use input_ids, hence the check for both
if (
Expand Down Expand Up @@ -598,7 +596,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
and self.use_cache_branch is not False
and "attention_mask" in dummy_inputs
):
past_length = sequence_length_pkv
past_length = dummy_inputs["past_key_values"][0][0].shape[2]
dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
desired_length=past_length + 1,
Expand All @@ -607,7 +605,7 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
)

if self.use_past_in_inputs and self.use_cache_branch is not False and "decoder_attention_mask" in dummy_inputs:
past_length = sequence_length_pkv
past_length = dummy_inputs["past_key_values"][0][0].shape[2]
dummy_inputs["decoder_attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["decoder_attention_mask"],
desired_length=past_length + 1,
Expand Down
27 changes: 0 additions & 27 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,38 +215,11 @@ class LlamaOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class MPTDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_shape = (
self.batch_size,
self.num_attention_heads,
self.hidden_size // self.num_attention_heads,
self.sequence_length,
)
past_value_shape = (
self.batch_size,
self.num_attention_heads,
self.sequence_length,
self.hidden_size // self.num_attention_heads,
)
return [
(
self.random_float_tensor(past_key_shape, framework=framework),
self.random_float_tensor(past_value_shape, framework=framework),
)
for _ in range(self.num_layers)
]


class MPTOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers"
)
DUMMY_INPUT_GENERATOR_CLASSES = (
MPTDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = MPTDummyPastKeyValuesGenerator


class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
Expand Down

0 comments on commit e47d95f

Please sign in to comment.