diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 908d08b6f3..79d3d1e86f 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -38,6 +38,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - Deberta-v2 - Deit - Detr +- DINOv2 - DistilBert - Donut-Swin - Electra diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 7670f95b8e..2b94306bc6 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -776,6 +776,45 @@ class HieraOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 11 +class Dinov2DummyInputGenerator(DummyVisionInputGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedVisionConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], + width: int = DEFAULT_DUMMY_SHAPES["width"], + height: int = DEFAULT_DUMMY_SHAPES["height"], + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + num_channels=num_channels, + width=width, + height=height, + **kwargs, + ) + + from transformers.onnx.utils import get_preprocessor + + preprocessor = get_preprocessor(normalized_config._name_or_path) + if preprocessor is not None and hasattr(preprocessor, "crop_size"): + self.height = preprocessor.crop_size.get("height", self.height) + self.width = preprocessor.crop_size.get("width", self.width) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + input_ = super().generate( + input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype + ) + return input_ + + +class Dinov2OnnxConfig(ViTOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (Dinov2DummyInputGenerator,) + + class MobileViTOnnxConfig(ViTOnnxConfig): ATOL_FOR_VALIDATION = 1e-4 DEFAULT_ONNX_OPSET = 11 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index b771eb731f..c164cc7bb8 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -574,6 +574,11 @@ class TasksManager: "image-segmentation", onnx="DetrOnnxConfig", ), + "dinov2": supported_tasks_mapping( + "feature-extraction", + "image-classification", + onnx="Dinov2OnnxConfig", + ), "distilbert": supported_tasks_mapping( "feature-extraction", "fill-mask", diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 7e53005ed2..595bc9d0a3 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -1682,7 +1682,7 @@ def forward( @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) class ORTModelForImageClassification(ORTModel): """ - ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, swinv2, vit. + ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, dinov2, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, swinv2, vit. """ auto_model_class = AutoModelForImageClassification diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 64dc0d7cc9..8ab5a304d1 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -204,6 +204,7 @@ class NormalizedConfigManager: 'data2vec-text', 'data2vec-vision', 'detr', + 'dinov2', 'flaubert', 'groupvit', 'hiera', diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index eec4bb8dd2..577f020c5b 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -68,6 +68,7 @@ "deberta": "hf-internal-testing/tiny-random-DebertaModel", "deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model", "deit": "hf-internal-testing/tiny-random-DeiTModel", + "dinov2": "hf-internal-testing/tiny-random-Dinov2Model", "donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder", "donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel", "detr": "hf-internal-testing/tiny-random-DetrModel", # hf-internal-testing/tiny-random-detr is larger diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 4b44acb38a..3e68de3195 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2727,6 +2727,7 @@ class ORTModelForImageClassificationIntegrationTest(ORTModelTestMixin): "convnextv2", "data2vec_vision", "deit", + "dinov2", "levit", "mobilenet_v1", "mobilenet_v2", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 3b52194a12..0e8d42fbcc 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -87,8 +87,9 @@ "deit": "hf-internal-testing/tiny-random-DeiTModel", "donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder", "detr": "hf-internal-testing/tiny-random-detr", - "dpt": "hf-internal-testing/tiny-random-DPTModel", + "dinov2": "hf-internal-testing/tiny-random-Dinov2Model", "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", + "dpt": "hf-internal-testing/tiny-random-DPTModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", "encoder-decoder": { "hf-internal-testing/tiny-random-EncoderDecoderModel-bert-bert": [