From b610212fa9037d5af3096b49c176188cc7a20e64 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 25 Jul 2024 13:43:02 +0200 Subject: [PATCH] patch all clip variants --- optimum/exporters/onnx/model_configs.py | 26 ++++++++++++++++++++++--- optimum/exporters/onnx/model_patcher.py | 2 +- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 25e6e42dee..1bae34d895 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -71,7 +71,7 @@ ) from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME from .model_patcher import ( - CLIPTextModelPatcher, + CLIPModelPatcher, FalconModelPatcher, MistralModelPatcher, MusicgenModelPatcher, @@ -912,10 +912,16 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs + def patch_model_for_export( + self, + model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> "ModelPatcher": + return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) + class CLIPOnnxConfig(TextAndVisionOnnxConfig): NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig - DEFAULT_ONNX_OPSET = 14 @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -934,6 +940,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]: "image_embeds": {0: "image_batch_size"}, } + def patch_model_for_export( + self, + model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> "ModelPatcher": + return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) + class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig): @property @@ -984,7 +997,7 @@ def patch_model_for_export( model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], model_kwargs: Optional[Dict[str, Any]] = None, ) -> "ModelPatcher": - return CLIPTextModelPatcher(self, model, model_kwargs=model_kwargs) + return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig): @@ -1000,6 +1013,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs + def patch_model_for_export( + self, + model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> "ModelPatcher": + return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) + class UNetOnnxConfig(VisionOnnxConfig): ATOL_FOR_VALIDATION = 1e-3 diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index c49659ba9a..080ac4e8af 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -1133,7 +1133,7 @@ def __init__( self._update_causal_mask_original = self._model._update_causal_mask -class CLIPTextModelPatcher(ModelPatcher): +class CLIPModelPatcher(ModelPatcher): def __enter__(self): super().__enter__()