From b96bb6184523d63d62bd5a776264edf452841753 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 29 Aug 2024 13:56:07 +0000 Subject: [PATCH] Add ONNX export support for SwinV2 --- docs/source/exporters/onnx/overview.mdx | 1 + optimum/exporters/onnx/model_configs.py | 4 ++++ optimum/exporters/tasks.py | 6 ++++++ optimum/onnxruntime/modeling_ort.py | 2 +- optimum/onnxruntime/utils.py | 1 + tests/exporters/exporters_utils.py | 2 ++ tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 7 files changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 11d0bc4a92..908d08b6f3 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -95,6 +95,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - Splinter - SqueezeBert - Swin +- SwinV2 - T5 - Table Transformer - TROCR diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 81de33116f..7670f95b8e 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -817,6 +817,10 @@ class SwinOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 11 +class SwinV2OnnxConfig(SwinOnnxConfig): + pass + + class Swin2srOnnxConfig(SwinOnnxConfig): pass diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 01270f0b40..b771eb731f 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1054,6 +1054,12 @@ class TasksManager: "masked-im", onnx="SwinOnnxConfig", ), + "swinv2": supported_tasks_mapping( + "feature-extraction", + "image-classification", + "masked-im", + onnx="SwinV2OnnxConfig", + ), "swin2sr": supported_tasks_mapping( "feature-extraction", "image-to-image", diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 254b771e33..7e53005ed2 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -1682,7 +1682,7 @@ def forward( @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) class ORTModelForImageClassification(ORTModel): """ - ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit. + ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, swinv2, vit. """ auto_model_class = AutoModelForImageClassification diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index ad40af92b9..e4c16ae83a 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -175,6 +175,7 @@ def check_optimization_supported_model(cls, model_type: str, optimization_config "clip", "vit", "swin", + "swinv2", ] model_type = model_type.replace("_", "-") if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization): diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 2af51fc183..eec4bb8dd2 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -148,6 +148,7 @@ "splinter": "hf-internal-testing/tiny-random-SplinterModel", "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", "swin": "hf-internal-testing/tiny-random-SwinModel", + "swinv2": "hf-internal-testing/tiny-random-Swinv2Model", "swin2sr": "hf-internal-testing/tiny-random-Swin2SRModel", "t5": "hf-internal-testing/tiny-random-t5", "table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel", @@ -268,6 +269,7 @@ "splinter": "hf-internal-testing/tiny-random-SplinterModel", "squeezebert": "squeezebert/squeezebert-uncased", "swin": "microsoft/swin-tiny-patch4-window7-224", + "swinv2": "microsoft/swinv2-tiny-patch4-window16-256", "t5": "t5-small", "table-transformer": "microsoft/table-transformer-detection", "vit": "google/vit-base-patch16-224", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 3dc6be1909..3b52194a12 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -144,6 +144,7 @@ "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "swin": "hf-internal-testing/tiny-random-SwinModel", + "swinv2": "hf-internal-testing/tiny-random-Swinv2Model", "swin-window": "yujiepan/tiny-random-swin-patch4-window7-224", "t5": "hf-internal-testing/tiny-random-t5", "table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel",