Skip to content

Commit

Permalink
Fix compatibility with transformers v4.41.0 for ONNX (#1860)
Browse files Browse the repository at this point in the history
* bump transformers

* update default onnx opset

* style

* save export for model with invalid generation config

* set minimum onnx opset

* update setup
  • Loading branch information
echarlaix authored May 23, 2024
1 parent e0f5812 commit cc9889b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 28 deletions.
7 changes: 6 additions & 1 deletion optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,12 @@ def onnx_export_from_model(
model.config.save_pretrained(output)
generation_config = getattr(model, "generation_config", None)
if generation_config is not None:
generation_config.save_pretrained(output)
# since v4.41.0 an exceptions will be raised when saving a generation config considered invalid
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/generation/configuration_utils.py#L697
try:
generation_config.save_pretrained(output)
except Exception as exception:
logger.warning(f"The generation config is invalid and will not be saved : {exception}")

model_name_or_path = model.config._name_or_path
maybe_save_preprocessors(model_name_or_path, output)
Expand Down
66 changes: 39 additions & 27 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
class BertOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -114,42 +115,44 @@ def inputs(self) -> Dict[str, Dict[int, str]]:


class AlbertOnnxConfig(BertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class ConvBertOnnxConfig(BertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class ElectraOnnxConfig(BertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class RoFormerOnnxConfig(BertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class SqueezeBertOnnxConfig(BertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class MobileBertOnnxConfig(BertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class NystromformerOnnxConfig(BertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class XLMOnnxConfig(BertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class SplinterOnnxConfig(BertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class DistilBertOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "multiple-choice":
Expand All @@ -172,7 +175,7 @@ class CamembertOnnxConfig(DistilBertOnnxConfig):


class FlaubertOnnxConfig(BertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class IBertOnnxConfig(DistilBertOnnxConfig):
Expand All @@ -195,6 +198,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:


class MarkupLMOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTextInputGenerator,
DummyXPathSeqInputGenerator,
Expand Down Expand Up @@ -706,6 +710,7 @@ class MarianOnnxConfig(BartOnnxConfig):
class ViTOnnxConfig(VisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
MIN_TORCH_VERSION = version.parse("1.11")
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -725,36 +730,38 @@ class CvTOnnxConfig(ViTOnnxConfig):


class LevitOnnxConfig(ViTOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class DeiTOnnxConfig(ViTOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class BeitOnnxConfig(ViTOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class ConvNextOnnxConfig(ViTOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class ConvNextV2OnnxConfig(ViTOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class MobileViTOnnxConfig(ViTOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 11


class RegNetOnnxConfig(ViTOnnxConfig):
# This config has the same inputs as ViTOnnxConfig
pass
DEFAULT_ONNX_OPSET = 11


class ResNetOnnxConfig(ViTOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
DEFAULT_ONNX_OPSET = 11


class DetrOnnxConfig(ViTOnnxConfig):
Expand All @@ -776,28 +783,29 @@ class TableTransformerOnnxConfig(DetrOnnxConfig):


class YolosOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 12
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class SwinOnnxConfig(ViTOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class Swin2srOnnxConfig(SwinOnnxConfig):
pass


class DptOnnxConfig(ViTOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class GlpnOnnxConfig(ViTOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class PoolFormerOnnxConfig(ViTOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
ATOL_FOR_VALIDATION = 2e-3
DEFAULT_ONNX_OPSET = 11


class SegformerOnnxConfig(YolosOnnxConfig):
Expand All @@ -806,6 +814,7 @@ class SegformerOnnxConfig(YolosOnnxConfig):

class MobileNetV1OnnxConfig(ViTOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 11

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -817,7 +826,7 @@ class MobileNetV2OnnxConfig(MobileNetV1OnnxConfig):


class DonutSwinOnnxConfig(ViTOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class TimmDefaultOnnxConfig(ViTOnnxConfig):
Expand Down Expand Up @@ -1191,12 +1200,13 @@ class Data2VecTextOnnxConfig(DistilBertOnnxConfig):


class Data2VecVisionOnnxConfig(ViTOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class Data2VecAudioOnnxConfig(AudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfig
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class PerceiverDummyInputGenerator(DummyVisionInputGenerator):
Expand Down Expand Up @@ -1292,30 +1302,31 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):

class HubertOnnxConfig(AudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfig
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class Wav2Vec2OnnxConfig(HubertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class Wav2Vec2ConformerOnnxConfig(HubertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 11


class SEWOnnxConfig(HubertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class SEWDOnnxConfig(HubertOnnxConfig):
DEFAULT_ONNX_OPSET = 12


class UniSpeechOnnxConfig(HubertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class UniSpeechSATOnnxConfig(HubertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class WavLMOnnxConfig(HubertOnnxConfig):
Expand Down Expand Up @@ -1344,6 +1355,7 @@ class ASTOnnxConfig(OnnxConfig):
)
DUMMY_INPUT_GENERATOR_CLASSES = (ASTDummyAudioInputGenerator,)
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand Down

0 comments on commit cc9889b

Please sign in to comment.