Skip to content

Commit

Permalink
Merge branch 'main' into LayoutLMv3-TFLite-conversion-support
Browse files Browse the repository at this point in the history
  • Loading branch information
salmanmaq committed Jun 24, 2024
2 parents aa00dd6 + aad4b8b commit b24bda2
Show file tree
Hide file tree
Showing 22 changed files with 513 additions and 549 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ jobs:
pytest onnxruntime -m "run_in_series" --durations=0 -vvvv -s
- name: Test with pytest (in parallel)
env:
FXMARTYCLONE_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
working-directory: tests
run: |
pytest onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto
17 changes: 17 additions & 0 deletions .github/workflows/trufflehog.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
on:
push:

name: Secret Leaks

jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main


3 changes: 1 addition & 2 deletions optimum/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@
from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
from .env import EnvironmentCommand
from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand
from .onnxruntime import ONNXRuntimeCommand, ONNXRuntimeOptimizeCommand, ONNXRuntimeQuantizeCommand
from .optimum_cli import register_optimum_cli_subcommand
from .optimum_cli import optimum_cli_subcommand
57 changes: 51 additions & 6 deletions optimum/commands/optimum_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,57 @@
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union

from ..subpackages import load_subpackages
from ..utils import logging
from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand
from .env import EnvironmentCommand
from .export import ExportCommand
from .onnxruntime import ONNXRuntimeCommand


logger = logging.get_logger()

OPTIMUM_CLI_SUBCOMMANDS = [ExportCommand, EnvironmentCommand, ONNXRuntimeCommand]
# The table below contains the optimum-cli root subcommands provided by the optimum package
OPTIMUM_CLI_ROOT_SUBCOMMANDS = [ExportCommand, EnvironmentCommand]

# The table below is dynamically populated when loading subpackages
_OPTIMUM_CLI_SUBCOMMANDS = []


def optimum_cli_subcommand(parent_command: Optional[Type[BaseOptimumCLICommand]] = None):
"""
A decorator to declare optimum-cli subcommands.
The declaration of an optimum-cli subcommand looks like this:
```
@optimum_cli_subcommand()
class MySubcommand(BaseOptimumCLICommand):
<implementation>
```
or
```
@optimum_cli_subcommand(ExportCommand)
class MySubcommand(BaseOptimumCLICommand):
<implementation>
```
Args:
parent_command: (`Optional[Type[BaseOptimumCLICommand]]`):
The class of the parent command or None if this is a top-level command. Defaults to None.
"""

if parent_command is not None and not issubclass(parent_command, BaseOptimumCLICommand):
raise ValueError(f"The parent command {parent_command} must be a subclass of BaseOptimumCLICommand")

def wrapper(subcommand):
if not issubclass(subcommand, BaseOptimumCLICommand):
raise ValueError(f"The subcommand {subcommand} must be a subclass of BaseOptimumCLICommand")
_OPTIMUM_CLI_SUBCOMMANDS.append((subcommand, parent_command))

return wrapper


