Skip to content

Commit

Permalink
Whisper ONNX audio-classification (#1727)
Browse files Browse the repository at this point in the history
* support whisper onnx audio-classification

* fix tests

* remove unnecessary method
  • Loading branch information
fxmarty authored Feb 28, 2024
1 parent dfca3fd commit bb21ae7
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 5 deletions.
12 changes: 9 additions & 3 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,9 +1342,15 @@ class WhisperOnnxConfig(AudioToTextOnnxConfig):

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False:
common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2"
if self.task == "audio-classification":
common_inputs = {"input_features": {0: "batch_size"}}
else:
common_inputs = super().inputs
if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis.

if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False:
common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2"
return common_inputs

@property
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,7 @@ class TasksManager:
"whisper": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"audio-classification",
"automatic-speech-recognition",
"automatic-speech-recognition-with-past",
onnx="WhisperOnnxConfig",
Expand Down
30 changes: 28 additions & 2 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,29 @@ class ORTModelForAudioClassification(ORTModel):

auto_model_class = AutoModelForAudioClassification

def __init__(
self,
model: ort.InferenceSession,
config: "PretrainedConfig",
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
**kwargs,
):
super().__init__(
model=model,
config=config,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
**kwargs,
)

if config.model_type == "whisper":
self.input_name = "input_features"
else:
self.input_name = "input_values"

@add_start_docstrings_to_model_forward(
ONNX_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ AUDIO_CLASSIFICATION_EXAMPLE.format(
Expand All @@ -1846,6 +1869,9 @@ def forward(
attenton_mask: Optional[torch.Tensor] = None,
**kwargs,
):
if input_values is None:
# Whisper uses input_features and not input_values.
input_values = kwargs["input_features"]
use_torch = isinstance(input_values, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)
if self.device.type == "cuda" and self.use_io_binding:
Expand All @@ -1864,11 +1890,11 @@ def forward(
if use_torch:
# converts pytorch inputs into numpy inputs for onnx
onnx_inputs = {
"input_values": input_values.cpu().detach().numpy(),
self.input_name: input_values.cpu().detach().numpy(),
}
else:
onnx_inputs = {
"input_values": input_values,
self.input_name: input_values,
}

# run inference
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3127,6 +3127,7 @@ class ORTModelForAudioClassificationIntegrationTest(ORTModelTestMixin):
"wavlm",
"wav2vec2",
"wav2vec2-conformer",
"whisper",
]

FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES}
Expand Down

0 comments on commit bb21ae7

Please sign in to comment.