Skip to content

Commit

Permalink
add phi3 support
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed May 23, 2024
1 parent e0f5812 commit 1f66b6a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Pegasus
- Perceiver
- Phi
- Phi3
- Pix2Struct
- PoolFormer
- Qwen2(Qwen1.5)
Expand Down
9 changes: 9 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedVisionConfig,
is_diffusers_available,
logging,
Expand Down Expand Up @@ -287,6 +288,14 @@ class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class Phi3OnnxConfig(PhiOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
MistralDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA


class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
# This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35
MIN_TRANSFORMERS_VERSION = version.parse("4.34.99")
Expand Down
8 changes: 8 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,14 @@ class TasksManager:
"text-classification",
onnx="PhiOnnxConfig",
),
"phi3": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="Phi3OnnxConfig",
),
"pix2struct": supported_tasks_mapping(
"image-to-text",
"image-to-text-with-past",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,6 @@
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedVisionConfig,
)

0 comments on commit 1f66b6a

Please sign in to comment.