Skip to content

Commit

Permalink
ONNX: disable text-generation models for sequence classification & fi…
Browse files Browse the repository at this point in the history
…xes for transformers 4.32 (#1308)

* fix

* disable more

* fix
  • Loading branch information
fxmarty authored Aug 22, 2023
1 parent 2c1eaf6 commit f600bc6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
12 changes: 6 additions & 6 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ class TasksManager:
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"token-classification",
onnx="GPT2OnnxConfig",
),
Expand All @@ -521,7 +521,7 @@ class TasksManager:
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
"token-classification",
onnx="GPTBigCodeOnnxConfig",
),
Expand All @@ -531,15 +531,15 @@ class TasksManager:
"text-generation",
"text-generation-with-past",
"question-answering",
"text-classification",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
onnx="GPTJOnnxConfig",
),
"gpt-neo": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
onnx="GPTNeoOnnxConfig",
),
"gpt-neox": supported_tasks_mapping(
Expand Down Expand Up @@ -714,15 +714,15 @@ class TasksManager:
"text-generation",
"text-generation-with-past",
"question-answering",
"text-classification",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
onnx="OPTOnnxConfig",
),
"llama": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
# "text-classification", # TODO: maybe reenable once fixed. See: https://github.com/huggingface/optimum/pull/1308
onnx="LlamaOnnxConfig",
),
"pegasus": supported_tasks_mapping(
Expand Down
12 changes: 6 additions & 6 deletions optimum/pipelines/pipelines_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def load_bettertransformer(
load_feature_extractor=None,
SUPPORTED_TASKS=None,
subfolder: str = "",
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: str = "main",
model_kwargs: Optional[Dict[str, Any]] = None,
config: AutoConfig = None,
Expand Down Expand Up @@ -218,7 +218,7 @@ def load_ort_pipeline(
load_feature_extractor,
SUPPORTED_TASKS,
subfolder: str = "",
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: str = "main",
model_kwargs: Optional[Dict[str, Any]] = None,
config: AutoConfig = None,
Expand Down Expand Up @@ -246,7 +246,7 @@ def load_ort_pipeline(
pattern,
glob_pattern="**/*.onnx",
subfolder=subfolder,
use_auth_token=use_auth_token,
use_auth_token=token,
revision=revision,
)
export = len(onnx_files) == 0
Expand Down Expand Up @@ -292,7 +292,7 @@ def pipeline(
tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None,
use_fast: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
token: Optional[Union[str, bool]] = None,
accelerator: Optional[str] = "ort",
revision: Optional[str] = None,
trust_remote_code: Optional[bool] = None,
Expand All @@ -315,7 +315,7 @@ def pipeline(
# copied from transformers.pipelines.__init__.py
hub_kwargs = {
"revision": revision,
"use_auth_token": use_auth_token,
"token": token,
"trust_remote_code": trust_remote_code,
"_commit_hash": None,
}
Expand Down Expand Up @@ -364,6 +364,7 @@ def pipeline(
SUPPORTED_TASKS=supported_tasks,
config=config,
hub_kwargs=hub_kwargs,
token=token,
*model_kwargs,
**kwargs,
)
Expand All @@ -379,6 +380,5 @@ def pipeline(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
use_fast=use_fast,
use_auth_token=use_auth_token,
**kwargs,
)
3 changes: 3 additions & 0 deletions tests/exporters/onnx/test_onnx_config_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def test_onnx_config_with_loss(self):
gc.collect()

def test_onnx_decoder_model_with_config_with_loss(self):
self.skipTest(
"Skipping due to a bug introduced in transformers with https://github.com/huggingface/transformers/pull/24979, argmax on int64 is not supported by ONNX"
)
with tempfile.TemporaryDirectory() as tmp_dir:
# Prepare model and dataset
model_checkpoint = "hf-internal-testing/tiny-random-gpt2"
Expand Down

0 comments on commit f600bc6

Please sign in to comment.