Skip to content

Commit

Permalink
Add ONNX export support for SwinV2
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Aug 29, 2024
1 parent 8d4b09e commit b96bb61
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Splinter
- SqueezeBert
- Swin
- SwinV2
- T5
- Table Transformer
- TROCR
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,10 @@ class SwinOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class SwinV2OnnxConfig(SwinOnnxConfig):
pass


class Swin2srOnnxConfig(SwinOnnxConfig):
pass

Expand Down
6 changes: 6 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,12 @@ class TasksManager:
"masked-im",
onnx="SwinOnnxConfig",
),
"swinv2": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"masked-im",
onnx="SwinV2OnnxConfig",
),
"swin2sr": supported_tasks_mapping(
"feature-extraction",
"image-to-image",
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,7 +1682,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForImageClassification(ORTModel):
"""
ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit.
ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, swinv2, vit.
"""

auto_model_class = AutoModelForImageClassification
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def check_optimization_supported_model(cls, model_type: str, optimization_config
"clip",
"vit",
"swin",
"swinv2",
]
model_type = model_type.replace("_", "-")
if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization):
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
"splinter": "hf-internal-testing/tiny-random-SplinterModel",
"squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"swinv2": "hf-internal-testing/tiny-random-Swinv2Model",
"swin2sr": "hf-internal-testing/tiny-random-Swin2SRModel",
"t5": "hf-internal-testing/tiny-random-t5",
"table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel",
Expand Down Expand Up @@ -268,6 +269,7 @@
"splinter": "hf-internal-testing/tiny-random-SplinterModel",
"squeezebert": "squeezebert/squeezebert-uncased",
"swin": "microsoft/swin-tiny-patch4-window7-224",
"swinv2": "microsoft/swinv2-tiny-patch4-window16-256",
"t5": "t5-small",
"table-transformer": "microsoft/table-transformer-detection",
"vit": "google/vit-base-patch16-224",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"swinv2": "hf-internal-testing/tiny-random-Swinv2Model",
"swin-window": "yujiepan/tiny-random-swin-patch4-window7-224",
"t5": "hf-internal-testing/tiny-random-t5",
"table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel",
Expand Down

0 comments on commit b96bb61

Please sign in to comment.