Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX export support for DinoV2, Hiera, Maskformer, PVT, SigLIP, SwinV2, VitMAE, and VitMSN models #2001

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Deberta-v2
- Deit
- Detr
- DINOv2
- DistilBert
- Donut-Swin
- Electra
Expand All @@ -52,6 +53,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- GPT-NeoX
- OPT
- GroupVit
- Hiera
- Hubert
- IBert
- LayoutLM
Expand All @@ -63,6 +65,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- M2-M100
- Marian
- MarkupLM
- Maskformer
- MBart
- Mistral
- MobileBert
Expand All @@ -80,6 +83,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Phi3
- Pix2Struct
- PoolFormer
- PVT
- Qwen2(Qwen1.5)
- RegNet
- ResNet
Expand All @@ -90,17 +94,21 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- SEW
- SEW-D
- Speech2Text
- SigLIP
- SpeechT5
- Splinter
- SqueezeBert
- Swin
- SwinV2
- T5
- Table Transformer
- TROCR
- UniSpeech
- UniSpeech SAT
- Vision Encoder Decoder
- Vit
- VitMAE
- VitMSN
- Wav2Vec2
- Wav2Vec2 Conformer
- WavLM
Expand Down
106 changes: 106 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,65 @@ class ConvNextV2OnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class HieraOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class PvtOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class VitMAEOnnxConfig(ViTOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14


class VitMSNOnnxConfig(ViTOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14


class Dinov2DummyInputGenerator(DummyVisionInputGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"],
height: int = DEFAULT_DUMMY_SHAPES["height"],
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
num_channels=num_channels,
width=width,
height=height,
**kwargs,
)

from transformers.onnx.utils import get_preprocessor

preprocessor = get_preprocessor(normalized_config._name_or_path)
if preprocessor is not None and hasattr(preprocessor, "crop_size"):
self.height = preprocessor.crop_size.get("height", self.height)
self.width = preprocessor.crop_size.get("width", self.width)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
input_ = super().generate(
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
)
return input_


class Dinov2OnnxConfig(ViTOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (Dinov2DummyInputGenerator,)


class MobileViTOnnxConfig(ViTOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 11
Expand Down Expand Up @@ -813,6 +872,10 @@ class SwinOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class SwinV2OnnxConfig(SwinOnnxConfig):
pass


class Swin2srOnnxConfig(SwinOnnxConfig):
pass

Expand Down Expand Up @@ -848,6 +911,22 @@ class MobileNetV2OnnxConfig(MobileNetV1OnnxConfig):
pass


class MaskformerOnnxConfig(ViTOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::einsum' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 12, try exporting with this version.
DEFAULT_ONNX_OPSET = 12

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "image-segmentation":
return {
"class_queries_logits": {0: "batch_size", 1: "num_queries"},
"masks_queries_logits": {0: "batch_size", 1: "num_queries", 2: "height", 3: "width"},
}
else:
return super().outputs


class DonutSwinOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11

Expand Down Expand Up @@ -1034,6 +1113,33 @@ def patch_model_for_export(
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class SiglipNormalizedConfig(CLIPNormalizedConfig):
pass


class SiglipOnnxConfig(CLIPOnnxConfig):
NORMALIZED_CONFIG_CLASS = SiglipNormalizedConfig
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 13 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "text_batch_size", 1: "sequence_length"},
"pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"},
# NOTE: No attention_mask
}


class SiglipTextWithProjectionOnnxConfig(CLIPTextWithProjectionOnnxConfig):
pass


class SiglipTextOnnxConfig(CLIPTextOnnxConfig):
pass


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
Expand Down
58 changes: 56 additions & 2 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class TasksManager:
"feature-extraction": "AutoModel",
"fill-mask": "AutoModelForMaskedLM",
"image-classification": "AutoModelForImageClassification",
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"),
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation", "AutoModelForInstanceSegmentation", "AutoModelForUniversalSegmentation"),
"image-to-image": "AutoModelForImageToImage",
"image-to-text": "AutoModelForVision2Seq",
"mask-generation": "AutoModel",
Expand All @@ -223,6 +223,7 @@ class TasksManager:
"text2text-generation": "AutoModelForSeq2SeqLM",
"text-classification": "AutoModelForSequenceClassification",
"token-classification": "AutoModelForTokenClassification",
"visual-question-answering": "AutoModelForVisualQuestionAnswering",
"zero-shot-image-classification": "AutoModelForZeroShotImageClassification",
"zero-shot-object-detection": "AutoModelForZeroShotObjectDetection",
}
Expand Down Expand Up @@ -574,6 +575,11 @@ class TasksManager:
"image-segmentation",
onnx="DetrOnnxConfig",
),
"dinov2": supported_tasks_mapping(
"feature-extraction",
"image-classification",
onnx="Dinov2OnnxConfig",
),
"distilbert": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down Expand Up @@ -705,6 +711,11 @@ class TasksManager:
"feature-extraction",
onnx="GroupViTOnnxConfig",
),
"hiera": supported_tasks_mapping(
"feature-extraction",
"image-classification",
onnx="HieraOnnxConfig",
),
"hubert": supported_tasks_mapping(
"feature-extraction",
"automatic-speech-recognition",
Expand Down Expand Up @@ -786,6 +797,11 @@ class TasksManager:
"question-answering",
onnx="MarkupLMOnnxConfig",
),
"maskformer": supported_tasks_mapping(
"feature-extraction",
"image-segmentation",
onnx="MaskformerOnnxConfig",
),
"mbart": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down Expand Up @@ -958,6 +974,11 @@ class TasksManager:
"image-classification",
onnx="PoolFormerOnnxConfig",
),
"pvt": supported_tasks_mapping(
"feature-extraction",
"image-classification",
onnx="PvtOnnxConfig",
),
"regnet": supported_tasks_mapping(
"feature-extraction",
"image-classification",
Expand Down Expand Up @@ -1017,6 +1038,19 @@ class TasksManager:
"audio-classification",
onnx="SEWDOnnxConfig",
),
"siglip": supported_tasks_mapping(
"feature-extraction",
"zero-shot-image-classification",
onnx="SiglipOnnxConfig",
),
"siglip-text-model": supported_tasks_mapping(
"feature-extraction",
onnx="SiglipTextOnnxConfig",
),
"siglip-text-with-projection": supported_tasks_mapping(
"feature-extraction",
onnx="SiglipTextWithProjectionOnnxConfig",
),
"speech-to-text": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down Expand Up @@ -1049,6 +1083,12 @@ class TasksManager:
"masked-im",
onnx="SwinOnnxConfig",
),
"swinv2": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"masked-im",
onnx="SwinV2OnnxConfig",
),
"swin2sr": supported_tasks_mapping(
"feature-extraction",
"image-to-image",
Expand Down Expand Up @@ -1095,7 +1135,21 @@ class TasksManager:
onnx="VisionEncoderDecoderOnnxConfig",
),
"vit": supported_tasks_mapping(
"feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig"
"feature-extraction",
"image-classification",
"masked-im",
onnx="ViTOnnxConfig",
),
"vit-mae": supported_tasks_mapping(
"feature-extraction",
"masked-im",
onnx="VitMAEOnnxConfig",
),
"vit-msn": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"masked-im",
onnx="VitMSNOnnxConfig",
),
"vits": supported_tasks_mapping(
"text-to-audio",
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1682,7 +1682,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForImageClassification(ORTModel):
"""
ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit.
ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, dinov2, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, swinv2, vit.
"""

auto_model_class = AutoModelForImageClassification
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def check_optimization_supported_model(cls, model_type: str, optimization_config
"clip",
"vit",
"swin",
"swinv2",
]
model_type = model_type.replace("_", "-")
if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization):
Expand Down
4 changes: 4 additions & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,10 @@ class NormalizedConfigManager:
'data2vec-text',
'data2vec-vision',
'detr',
'dinov2',
'flaubert',
'groupvit',
'hiera',
'ibert',
'layoutlm',
'layoutlmv3',
Expand All @@ -216,6 +218,8 @@ class NormalizedConfigManager:
'owlvit',
'perceiver',
'roformer',
'segformer',
'siglip',
'squeezebert',
'table-transformer',
"""
Expand Down
Loading
Loading