diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 0e37259f7b..9bbfacf235 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1342,9 +1342,15 @@ class WhisperOnnxConfig(AudioToTextOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: - common_inputs = super().inputs - if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False: - common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2" + if self.task == "audio-classification": + common_inputs = {"input_features": {0: "batch_size"}} + else: + common_inputs = super().inputs + if self._behavior is not ConfigBehavior.DECODER: + common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis. + + if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False: + common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2" return common_inputs @property diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 2f5794eecd..57ee54c2c7 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1062,6 +1062,7 @@ class TasksManager: "whisper": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", + "audio-classification", "automatic-speech-recognition", "automatic-speech-recognition-with-past", onnx="WhisperOnnxConfig", diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 413acbc233..4ebf318aa7 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -1832,6 +1832,29 @@ class ORTModelForAudioClassification(ORTModel): auto_model_class = AutoModelForAudioClassification + def __init__( + self, + model: ort.InferenceSession, + config: "PretrainedConfig", + use_io_binding: Optional[bool] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + **kwargs, + ): + super().__init__( + model=model, + config=config, + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + preprocessors=preprocessors, + **kwargs, + ) + + if config.model_type == "whisper": + self.input_name = "input_features" + else: + self.input_name = "input_values" + @add_start_docstrings_to_model_forward( ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") + AUDIO_CLASSIFICATION_EXAMPLE.format( @@ -1846,6 +1869,9 @@ def forward( attenton_mask: Optional[torch.Tensor] = None, **kwargs, ): + if input_values is None: + # Whisper uses input_features and not input_values. + input_values = kwargs["input_features"] use_torch = isinstance(input_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) if self.device.type == "cuda" and self.use_io_binding: @@ -1864,11 +1890,11 @@ def forward( if use_torch: # converts pytorch inputs into numpy inputs for onnx onnx_inputs = { - "input_values": input_values.cpu().detach().numpy(), + self.input_name: input_values.cpu().detach().numpy(), } else: onnx_inputs = { - "input_values": input_values, + self.input_name: input_values, } # run inference diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index b9f7cdcd4d..6615a634df 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -3127,6 +3127,7 @@ class ORTModelForAudioClassificationIntegrationTest(ORTModelTestMixin): "wavlm", "wav2vec2", "wav2vec2-conformer", + "whisper", ] FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}