From 6e4f116c0f03ebe40acc4667257898474fb52939 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 27 Feb 2024 16:10:08 +0100 Subject: [PATCH] support whisper onnx audio-classification --- optimum/exporters/onnx/model_configs.py | 12 +++++++++--- optimum/exporters/tasks.py | 1 + tests/onnxruntime/test_modeling.py | 1 + 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 0e37259f7b..9bbfacf235 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1342,9 +1342,15 @@ class WhisperOnnxConfig(AudioToTextOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: - common_inputs = super().inputs - if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False: - common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2" + if self.task == "audio-classification": + common_inputs = {"input_features": {0: "batch_size"}} + else: + common_inputs = super().inputs + if self._behavior is not ConfigBehavior.DECODER: + common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis. + + if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False: + common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2" return common_inputs @property diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 8c87da6cc6..42cae6cc63 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1036,6 +1036,7 @@ class TasksManager: "whisper": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", + "audio-classification", "automatic-speech-recognition", "automatic-speech-recognition-with-past", onnx="WhisperOnnxConfig", diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index b9f7cdcd4d..6615a634df 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -3127,6 +3127,7 @@ class ORTModelForAudioClassificationIntegrationTest(ORTModelTestMixin): "wavlm", "wav2vec2", "wav2vec2-conformer", + "whisper", ] FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}