From 6772106b8e972958546656dc50989be7b5f073bb Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 11 Jul 2024 14:13:16 +0200 Subject: [PATCH] Clip vision model onnx export (#1920) clip vision model onnx export --- optimum/exporters/onnx/model_configs.py | 16 ++++++++++++++++ optimum/exporters/tasks.py | 4 ++++ tests/exporters/exporters_utils.py | 1 + 3 files changed, 21 insertions(+) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index c66e54b323..e2bcd7fe20 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -896,6 +896,22 @@ class CLIPNormalizedConfig(NormalizedTextAndVisionConfig): VISION_CONFIG = "vision_config" +class CLIPVisionModelOnnxConfig(VisionOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + common_outputs = super().outputs + common_outputs["last_hidden_state"] = {0: "batch_size"} + common_outputs["pooler_output"] = {0: "batch_size"} + + return common_outputs + + class CLIPOnnxConfig(TextAndVisionOnnxConfig): NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig DEFAULT_ONNX_OPSET = 14 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 2896842f93..c0221f7bf6 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -448,6 +448,10 @@ class TasksManager: "zero-shot-image-classification", onnx="CLIPOnnxConfig", ), + "clip-vision-model": supported_tasks_mapping( + "feature-extraction", + onnx="CLIPVisionModelOnnxConfig", + ), "codegen": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 0c52754ff6..9c5d2c8991 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -56,6 +56,7 @@ "bloom": "hf-internal-testing/tiny-random-BloomModel", "camembert": "hf-internal-testing/tiny-random-camembert", "clip": "hf-internal-testing/tiny-random-CLIPModel", + "clip-vision-model": "fxmarty/clip-vision-model-tiny", "convbert": "hf-internal-testing/tiny-random-ConvBertModel", "convnext": "hf-internal-testing/tiny-random-convnext", "convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model",