Skip to content

Commit

Permalink
fix pkv and audio
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jun 3, 2024
1 parent 949743e commit 13a6650
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
7 changes: 5 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,15 @@ def forward(
else:
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
"attention_mask": attention_mask,
"use_cache_branch": use_cache_branch,
"labels": labels,
}
if past_key_values is not None:
model_inputs.update(
zip(self.key_value_input_names, past_key_values),
)

onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
Expand Down
25 changes: 10 additions & 15 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,7 +1843,16 @@ def forward(
input_features: Optional[Union[torch.Tensor, np.ndarray]] = None,
**kwargs,
):
use_torch = isinstance(input_values, torch.Tensor)
if self.input_name == "input_features":
assert input_features is not None, "input_features must be provided for this model"
main_input = input_features
elif self.input_name == "input_values":
assert input_values is not None, "input_values must be provided for this model"
main_input = input_values
else:
raise ValueError(f"Input {self.input_name} not supported for Audio Classification")

use_torch = isinstance(main_input, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

if attention_mask is None:
Expand All @@ -1852,20 +1861,6 @@ def forward(
else:
attention_mask = np.ones_like(input_values)

if self.input_name == "input_features":
if input_features is None:
raise ValueError("input_features must be provided for whisper model")

main_input = input_features
elif self.input_name == "input_values":
if input_values is None:
raise ValueError("input_values must be provided for this model")

main_input = input_values

else:
raise ValueError(f"Input {self.input_name} not supported for Audio Classification")

if self.device.type == "cuda" and self.use_io_binding:
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
main_input,
Expand Down

0 comments on commit 13a6650

Please sign in to comment.