Skip to content

Commit

Permalink
patch all clip variants
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jul 25, 2024
1 parent c2a5c03 commit b610212
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
26 changes: 23 additions & 3 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
)
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import (
CLIPTextModelPatcher,
CLIPModelPatcher,
FalconModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
Expand Down Expand Up @@ -912,10 +912,16 @@ def outputs(self) -> Dict[str, Dict[int, str]]:

return common_outputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class CLIPOnnxConfig(TextAndVisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -934,6 +940,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
"image_embeds": {0: "image_batch_size"},
}

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
@property
Expand Down Expand Up @@ -984,7 +997,7 @@ def patch_model_for_export(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPTextModelPatcher(self, model, model_kwargs=model_kwargs)
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig):
Expand All @@ -1000,6 +1013,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:

return common_outputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ def __init__(
self._update_causal_mask_original = self._model._update_causal_mask


class CLIPTextModelPatcher(ModelPatcher):
class CLIPModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()

Expand Down

0 comments on commit b610212

Please sign in to comment.