diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 4aaca160ca..5c20f1ac52 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -228,6 +228,7 @@ def generate(self, input_name: str, framework: str = "pt"): for _ in range(self.num_layers) ] + class MPTOnnxConfig(TextDecoderOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MPTPastKeyValuesGenerator) DEFAULT_ONNX_OPSET = 14 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index e841a6e047..2a04c1552f 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -654,6 +654,13 @@ class TasksManager: onnx="MPNetOnnxConfig", tflite="MPNetTFLiteConfig", ), + "mpt": supported_tasks_mapping( + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + onnx="MPTOnnxConfig", + ), "mt5": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 3016dfa736..b9838988fe 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -222,6 +222,7 @@ class NormalizedConfigManager: "marian": BartLikeNormalizedTextConfig, "mbart": BartLikeNormalizedTextConfig, "mt5": T5LikeNormalizedTextConfig, + "mpt": NormalizedTextConfig.with_args(hidden_size="d_model", num_attention_heads="n_heads", num_layers="n_layers"), "m2m_100": BartLikeNormalizedTextConfig, "nystromformer": NormalizedTextConfig, "opt": NormalizedTextConfig,