def resolve_command_to_command_instance(
Expand Down Expand Up @@ -137,15 +178,19 @@ def main():
root = RootOptimumCLICommand("Optimum CLI tool", usage="optimum-cli")
parser = root.parser

for subcommand_cls in OPTIMUM_CLI_SUBCOMMANDS:
for subcommand_cls in OPTIMUM_CLI_ROOT_SUBCOMMANDS:
register_optimum_cli_subcommand(subcommand_cls, parent_command=root)

commands_in_register = dynamic_load_commands_in_register()
# Load subpackages to give them a chance to declare their own subcommands
load_subpackages()

# Register subcommands declared by the subpackages or found in the register files under commands/register
commands_to_register = _OPTIMUM_CLI_SUBCOMMANDS + dynamic_load_commands_in_register()
command2command_instance = resolve_command_to_command_instance(
root, [parent_command_cls for _, parent_command_cls in commands_in_register if parent_command_cls is not None]
root, [parent_command_cls for _, parent_command_cls in commands_to_register if parent_command_cls is not None]
)

for command_or_command_info, parent_command in commands_in_register:
for command_or_command_info, parent_command in commands_to_register:
if parent_command is None:
parent_command_instance = root
else:
Expand Down
41 changes: 7 additions & 34 deletions optimum/gptq/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,40 +182,11 @@ def get_c4_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train")


def get_ptb(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if split == "train":
data = load_dataset("ptb_text_only", "penn_treebank", split="train")
elif split == "validation":
data = load_dataset("ptb_text_only", "penn_treebank", split="validation")

enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt")

dataset = []
for _ in range(nsamples):
i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = enc.input_ids[:, i:j]
attention_mask = torch.ones_like(inp)
dataset.append({"input_ids": inp, "attention_mask": attention_mask})

return dataset
raise RuntimeError("Loading the `ptb` dataset was deprecated")


def get_ptb_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"):
if split == "train":
data = load_dataset("ptb_text_only", "penn_treebank", split="train")
elif split == "validation":
data = load_dataset("ptb_text_only", "penn_treebank", split="test")

enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt")

dataset = []
for _ in range(nsamples):
i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = enc.input_ids[:, i:j]
attention_mask = torch.ones_like(inp)
dataset.append({"input_ids": inp, "attention_mask": attention_mask})
return dataset
raise RuntimeError("Loading the `ptb` dataset was deprecated")


def get_dataset(
Expand All @@ -226,7 +197,7 @@ def get_dataset(
Args:
dataset_name (`str`):
Dataset name. Available options are `['wikitext2', 'c4', 'ptb', 'c4-new', 'ptb_new']`.
Dataset name. Available options are `['wikitext2', 'c4', 'c4-new']`.
tokenizer (`Any`):
Tokenizer of the model
nsamples (`int`, defaults to `128`):
Expand All @@ -247,11 +218,13 @@ def get_dataset(
"wikitext2": get_wikitext2,
"c4": get_c4,
"c4-new": get_c4_new,
"ptb": get_ptb,
"ptb-new": get_ptb_new,
}
if split not in ["train", "validation"]:
raise ValueError(f"The split need to be 'train' or 'validation' but found {split}")
if dataset_name in {"ptb", "ptb-new"}:
raise ValueError(
f"{dataset_name} dataset was deprecated, only the following dataset are supported : {list(get_dataset_map)}"
)
if dataset_name not in get_dataset_map:
raise ValueError(f"Expected a value in {list(get_dataset_map.keys())} but found {dataset_name}")
get_dataset_fn = get_dataset_map[dataset_name]
Expand Down
10 changes: 8 additions & 2 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,10 @@ def store_input_hook(_, input, *args):
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
data[k] = v.to(0)
if not has_device_map or device.type == "cpu":
data[k] = v.to(0)
else:
data[k] = v.to(device)
try:
model(**data)
except ValueError:
Expand All @@ -458,7 +461,10 @@ def store_input_hook(_, input, *args):
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
data[k] = v.to(0)
if not has_device_map or device.type == "cpu":
data[k] = v.to(0)
else:
data[k] = v.to(device)
try:
model(**data)
except ValueError:
Expand Down
124 changes: 30 additions & 94 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Defines the base classes that are used to perform inference with ONNX Runtime of Transformers models."""

from abc import abstractmethod
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, Union
from typing import Dict, Optional, Set, Tuple, Union

import numpy as np
import torch
Expand All @@ -24,22 +24,22 @@

from ..utils import NormalizedConfigManager
from ..utils.logging import warn_once
from .modeling_ort import ORTModel
from .utils import get_ordered_input_names, logging


logger = logging.get_logger(__name__)


if TYPE_CHECKING:
from .modeling_ort import ORTModel


class ORTModelPart:
"""
For multi-file ONNX models, such as encoder-decoder models, represents a part of the model.
It has its own `onnxruntime.InferenceSession`, and can perform a forward pass.
"""

_prepare_onnx_inputs = ORTModel._prepare_onnx_inputs
_prepare_onnx_outputs = ORTModel._prepare_onnx_outputs

def __init__(
self,
session: InferenceSession,
Expand All @@ -53,6 +53,8 @@ def __init__(
self.main_input_name = self.parent_model.main_input_name
self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()}
self.output_dtypes = {output_key.name: output_key.type for output_key in session.get_outputs()}

self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward)

Expand Down Expand Up @@ -98,25 +100,13 @@ def forward(

last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"])
else:
if use_torch:
onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()}

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy()
else:
onnx_inputs = {"input_ids": input_ids}
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names:
onnx_inputs["attention_mask"] = attention_mask
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

# Run inference
outputs = self.session.run(None, onnx_inputs)

last_hidden_state = outputs[self.output_names["last_hidden_state"]]
if use_torch:
last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device)
last_hidden_state = model_outputs["last_hidden_state"]

return BaseModelOutput(last_hidden_state=last_hidden_state)

Expand Down Expand Up @@ -350,83 +340,29 @@ def forward(
else:
raise ValueError("Unsupported num_pkv")
else:
if use_torch:
onnx_inputs = {
"input_ids": input_ids.cpu().detach().numpy(),
}

# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy()

# Add the decoder_attention_mask inputs when needed
if "decoder_attention_mask" in self.input_names:
onnx_inputs["decoder_attention_mask"] = decoder_attention_mask.cpu().detach().numpy()

# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names:
onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy()

if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()

if "labels" in self.input_names:
# TODO: Any preprocessing like `self._shift_right(labels)`?
onnx_inputs["labels"] = labels.cpu().detach().numpy()

if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch_tensor.cpu().detach().numpy()
else:
onnx_inputs = {
"input_ids": input_ids,
}

# Add the encoder_hidden_states inputs when needed
if "encoder_hidden_states" in self.input_names:
onnx_inputs["encoder_hidden_states"] = encoder_hidden_states

# Add the decoder_attention_mask inputs when needed
if "decoder_attention_mask" in self.input_names:
onnx_inputs["decoder_attention_mask"] = decoder_attention_mask

# Add the encoder_attention_mask inputs when needed
if "encoder_attention_mask" in self.input_names:
onnx_inputs["encoder_attention_mask"] = encoder_attention_mask

if past_key_values is not None:
# Add the past_key_values to the decoder inputs
for input_name, past_key_value in zip(self.key_value_input_names, past_key_values):
onnx_inputs[input_name] = past_key_value

if "labels" in self.input_names:
# TODO: Any preprocessing like `self._shift_right(labels)`?
onnx_inputs["labels"] = labels

if self.parent_model.use_merged is True:
onnx_inputs["use_cache_branch"] = use_cache_branch_tensor
model_inputs = {
"input_ids": input_ids,
"encoder_hidden_states": encoder_hidden_states,
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
"use_cache_branch": use_cache_branch_tensor,
"labels": labels,
}
if past_key_values is not None:
model_inputs.update(zip(self.key_value_input_names, past_key_values))

# Run inference
outputs = self.session.run(None, onnx_inputs)
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

# TODO: using two loops here is probably unefficient
# TODO: using a new variable out_past_key_values is memory inefficient,
# past_key_values is not used anymore at this point
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
# self-attention layer and 2 to the cross-attention layer)
out_past_key_values = tuple(
torch.from_numpy(outputs[self.output_names[key]]).to(self.device)
for key in self.key_value_output_names
)

logits = outputs[self.output_names["logits"]]
if use_torch:
logits = torch.from_numpy(logits).to(self.device)
out_past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names)

loss = None
if "loss" in self.output_names:
loss = outputs[self.output_names["loss"]]
if use_torch:
loss = torch.from_numpy(loss).to(self.device)
loss = model_outputs.get("loss", None)
logits = model_outputs["logits"]

# TODO: this is extremely ugly and unreadable. What if cross-attention k/v change?
# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
Expand Down
Loading

0 comments on commit b24bda2

Please sign in to comment.