diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 22471c297a..747e1396fb 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -77,6 +77,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - Pegasus - Perceiver - Phi +- Phi3 - Pix2Struct - PoolFormer - Qwen2(Qwen1.5) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index d4c4ac934b..97f3f7b999 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -51,6 +51,7 @@ NormalizedSeq2SeqConfig, NormalizedTextAndVisionConfig, NormalizedTextConfig, + NormalizedTextConfigWithGQA, NormalizedVisionConfig, is_diffusers_available, logging, @@ -287,6 +288,14 @@ class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig +class Phi3OnnxConfig(PhiOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + MistralDummyPastKeyValuesGenerator, + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA + + class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): # This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35 MIN_TRANSFORMERS_VERSION = version.parse("4.34.99") diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index efa782353b..608b3df0d7 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -893,6 +893,14 @@ class TasksManager: "text-classification", onnx="PhiOnnxConfig", ), + "phi3": supported_tasks_mapping( + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + onnx="Phi3OnnxConfig", + ), "pix2struct": supported_tasks_mapping( "image-to-text", "image-to-text-with-past", diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 07be3f7e1a..5d5044e63e 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -80,5 +80,6 @@ NormalizedSeq2SeqConfig, NormalizedTextAndVisionConfig, NormalizedTextConfig, + NormalizedTextConfigWithGQA, NormalizedVisionConfig, )