Skip to content

Commit

Permalink
Updated configs for MPT models.
Browse files Browse the repository at this point in the history
  • Loading branch information
andreyanufr committed Jun 28, 2023
1 parent 94690e5 commit c4a1866
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def generate(self, input_name: str, framework: str = "pt"):
for _ in range(self.num_layers)
]


class MPTOnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MPTPastKeyValuesGenerator)
DEFAULT_ONNX_OPSET = 14
Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,13 @@ class TasksManager:
onnx="MPNetOnnxConfig",
tflite="MPNetTFLiteConfig",
),
"mpt": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
onnx="MPTOnnxConfig",
),
"mt5": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class NormalizedConfigManager:
"marian": BartLikeNormalizedTextConfig,
"mbart": BartLikeNormalizedTextConfig,
"mt5": T5LikeNormalizedTextConfig,
"mpt": NormalizedTextConfig.with_args(hidden_size="d_model", num_attention_heads="n_heads", num_layers="n_layers"),
"m2m_100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
"opt": NormalizedTextConfig,
Expand Down

0 comments on commit c4a1866

Please sign in to comment.