Skip to content

Commit

Permalink
Fix Windows and onnx dtype compatibility (#1886)
Browse files Browse the repository at this point in the history
* fix pkv and audio

* add t5 test

* fix seq2seq

* fix vision2seq tests as it seems to have had always outputed kv cache in torch format before

* fix folder deletion on windows

* fix temporary directory removal on windows

* remove attention_mask creation as ORTModelForxxx's corresponding processors will create it

* remove_directory utility function
  • Loading branch information
IlyasMoutawwakil authored Jun 24, 2024
1 parent 8b43dd2 commit aad4b8b
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 485 deletions.
124 changes: 30 additions & 94 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Defines the base classes that are used to perform inference with ONNX Runtime of Transformers models."""

from abc import abstractmethod
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, Union
from typing import Dict, Optional, Set, Tuple, Union

import numpy as np
import torch
Expand All @@ -24,22 +24,22 @@

from ..utils import NormalizedConfigManager
from ..utils.logging import warn_once
from .modeling_ort import ORTModel
from .utils import get_ordered_input_names, logging


logger = logging.get_logger(__name__)


if TYPE_CHECKING:
from .modeling_ort import ORTModel


class ORTModelPart:
"""
For multi-file ONNX models, such as encoder-decoder models, represents a part of the model.
It has its own `onnxruntime.InferenceSession`, and can perform a forward pass.
"""

_prepare_onnx_inputs = ORTModel._prepare_onnx_inputs
_prepare_onnx_outputs = ORTModel._prepare_onnx_outputs

def __init__(
self,
session: InferenceSession,
Expand All @@ -53,6 +53,8 @@ def __init__(
self.main_input_name = self.parent_model.main_input_name
self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()}
self.output_dtypes = {output_key.name: output_key.type for output_key in session.get_outputs()}

self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward)

Expand Down Expand Up @@ -98,25 +100,13 @@ def forward(

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
if use_torch:
onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()}

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy()
else:
onnx_inputs = {"input_ids": input_ids}
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

# Run inference
outputs = self.session.run(None, onnx_inputs)

last_hidden_state = outputs[self.output_names["last_hidden_state"]]
if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)
last_hidden_state = model_outputs["last_hidden_state"]

return BaseModelOutput(last_hidden_state=last_hidden_state)

Expand Down Expand Up @@ -350,83 +340,29 @@ def forward(
else:
raise ValueError("Unsupported num_pkv")
else:
if use_torch:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
}

# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy()

# Add the decoder_attention_mask inputs when needed
if "decoder_attention_mask" in self.input_names:
onnx_inputs["decoder_attention_mask"] = decoder_attention_mask.cpu().detach().numpy()

# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names:
onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy()

if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()

if "labels" in self.input_names:
# TODO: Any preprocessing like `self._shift_right(labels)`?
onnx_inputs["labels"] = labels.cpu().detach().numpy()

if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch_tensor.cpu().detach().numpy()
else:
onnx_inputs = {
"input_ids": input_ids,
}

# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
onnx_inputs["encoder_hidden_states"] = encoder_hidden_states

# Add the decoder_attention_mask inputs when needed
if "decoder_attention_mask" in self.input_names:
onnx_inputs["decoder_attention_mask"] = decoder_attention_mask

# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names:
onnx_inputs["encoder_attention_mask"] = encoder_attention_mask

if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value

if "labels" in self.input_names:
# TODO: Any preprocessing like `self._shift_right(labels)`?
onnx_inputs["labels"] = labels

if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch_tensor
model_inputs = {
"input_ids": input_ids,
"encoder_hidden_states": encoder_hidden_states,
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
"use_cache_branch": use_cache_branch_tensor,
"labels": labels,
}
if past_key_values is not None:
model_inputs.update(zip(self.key_value_input_names, past_key_values))

# Run inference
outputs = self.session.run(None, onnx_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

# TODO: using two loops here is probably unefficient
# TODO: using a new variable out_past_key_values is memory inefficient,
# past_key_values is not used anymore at this point
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = tuple(
torch.from_numpy(outputs[self.output_names[key]]).to(self.device)
for key in self.key_value_output_names
)

logits = outputs[self.output_names["logits"]]
if use_torch:
logits = torch.from_numpy(logits).to(self.device)
out_past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names)

loss = None
if "loss" in self.output_names:
loss = outputs[self.output_names["loss"]]
if use_torch:
loss = torch.from_numpy(loss).to(self.device)
loss = model_outputs.get("loss", None)
logits = model_outputs["logits"]

# TODO: this is extremely ugly and unreadable. What if cross-attention k/v change?
# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
Expand Down
73 changes: 30 additions & 43 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
if check_if_transformers_greater("4.25.0"):
from transformers.generation import GenerationMixin
else:
from transformers.generation_utils import GenerationMixin
from transformers.generation_utils import GenerationMixin # type: ignore # noqa: F401


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -139,15 +139,16 @@ def __init__(

self.num_pkv = 2
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.key_value_input_names = [key for key in self.inputs_names if (".key" in key) or (".value" in key)]
self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)]
self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)]
self.use_cache = len(self.key_value_input_names) > 0

if generation_config is None:
generation_config = GenerationConfig.from_model_config(config)

self.generation_config = generation_config
self.onnx_paths = [self.model_path]
self.use_merged = "use_cache_branch" in self.inputs_names
self.use_merged = "use_cache_branch" in self.input_names
self.model_type = self.config.model_type

self.use_fp16 = False
Expand All @@ -160,7 +161,7 @@ def __init__(

# Reference: https://github.com/huggingface/optimum/pull/1381
model_type = config.model_type.replace("_", "-")
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.inputs_names:
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.input_names:
logger.warning(
f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. "
"We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support."
Expand Down Expand Up @@ -202,7 +203,6 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)

inputs = {}
known_output_shapes = {}
use_cache_branch = None
loss = None
Expand All @@ -226,10 +226,10 @@ def forward(
# I suspect the reason is the contiguous python list that messes something up?
model_inputs = [input_ids.contiguous()]

if "attention_mask" in self.inputs_names:
if "attention_mask" in self.input_names:
model_inputs.append(attention_mask)

if "position_ids" in self.inputs_names:
if "position_ids" in self.input_names:
if position_ids is None:
raise ValueError("position_ids was not passed but is a required input for this ONNX model.")
model_inputs.append(position_ids.contiguous())
Expand All @@ -240,12 +240,11 @@ def forward(
if use_cache_branch is not None:
model_inputs.append(use_cache_branch)

if "labels" in self.inputs_names:
if "labels" in self.input_names:
model_inputs.append(labels)
known_output_shapes.update({"loss": []})

io_binding, output_shapes, output_buffers = self._prepare_io_binding(
self.model,
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
*model_inputs,
known_output_shapes=known_output_shapes,
ordered_input_names=self._ordered_input_names,
Expand All @@ -259,53 +258,41 @@ def forward(
io_binding.synchronize_outputs()

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2)
past_key_values = ()
for name in self.key_value_output_names:
past_key_values += (output_buffers[name].view(output_shapes[name]),)
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2 for the self-attention)
past_key_values = tuple(
output_buffers[name].view(output_shapes[name]) for name in self.key_value_output_names
)

logits = output_buffers["logits"].view(output_shapes["logits"])

if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])
else:
inputs["input_ids"] = input_ids.cpu().detach().numpy() if use_torch else input_ids

if "attention_mask" in self.inputs_names:
inputs["attention_mask"] = attention_mask.cpu().detach().numpy() if use_torch else attention_mask

if "labels" in self.inputs_names:
inputs["labels"] = labels.cpu().detach().numpy() if use_torch else labels

if "position_ids" in self.inputs_names:
if position_ids is None:
raise ValueError("position_ids was not passed but is a required input for this ONNX model.")
inputs["position_ids"] = position_ids.cpu().detach().numpy() if use_torch else position_ids

# Add the past_key_values to the decoder inputs
model_inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"use_cache_branch": use_cache_branch,
"labels": labels,
}
if past_key_values is not None:
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
inputs[input_name] = past_key_value.cpu().detach().numpy() if use_torch else past_key_value
model_inputs.update(
zip(self.key_value_input_names, past_key_values),
)

if use_cache_branch is not None:
inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() if use_torch else use_cache_branch
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

outputs = self.model.run(None, inputs)
loss = model_outputs.get("loss", None)
logits = model_outputs["logits"]

if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 for the self-attention)
past_key_values = tuple(
torch.from_numpy(outputs[self.output_names[key]]).to(self.device)
for key in self.key_value_output_names
)

logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device)
if "loss" in self.output_names:
loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device)
past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names)

if self.use_cache and self.model_type != "gpt_bigcode":
# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and
# per decoder layer
# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and per decoder layer
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
)
Expand Down
Loading

0 comments on commit aad4b8b

Please sign in to comment.