Skip to content

Commit

Permalink
Enable ONNX export of CLIP models with sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 16, 2024
1 parent 9fd9ca5 commit 0409468
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 46 deletions.
31 changes: 2 additions & 29 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
)
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import (
CLIPModelPatcher,
FalconModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
Expand Down Expand Up @@ -907,6 +906,7 @@ class CLIPNormalizedConfig(NormalizedTextAndVisionConfig):

class CLIPVisionModelOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -920,16 +920,10 @@ def outputs(self) -> Dict[str, Dict[int, str]]:

return common_outputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class CLIPOnnxConfig(TextAndVisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -948,13 +942,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
"image_embeds": {0: "image_batch_size"},
}

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
@property
Expand Down Expand Up @@ -1000,13 +987,6 @@ def outputs(self) -> Dict[str, Dict[int, str]]:

return common_outputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig):
@property
Expand All @@ -1031,13 +1011,6 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
return dummy_inputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
Expand Down
17 changes: 0 additions & 17 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,20 +1138,3 @@ def __init__(
self._update_causal_mask_original = self._model.model._update_causal_mask
else:
self._update_causal_mask_original = self._model._update_causal_mask


class CLIPModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()

if _transformers_version >= version.parse("4.43"):
from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention

self.original_sdpa_forward, CLIPSdpaAttention.forward = CLIPSdpaAttention.forward, CLIPAttention.forward

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if _transformers_version >= version.parse("4.43"):
from transformers.models.clip.modeling_clip import CLIPSdpaAttention

CLIPSdpaAttention.forward = self.original_sdpa_forward

0 comments on commit 0409468

Please sign in to comment.