From ac951ca0ae49983b1515b84342fea15c3b7ec35c Mon Sep 17 00:00:00 2001 From: Zach Deane-Mayer <581590+zachmayer@users.noreply.github.com> Date: Wed, 5 Jun 2024 02:18:42 -0400 Subject: [PATCH 1/8] ORTOptimizer for the model type Segformer (#1820) * add segformer * black * make format * decoder_hidden_size not a list * tests pass now * use max * use zero --------- Co-authored-by: Zach Deane-Mayer --- optimum/onnxruntime/modeling_ort.py | 11 ++++++++--- optimum/onnxruntime/utils.py | 1 + optimum/utils/normalized_config.py | 15 ++++++++++++++- tests/onnxruntime/test_optimization.py | 2 ++ 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index eb38a7fef1..b65e1d3b29 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -1746,13 +1746,18 @@ class ORTModelForSemanticSegmentation(ORTModel): checkpoint="optimum/segformer-b0-finetuned-ade-512-512", ) ) - def forward(self, **kwargs): - use_torch = isinstance(next(iter(kwargs.values())), torch.Tensor) + def forward( + self, + pixel_values: Union[torch.Tensor, np.ndarray], + **kwargs, + ): + use_torch = isinstance(pixel_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) if self.device.type == "cuda" and self.use_io_binding: io_binding = IOBindingHelper.prepare_io_binding( self, + pixel_values, **kwargs, ordered_input_names=self._ordered_input_names, ) @@ -1769,7 +1774,7 @@ def forward(self, **kwargs): # converts output to namedtuple for pipelines post-processing return SemanticSegmenterOutput(logits=outputs["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch=use_torch, **kwargs) + onnx_inputs = self._prepare_onnx_inputs(use_torch=use_torch, pixel_values=pixel_values, **kwargs) # run inference onnx_outputs = self.model.run(None, onnx_inputs) diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 0e1da447a6..37d0feefcc 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -128,6 +128,7 @@ class ORTConfigManager: "nystromformer": "bert", "pegasus": "bert", "roberta": "bert", + "segformer": "vit", "t5": "bert", "vit": "vit", "whisper": "bart", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 682f70e3ca..81207b7649 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -102,6 +102,19 @@ class NormalizedVisionConfig(NormalizedConfig): INPUT_SIZE = "input_size" +class NormalizedSegformerConfig(NormalizedVisionConfig): + NUM_ATTENTION_HEADS = "num_attention_heads" + HIDDEN_SIZE = "hidden_sizes" + + # If the attribute is a list, return 0 + # 0 means let the optimizer infer the correct value based on the model graph + def __getattr__(self, attr_name): + attr_value = super().__getattr__(attr_name) + if isinstance(attr_value, list): + attr_value = 0 + return attr_value + + class NormalizedTextAndVisionConfig(NormalizedTextConfig, NormalizedVisionConfig): TEXT_CONFIG = None VISION_CONFIG = None @@ -203,7 +216,6 @@ class NormalizedConfigManager: 'owlvit', 'perceiver', 'roformer', - 'segformer', 'squeezebert', 'table-transformer', """ @@ -258,6 +270,7 @@ class NormalizedConfigManager: "regnet": NormalizedVisionConfig, "resnet": NormalizedVisionConfig, "roberta": NormalizedTextConfig, + "segformer": NormalizedSegformerConfig, "speech-to-text": SpeechToTextLikeNormalizedTextConfig, "splinter": NormalizedTextConfig, "t5": T5LikeNormalizedTextConfig, diff --git a/tests/onnxruntime/test_optimization.py b/tests/onnxruntime/test_optimization.py index c9cadbaa82..82109fcd11 100644 --- a/tests/onnxruntime/test_optimization.py +++ b/tests/onnxruntime/test_optimization.py @@ -36,6 +36,7 @@ AutoOptimizationConfig, ORTConfig, ORTModelForImageClassification, + ORTModelForSemanticSegmentation, ORTModelForSequenceClassification, ORTOptimizer, ) @@ -171,6 +172,7 @@ def test_compare_original_seq2seq_model_with_optimized_model(self, model_cls, mo # Contribution note: Please add test models in alphabetical order. Find test models here: https://huggingface.co/hf-internal-testing. SUPPORTED_IMAGE_ARCHITECTURES_WITH_MODEL_ID = ( + (ORTModelForSemanticSegmentation, "hf-internal-testing/tiny-random-segformer"), (ORTModelForImageClassification, "hf-internal-testing/tiny-random-vit"), ) From 113b645dc7d0b7710803f23ffbf937ce6461ed1e Mon Sep 17 00:00:00 2001 From: GoldenTeethCN Date: Thu, 6 Jun 2024 16:42:52 +0800 Subject: [PATCH 2/8] fix: device consistence (#1891) * fix: device consistence * style: make style on ./optimum/gptq/quantizer.py --- optimum/gptq/quantizer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 2c2c9d7e71..902af87bbb 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -432,7 +432,10 @@ def store_input_hook(_, input, *args): for data in dataset: for k, v in data.items(): # put the data on gpu, we won't put them back to cpu - data[k] = v.to(0) + if not has_device_map or device.type == "cpu": + data[k] = v.to(0) + else: + data[k] = v.to(device) try: model(**data) except ValueError: @@ -458,7 +461,10 @@ def store_input_hook(_, input, *args): for data in dataset: for k, v in data.items(): # put the data on gpu, we won't put them back to cpu - data[k] = v.to(0) + if not has_device_map or device.type == "cpu": + data[k] = v.to(0) + else: + data[k] = v.to(device) try: model(**data) except ValueError: From f33f2f1d84f5da1e347d64d90a393e0e02a9ac5a Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 10 Jun 2024 10:06:44 +0200 Subject: [PATCH 3/8] Allow optimum to discover and load subpackages (#1894) As an alternative to directly adding their commands in a register.py file under the root optimum directory, this adds a decorator to declare a subcommand that can be used by subpackages when they are loaded. This will fix the issue of subcommands 'disappearing' when optimum is upgraded without reinstalling the subpackage. The onnxruntime commands are moved into a subpackage loader directory. This subpackage directory is only loaded (and its commands added) when the onnxruntime is available. This avoids wrongly indicating that the onnxruntime commands are available when the package is actually not installed. --- optimum/commands/__init__.py | 3 +- optimum/commands/optimum_cli.py | 57 +++++++++++-- optimum/onnxruntime/subpackage/__init__.py | 1 + .../subpackage/commands}/__init__.py | 2 - .../subpackage/commands}/base.py | 4 +- .../subpackage/commands}/optimize.py | 4 +- .../subpackage/commands}/quantize.py | 6 +- optimum/subpackages.py | 81 +++++++++++++++++++ 8 files changed, 142 insertions(+), 16 deletions(-) create mode 100644 optimum/onnxruntime/subpackage/__init__.py rename optimum/{commands/onnxruntime => onnxruntime/subpackage/commands}/__init__.py (87%) rename optimum/{commands/onnxruntime => onnxruntime/subpackage/commands}/base.py (91%) rename optimum/{commands/onnxruntime => onnxruntime/subpackage/commands}/optimize.py (96%) rename optimum/{commands/onnxruntime => onnxruntime/subpackage/commands}/quantize.py (95%) create mode 100644 optimum/subpackages.py diff --git a/optimum/commands/__init__.py b/optimum/commands/__init__.py index 540ea4dd86..8a2a276d1c 100644 --- a/optimum/commands/__init__.py +++ b/optimum/commands/__init__.py @@ -15,5 +15,4 @@ from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand from .env import EnvironmentCommand from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand -from .onnxruntime import ONNXRuntimeCommand, ONNXRuntimeOptimizeCommand, ONNXRuntimeQuantizeCommand -from .optimum_cli import register_optimum_cli_subcommand +from .optimum_cli import optimum_cli_subcommand diff --git a/optimum/commands/optimum_cli.py b/optimum/commands/optimum_cli.py index 4bae9bb5f8..64a7075c6c 100644 --- a/optimum/commands/optimum_cli.py +++ b/optimum/commands/optimum_cli.py @@ -17,16 +17,57 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Type, Union +from ..subpackages import load_subpackages from ..utils import logging from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand from .env import EnvironmentCommand from .export import ExportCommand -from .onnxruntime import ONNXRuntimeCommand logger = logging.get_logger() -OPTIMUM_CLI_SUBCOMMANDS = [ExportCommand, EnvironmentCommand, ONNXRuntimeCommand] +# The table below contains the optimum-cli root subcommands provided by the optimum package +OPTIMUM_CLI_ROOT_SUBCOMMANDS = [ExportCommand, EnvironmentCommand] + +# The table below is dynamically populated when loading subpackages +_OPTIMUM_CLI_SUBCOMMANDS = [] + + +def optimum_cli_subcommand(parent_command: Optional[Type[BaseOptimumCLICommand]] = None): + """ + A decorator to declare optimum-cli subcommands. + + The declaration of an optimum-cli subcommand looks like this: + + ``` + @optimum_cli_subcommand() + class MySubcommand(BaseOptimumCLICommand): + + ``` + + or + + ``` + @optimum_cli_subcommand(ExportCommand) + class MySubcommand(BaseOptimumCLICommand): + + ``` + + Args: + parent_command: (`Optional[Type[BaseOptimumCLICommand]]`): + The class of the parent command or None if this is a top-level command. Defaults to None. + + """ + + if parent_command is not None and not issubclass(parent_command, BaseOptimumCLICommand): + raise ValueError(f"The parent command {parent_command} must be a subclass of BaseOptimumCLICommand") + + def wrapper(subcommand): + if not issubclass(subcommand, BaseOptimumCLICommand): + raise ValueError(f"The subcommand {subcommand} must be a subclass of BaseOptimumCLICommand") + _OPTIMUM_CLI_SUBCOMMANDS.append((subcommand, parent_command)) + + return wrapper def resolve_command_to_command_instance( @@ -137,15 +178,19 @@ def main(): root = RootOptimumCLICommand("Optimum CLI tool", usage="optimum-cli") parser = root.parser - for subcommand_cls in OPTIMUM_CLI_SUBCOMMANDS: + for subcommand_cls in OPTIMUM_CLI_ROOT_SUBCOMMANDS: register_optimum_cli_subcommand(subcommand_cls, parent_command=root) - commands_in_register = dynamic_load_commands_in_register() + # Load subpackages to give them a chance to declare their own subcommands + load_subpackages() + + # Register subcommands declared by the subpackages or found in the register files under commands/register + commands_to_register = _OPTIMUM_CLI_SUBCOMMANDS + dynamic_load_commands_in_register() command2command_instance = resolve_command_to_command_instance( - root, [parent_command_cls for _, parent_command_cls in commands_in_register if parent_command_cls is not None] + root, [parent_command_cls for _, parent_command_cls in commands_to_register if parent_command_cls is not None] ) - for command_or_command_info, parent_command in commands_in_register: + for command_or_command_info, parent_command in commands_to_register: if parent_command is None: parent_command_instance = root else: diff --git a/optimum/onnxruntime/subpackage/__init__.py b/optimum/onnxruntime/subpackage/__init__.py new file mode 100644 index 0000000000..7029af7132 --- /dev/null +++ b/optimum/onnxruntime/subpackage/__init__.py @@ -0,0 +1 @@ +from .commands import ONNXRuntimeCommand diff --git a/optimum/commands/onnxruntime/__init__.py b/optimum/onnxruntime/subpackage/commands/__init__.py similarity index 87% rename from optimum/commands/onnxruntime/__init__.py rename to optimum/onnxruntime/subpackage/commands/__init__.py index 1b9c24c3b2..44facf5ea5 100644 --- a/optimum/commands/onnxruntime/__init__.py +++ b/optimum/onnxruntime/subpackage/commands/__init__.py @@ -14,5 +14,3 @@ # limitations under the License. from .base import ONNXRuntimeCommand -from .optimize import ONNXRuntimeOptimizeCommand -from .quantize import ONNXRuntimeQuantizeCommand diff --git a/optimum/commands/onnxruntime/base.py b/optimum/onnxruntime/subpackage/commands/base.py similarity index 91% rename from optimum/commands/onnxruntime/base.py rename to optimum/onnxruntime/subpackage/commands/base.py index 53e3245ea4..df4414c19d 100644 --- a/optimum/commands/onnxruntime/base.py +++ b/optimum/onnxruntime/subpackage/commands/base.py @@ -14,11 +14,13 @@ # limitations under the License. """optimum.onnxruntime command-line interface base classes.""" -from .. import BaseOptimumCLICommand, CommandInfo +from optimum.commands import BaseOptimumCLICommand, CommandInfo, optimum_cli_subcommand + from .optimize import ONNXRuntimeOptimizeCommand from .quantize import ONNXRuntimeQuantizeCommand +@optimum_cli_subcommand() class ONNXRuntimeCommand(BaseOptimumCLICommand): COMMAND = CommandInfo( name="onnxruntime", diff --git a/optimum/commands/onnxruntime/optimize.py b/optimum/onnxruntime/subpackage/commands/optimize.py similarity index 96% rename from optimum/commands/onnxruntime/optimize.py rename to optimum/onnxruntime/subpackage/commands/optimize.py index 5890e0a07c..1dd82f0ee2 100644 --- a/optimum/commands/onnxruntime/optimize.py +++ b/optimum/onnxruntime/subpackage/commands/optimize.py @@ -75,8 +75,8 @@ def parse_args(parser: "ArgumentParser"): return parse_args_onnxruntime_optimize(parser) def run(self): - from ...onnxruntime.configuration import AutoOptimizationConfig, ORTConfig - from ...onnxruntime.optimization import ORTOptimizer + from ...configuration import AutoOptimizationConfig, ORTConfig + from ...optimization import ORTOptimizer if self.args.output == self.args.onnx_model: raise ValueError("The output directory must be different than the directory hosting the ONNX model.") diff --git a/optimum/commands/onnxruntime/quantize.py b/optimum/onnxruntime/subpackage/commands/quantize.py similarity index 95% rename from optimum/commands/onnxruntime/quantize.py rename to optimum/onnxruntime/subpackage/commands/quantize.py index 2613cb33ba..6f6d843cc7 100644 --- a/optimum/commands/onnxruntime/quantize.py +++ b/optimum/onnxruntime/subpackage/commands/quantize.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import TYPE_CHECKING -from .. import BaseOptimumCLICommand +from optimum.commands import BaseOptimumCLICommand if TYPE_CHECKING: @@ -69,8 +69,8 @@ def parse_args(parser: "ArgumentParser"): return parse_args_onnxruntime_quantize(parser) def run(self): - from ...onnxruntime.configuration import AutoQuantizationConfig, ORTConfig - from ...onnxruntime.quantization import ORTQuantizer + from ...configuration import AutoQuantizationConfig, ORTConfig + from ...quantization import ORTQuantizer if self.args.output == self.args.onnx_model: raise ValueError("The output directory must be different than the directory hosting the ONNX model.") diff --git a/optimum/subpackages.py b/optimum/subpackages.py new file mode 100644 index 0000000000..8729581521 --- /dev/null +++ b/optimum/subpackages.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import logging +import sys + + +if sys.version_info >= (3, 8): + from importlib import metadata as importlib_metadata +else: + import importlib_metadata +from importlib.util import find_spec, module_from_spec + +from .utils import is_onnxruntime_available + + +logger = logging.getLogger(__name__) + + +def load_namespace_modules(namespace: str, module: str): + """Load modules with a specific name inside a namespace + + This method operates on namespace packages: + https://packaging.python.org/en/latest/guides/packaging-namespace-packages/ + + For each package inside the specified `namespace`, it looks for the specified `module` and loads it. + + Args: + namespace (`str`): + The namespace containing modules to be loaded. + module (`str`): + The name of the module to load in each namespace package. + """ + for dist in importlib_metadata.distributions(): + dist_name = dist.metadata["Name"] + if not dist_name.startswith(f"{namespace}-"): + continue + package_import_name = dist_name.replace("-", ".") + module_import_name = f"{package_import_name}.{module}" + if module_import_name in sys.modules: + # Module already loaded + continue + backend_spec = find_spec(module_import_name) + if backend_spec is None: + continue + try: + imported_module = module_from_spec(backend_spec) + sys.modules[module_import_name] = imported_module + backend_spec.loader.exec_module(imported_module) + logger.debug(f"Successfully loaded {module_import_name}") + except Exception as e: + logger.error(f"An exception occured while loading {module_import_name}: {e}.") + + +def load_subpackages(): + """Load optimum subpackages + + This method goes through packages inside the `optimum` namespace and loads the `subpackage` module if it exists. + + This module is then in charge of registering the subpackage commands. + """ + SUBPACKAGE_LOADER = "subpackage" + load_namespace_modules("optimum", SUBPACKAGE_LOADER) + + # Load subpackages from internal modules not explicitly defined as namespace packages + loader_name = "." + SUBPACKAGE_LOADER + if is_onnxruntime_available(): + importlib.import_module(loader_name, package="optimum.onnxruntime") From 35f636707f18d9c3f996ee31a8d32515424b94af Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Mon, 10 Jun 2024 10:07:16 +0200 Subject: [PATCH 4/8] feat(ci): add trufflehog secrets detector (#1899) --- .github/workflows/trufflehog.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/workflows/trufflehog.yml diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml new file mode 100644 index 0000000000..164b4f2f8f --- /dev/null +++ b/.github/workflows/trufflehog.yml @@ -0,0 +1,23 @@ +on: + push: + +name: Secret Leaks + +permissions: + contents: read + id-token: write + issues: write + pull-requests: write + +jobs: + trufflehog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@main + + From db51410ae5ef4cbde7518cf01a997239dffbde1d Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Mon, 10 Jun 2024 11:42:08 +0200 Subject: [PATCH 5/8] fix(ci): remove unnecessary permissions (#1904) --- .github/workflows/trufflehog.yml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml index 164b4f2f8f..c71afbbb45 100644 --- a/.github/workflows/trufflehog.yml +++ b/.github/workflows/trufflehog.yml @@ -3,12 +3,6 @@ on: name: Secret Leaks -permissions: - contents: read - id-token: write - issues: write - pull-requests: write - jobs: trufflehog: runs-on: ubuntu-latest From f4809307e409d5ce698364ad48b69d38e0c406e9 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 14 Jun 2024 16:33:33 +0200 Subject: [PATCH 6/8] Remove read token (#1903) * remove read token * rename var & use org model * style & remove token * fix failing tests on datasets release --- .github/workflows/test_onnxruntime.yml | 2 ++ optimum/utils/testing_utils.py | 3 --- tests/onnxruntime/test_modeling.py | 11 +++++++---- tests/utils/test_task_processors.py | 7 ++++++- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test_onnxruntime.yml b/.github/workflows/test_onnxruntime.yml index 4893b681a6..291a3b0833 100644 --- a/.github/workflows/test_onnxruntime.yml +++ b/.github/workflows/test_onnxruntime.yml @@ -50,6 +50,8 @@ jobs: pytest onnxruntime -m "run_in_series" --durations=0 -vvvv -s - name: Test with pytest (in parallel) + env: + FXMARTYCLONE_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} working-directory: tests run: | pytest onnxruntime -m "not run_in_series" --durations=0 -vvvv -s -n auto diff --git a/optimum/utils/testing_utils.py b/optimum/utils/testing_utils.py index f1c2f668e3..a7c2b8bb05 100644 --- a/optimum/utils/testing_utils.py +++ b/optimum/utils/testing_utils.py @@ -36,9 +36,6 @@ # Used to test the hub USER = "__DUMMY_OPTIMUM_USER__" -# Not critical, only usable on the sandboxed CI instance. -TOKEN = "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" - def flatten_dict(dictionary: Dict): """ diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 3fe2c5e14d..7b2c8a66b9 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -938,11 +938,14 @@ def test_stable_diffusion_model_on_rocm_ep_str(self): self.assertListEqual(model.providers, ["ROCMExecutionProvider", "CPUExecutionProvider"]) def test_load_model_from_hub_private(self): - subprocess.run("huggingface-cli logout", shell=True) - # Read token of fxmartyclone (dummy user). - token = "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" + token = os.environ.get("HF_HUB_READ_TOKEN", None) - model = ORTModelForCustomTasks.from_pretrained("fxmartyclone/tiny-onnx-private-2", use_auth_token=token) + if token is None: + self.skipTest("Test requires a token for fxmartyclone in the environment variable `HF_HUB_READ_TOKEN`.") + + model = ORTModelForCustomTasks.from_pretrained( + "optimum-internal-testing/tiny-random-phi-private", use_auth_token=token + ) self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) diff --git a/tests/utils/test_task_processors.py b/tests/utils/test_task_processors.py index af89aec2b9..1656704807 100644 --- a/tests/utils/test_task_processors.py +++ b/tests/utils/test_task_processors.py @@ -50,7 +50,7 @@ "dataset_data_keys": {"question": "question", "context": "answer"}, }, "image-classification": { - "dataset_args": "mnist", + "dataset_args": "sasha/dog-food", "dataset_data_keys": {"image": "image"}, }, } @@ -232,6 +232,11 @@ def test_load_dataset_with_max_length(self): input_ids = dataset[0]["input_ids"] self.assertEqual(len(input_ids), max_length) + def test_load_default_dataset(self): + self.skipTest( + "Skipping so as not to execute conll2003 remote code (test would require trust_remote_code=True)" + ) + class QuestionAnsweringProcessorTest(TestCase, TaskProcessorTestBase): TASK_NAME = "question-answering" From 8b43dd2f9fa17c2e08520bf61d1bdc17b8115d69 Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Tue, 18 Jun 2024 16:14:50 +0200 Subject: [PATCH 7/8] Remove dataset with restrictive license (#1910) * rm dataset with restrictive license * format --- optimum/gptq/data.py | 41 ++++++--------------------------- tests/gptq/test_quantization.py | 2 +- 2 files changed, 8 insertions(+), 35 deletions(-) diff --git a/optimum/gptq/data.py b/optimum/gptq/data.py index 37a42714fc..b8734da478 100644 --- a/optimum/gptq/data.py +++ b/optimum/gptq/data.py @@ -182,40 +182,11 @@ def get_c4_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train") def get_ptb(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): - if split == "train": - data = load_dataset("ptb_text_only", "penn_treebank", split="train") - elif split == "validation": - data = load_dataset("ptb_text_only", "penn_treebank", split="validation") - - enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt") - - dataset = [] - for _ in range(nsamples): - i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = enc.input_ids[:, i:j] - attention_mask = torch.ones_like(inp) - dataset.append({"input_ids": inp, "attention_mask": attention_mask}) - - return dataset + raise RuntimeError("Loading the `ptb` dataset was deprecated") def get_ptb_new(tokenizer: Any, seqlen: int, nsamples: int, split: str = "train"): - if split == "train": - data = load_dataset("ptb_text_only", "penn_treebank", split="train") - elif split == "validation": - data = load_dataset("ptb_text_only", "penn_treebank", split="test") - - enc = tokenizer(" ".join(data["sentence"]), return_tensors="pt") - - dataset = [] - for _ in range(nsamples): - i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = enc.input_ids[:, i:j] - attention_mask = torch.ones_like(inp) - dataset.append({"input_ids": inp, "attention_mask": attention_mask}) - return dataset + raise RuntimeError("Loading the `ptb` dataset was deprecated") def get_dataset( @@ -226,7 +197,7 @@ def get_dataset( Args: dataset_name (`str`): - Dataset name. Available options are `['wikitext2', 'c4', 'ptb', 'c4-new', 'ptb_new']`. + Dataset name. Available options are `['wikitext2', 'c4', 'c4-new']`. tokenizer (`Any`): Tokenizer of the model nsamples (`int`, defaults to `128`): @@ -247,11 +218,13 @@ def get_dataset( "wikitext2": get_wikitext2, "c4": get_c4, "c4-new": get_c4_new, - "ptb": get_ptb, - "ptb-new": get_ptb_new, } if split not in ["train", "validation"]: raise ValueError(f"The split need to be 'train' or 'validation' but found {split}") + if dataset_name in {"ptb", "ptb-new"}: + raise ValueError( + f"{dataset_name} dataset was deprecated, only the following dataset are supported : {list(get_dataset_map)}" + ) if dataset_name not in get_dataset_map: raise ValueError(f"Expected a value in {list(get_dataset_map.keys())} but found {dataset_name}") get_dataset_fn = get_dataset_map[dataset_name] diff --git a/tests/gptq/test_quantization.py b/tests/gptq/test_quantization.py index 0c070f8c9e..5ed1619fde 100644 --- a/tests/gptq/test_quantization.py +++ b/tests/gptq/test_quantization.py @@ -394,7 +394,7 @@ class GPTQDataTest(unittest.TestCase): def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) - @parameterized.expand(["wikitext2", "c4", "ptb", "c4-new", "ptb-new"]) + @parameterized.expand(["wikitext2", "c4", "c4-new"]) def test_dataset(self, dataset): train_dataset = get_dataset( dataset, self.tokenizer, nsamples=self.NBSAMPLES, seqlen=self.SEQLEN, split="train" From aad4b8beff3194af2679f762e2097113943c9f07 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:13:47 +0100 Subject: [PATCH 8/8] Fix Windows and onnx dtype compatibility (#1886) * fix pkv and audio * add t5 test * fix seq2seq * fix vision2seq tests as it seems to have had always outputed kv cache in torch format before * fix folder deletion on windows * fix temporary directory removal on windows * remove attention_mask creation as ORTModelForxxx's corresponding processors will create it * remove_directory utility function --- optimum/onnxruntime/base.py | 124 ++---- optimum/onnxruntime/modeling_decoder.py | 73 ++-- optimum/onnxruntime/modeling_ort.py | 515 +++++++++--------------- optimum/utils/testing_utils.py | 14 + tests/onnxruntime/test_modeling.py | 58 +-- 5 files changed, 299 insertions(+), 485 deletions(-) diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index bf9c80a86c..16461dce95 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -14,7 +14,7 @@ """Defines the base classes that are used to perform inference with ONNX Runtime of Transformers models.""" from abc import abstractmethod -from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, Union +from typing import Dict, Optional, Set, Tuple, Union import numpy as np import torch @@ -24,22 +24,22 @@ from ..utils import NormalizedConfigManager from ..utils.logging import warn_once +from .modeling_ort import ORTModel from .utils import get_ordered_input_names, logging logger = logging.get_logger(__name__) -if TYPE_CHECKING: - from .modeling_ort import ORTModel - - class ORTModelPart: """ For multi-file ONNX models, such as encoder-decoder models, represents a part of the model. It has its own `onnxruntime.InferenceSession`, and can perform a forward pass. """ + _prepare_onnx_inputs = ORTModel._prepare_onnx_inputs + _prepare_onnx_outputs = ORTModel._prepare_onnx_outputs + def __init__( self, session: InferenceSession, @@ -53,6 +53,8 @@ def __init__( self.main_input_name = self.parent_model.main_input_name self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} + self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()} + self.output_dtypes = {output_key.name: output_key.type for output_key in session.get_outputs()} self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward) @@ -98,25 +100,13 @@ def forward( last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - if use_torch: - onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()} - - # Add the attention_mask inputs when needed - if "attention_mask" in self.input_names: - onnx_inputs["attention_mask"] = attention_mask.cpu().detach().numpy() - else: - onnx_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} - # Add the attention_mask inputs when needed - if "attention_mask" in self.input_names: - onnx_inputs["attention_mask"] = attention_mask + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - # Run inference - outputs = self.session.run(None, onnx_inputs) - - last_hidden_state = outputs[self.output_names["last_hidden_state"]] - if use_torch: - last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device) + last_hidden_state = model_outputs["last_hidden_state"] return BaseModelOutput(last_hidden_state=last_hidden_state) @@ -350,83 +340,29 @@ def forward( else: raise ValueError("Unsupported num_pkv") else: - if use_torch: - onnx_inputs = { - "input_ids": input_ids.cpu().detach().numpy(), - } - - # Add the encoder_hidden_states inputs when needed - if "encoder_hidden_states" in self.input_names: - onnx_inputs["encoder_hidden_states"] = encoder_hidden_states.cpu().detach().numpy() - - # Add the decoder_attention_mask inputs when needed - if "decoder_attention_mask" in self.input_names: - onnx_inputs["decoder_attention_mask"] = decoder_attention_mask.cpu().detach().numpy() - - # Add the encoder_attention_mask inputs when needed - if "encoder_attention_mask" in self.input_names: - onnx_inputs["encoder_attention_mask"] = encoder_attention_mask.cpu().detach().numpy() - - if past_key_values is not None: - # Add the past_key_values to the decoder inputs - for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): - onnx_inputs[input_name] = past_key_value.cpu().detach().numpy() - - if "labels" in self.input_names: - # TODO: Any preprocessing like `self._shift_right(labels)`? - onnx_inputs["labels"] = labels.cpu().detach().numpy() - - if self.parent_model.use_merged is True: - onnx_inputs["use_cache_branch"] = use_cache_branch_tensor.cpu().detach().numpy() - else: - onnx_inputs = { - "input_ids": input_ids, - } - - # Add the encoder_hidden_states inputs when needed - if "encoder_hidden_states" in self.input_names: - onnx_inputs["encoder_hidden_states"] = encoder_hidden_states - - # Add the decoder_attention_mask inputs when needed - if "decoder_attention_mask" in self.input_names: - onnx_inputs["decoder_attention_mask"] = decoder_attention_mask - - # Add the encoder_attention_mask inputs when needed - if "encoder_attention_mask" in self.input_names: - onnx_inputs["encoder_attention_mask"] = encoder_attention_mask - - if past_key_values is not None: - # Add the past_key_values to the decoder inputs - for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): - onnx_inputs[input_name] = past_key_value - - if "labels" in self.input_names: - # TODO: Any preprocessing like `self._shift_right(labels)`? - onnx_inputs["labels"] = labels - - if self.parent_model.use_merged is True: - onnx_inputs["use_cache_branch"] = use_cache_branch_tensor + model_inputs = { + "input_ids": input_ids, + "encoder_hidden_states": encoder_hidden_states, + "decoder_attention_mask": decoder_attention_mask, + "encoder_attention_mask": encoder_attention_mask, + "use_cache_branch": use_cache_branch_tensor, + "labels": labels, + } + if past_key_values is not None: + model_inputs.update(zip(self.key_value_input_names, past_key_values)) - # Run inference - outputs = self.session.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - # TODO: using two loops here is probably unefficient + # TODO: using a new variable out_past_key_values is memory inefficient, + # past_key_values is not used anymore at this point # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the # self-attention layer and 2 to the cross-attention layer) - out_past_key_values = tuple( - torch.from_numpy(outputs[self.output_names[key]]).to(self.device) - for key in self.key_value_output_names - ) - - logits = outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + out_past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names) - loss = None - if "loss" in self.output_names: - loss = outputs[self.output_names["loss"]] - if use_torch: - loss = torch.from_numpy(loss).to(self.device) + loss = model_outputs.get("loss", None) + logits = model_outputs["logits"] # TODO: this is extremely ugly and unreadable. What if cross-attention k/v change? # Tuple of tuple of length `n_layers`, with each tuple of length equal to: diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 2d9be2d757..5d4bbe184e 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -46,7 +46,7 @@ if check_if_transformers_greater("4.25.0"): from transformers.generation import GenerationMixin else: - from transformers.generation_utils import GenerationMixin + from transformers.generation_utils import GenerationMixin # type: ignore # noqa: F401 logger = logging.getLogger(__name__) @@ -139,15 +139,16 @@ def __init__( self.num_pkv = 2 self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) - self.key_value_input_names = [key for key in self.inputs_names if (".key" in key) or (".value" in key)] + self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)] self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)] self.use_cache = len(self.key_value_input_names) > 0 if generation_config is None: generation_config = GenerationConfig.from_model_config(config) + self.generation_config = generation_config self.onnx_paths = [self.model_path] - self.use_merged = "use_cache_branch" in self.inputs_names + self.use_merged = "use_cache_branch" in self.input_names self.model_type = self.config.model_type self.use_fp16 = False @@ -160,7 +161,7 @@ def __init__( # Reference: https://github.com/huggingface/optimum/pull/1381 model_type = config.model_type.replace("_", "-") - if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.inputs_names: + if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.input_names: logger.warning( f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. " "We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support." @@ -202,7 +203,6 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) - inputs = {} known_output_shapes = {} use_cache_branch = None loss = None @@ -226,10 +226,10 @@ def forward( # I suspect the reason is the contiguous python list that messes something up? model_inputs = [input_ids.contiguous()] - if "attention_mask" in self.inputs_names: + if "attention_mask" in self.input_names: model_inputs.append(attention_mask) - if "position_ids" in self.inputs_names: + if "position_ids" in self.input_names: if position_ids is None: raise ValueError("position_ids was not passed but is a required input for this ONNX model.") model_inputs.append(position_ids.contiguous()) @@ -240,12 +240,11 @@ def forward( if use_cache_branch is not None: model_inputs.append(use_cache_branch) - if "labels" in self.inputs_names: + if "labels" in self.input_names: model_inputs.append(labels) known_output_shapes.update({"loss": []}) - io_binding, output_shapes, output_buffers = self._prepare_io_binding( - self.model, + io_binding, output_shapes, output_buffers = self.prepare_io_binding( *model_inputs, known_output_shapes=known_output_shapes, ordered_input_names=self._ordered_input_names, @@ -259,53 +258,41 @@ def forward( io_binding.synchronize_outputs() if self.use_cache: - # Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2) - past_key_values = () - for name in self.key_value_output_names: - past_key_values += (output_buffers[name].view(output_shapes[name]),) + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2 for the self-attention) + past_key_values = tuple( + output_buffers[name].view(output_shapes[name]) for name in self.key_value_output_names + ) logits = output_buffers["logits"].view(output_shapes["logits"]) if "loss" in self.output_names: loss = output_buffers["loss"].view(output_shapes["loss"]) else: - inputs["input_ids"] = input_ids.cpu().detach().numpy() if use_torch else input_ids - - if "attention_mask" in self.inputs_names: - inputs["attention_mask"] = attention_mask.cpu().detach().numpy() if use_torch else attention_mask - - if "labels" in self.inputs_names: - inputs["labels"] = labels.cpu().detach().numpy() if use_torch else labels - - if "position_ids" in self.inputs_names: - if position_ids is None: - raise ValueError("position_ids was not passed but is a required input for this ONNX model.") - inputs["position_ids"] = position_ids.cpu().detach().numpy() if use_torch else position_ids - - # Add the past_key_values to the decoder inputs + model_inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": attention_mask, + "use_cache_branch": use_cache_branch, + "labels": labels, + } if past_key_values is not None: - for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): - inputs[input_name] = past_key_value.cpu().detach().numpy() if use_torch else past_key_value + model_inputs.update( + zip(self.key_value_input_names, past_key_values), + ) - if use_cache_branch is not None: - inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() if use_torch else use_cache_branch + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - outputs = self.model.run(None, inputs) + loss = model_outputs.get("loss", None) + logits = model_outputs["logits"] if self.use_cache: # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 for the self-attention) - past_key_values = tuple( - torch.from_numpy(outputs[self.output_names[key]]).to(self.device) - for key in self.key_value_output_names - ) - - logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device) - if "loss" in self.output_names: - loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device) + past_key_values = tuple(model_outputs[output_name] for output_name in self.key_value_output_names) if self.use_cache and self.model_type != "gpt_bigcode": - # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and - # per decoder layer + # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and per decoder layer past_key_values = tuple( past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) ) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index b65e1d3b29..734c9b6551 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -267,10 +267,13 @@ def __init__( **kwargs, ) - self.inputs_names = {input_key.name: idx for idx, input_key in enumerate(model.get_inputs())} + self.input_names = {input_key.name: idx for idx, input_key in enumerate(model.get_inputs())} + self.input_dtypes = {input_key.name: input_key.type for input_key in model.get_inputs()} + self.output_names = {output_key.name: idx for idx, output_key in enumerate(model.get_outputs())} + self.output_dtypes = {output_key.name: output_key.type for output_key in model.get_outputs()} - self._ordered_input_names = get_ordered_input_names(self.inputs_names.keys(), func=self.forward) + self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward) # TODO: why do we make device a property since we are only access the value, and do not do any check when setting the value? @property @@ -736,6 +739,7 @@ def _output_shape_inference(self, axis_name: Union[str, int], dimensions: Dict[s # exception. return int(eval(" ".join(tokens))) + # TODO: this method is bloated with state arguments (that are accesible using self) why ? def _prepare_io_binding( self, model: ort.InferenceSession, @@ -833,9 +837,15 @@ def _prepare_io_binding( return io_binding, output_shapes, output_buffers - def prepare_io_binding(self, *model_inputs, ordered_input_names, known_output_shapes=None): + def prepare_io_binding( + self, *model_inputs, ordered_input_names, outputs_to_not_bind=None, known_output_shapes=None + ): return self._prepare_io_binding( - self.model, ordered_input_names=ordered_input_names, known_output_shapes=known_output_shapes, *model_inputs + self.model, + *model_inputs, + ordered_input_names=ordered_input_names, + known_output_shapes=known_output_shapes, + outputs_to_not_bind=outputs_to_not_bind, ) def raise_on_numpy_input_io_binding(self, use_torch: bool): @@ -852,6 +862,39 @@ def raise_on_numpy_input_io_binding(self, use_torch: bool): " with model.use_io_binding = False, or pass torch.Tensor inputs instead." ) + def _prepare_onnx_inputs( + self, use_torch: bool, **inputs: Union[torch.Tensor, np.ndarray] + ) -> Dict[str, np.ndarray]: + onnx_inputs = {} + + # converts pytorch inputs into numpy inputs for onnx + for input_name in self.input_names.keys(): + onnx_inputs[input_name] = inputs.pop(input_name) + + if use_torch: + onnx_inputs[input_name] = onnx_inputs[input_name].cpu().detach().numpy() + + if onnx_inputs[input_name].dtype != self.input_dtypes[input_name]: + onnx_inputs[input_name] = onnx_inputs[input_name].astype( + TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name]) + ) + + return onnx_inputs + + def _prepare_onnx_outputs( + self, use_torch: bool, *onnx_outputs: np.ndarray + ) -> Dict[str, Union[torch.Tensor, np.ndarray]]: + model_outputs = {} + + # converts onnxruntime outputs into tensor for standard outputs + for output_name, idx in self.output_names.items(): + model_outputs[output_name] = onnx_outputs[idx] + + if use_torch: + model_outputs[output_name] = torch.from_numpy(model_outputs[output_name]).to(self.device) + + return model_outputs + @staticmethod def _cached_file( model_path: Union[Path, str], @@ -970,9 +1013,6 @@ def forward( self.raise_on_numpy_input_io_binding(use_torch) if self.device.type == "cuda" and self.use_io_binding: - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_ids, attention_mask, @@ -985,35 +1025,21 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return BaseModelOutput( - last_hidden_state=output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) - ) + last_hidden_state = output_buffers["last_hidden_state"].view(output_shapes["last_hidden_state"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - if attention_mask is None: - attention_mask = np.ones_like(input_ids) - else: - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids - - outputs = self.model.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - last_hidden_state = outputs[self.output_names["last_hidden_state"]] - if use_torch: - last_hidden_state = torch.from_numpy(last_hidden_state).to(self.device) + # TODO: why do we only return last_hidden_state? why not all outputs? + # that way, there will be less need for ORTModelForCustomTask in cases where + # we just want to extend model outputs with attentions, hidden_states, etc. + last_hidden_state = model_outputs["last_hidden_state"] - # converts output to namedtuple for pipelines post-processing - return BaseModelOutput(last_hidden_state=last_hidden_state) + # converts output to namedtuple for pipelines post-processing + return BaseModelOutput(last_hidden_state=last_hidden_state) @classmethod def _export( @@ -1144,32 +1170,18 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return MaskedLMOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() - - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids - - # run inference - outputs = self.model.run(None, onnx_inputs) - logits = outputs[self.output_names["logits"]] + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + + logits = model_outputs["logits"] - # converts output to namedtuple for pipelines post-processing - return MaskedLMOutput(logits=logits) + # converts output to namedtuple for pipelines post-processing + return MaskedLMOutput(logits=logits) QUESTION_ANSWERING_EXAMPLE = r""" @@ -1247,37 +1259,21 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return QuestionAnsweringModelOutput( - start_logits=output_buffers["start_logits"].view(output_shapes["start_logits"]), - end_logits=output_buffers["end_logits"].view(output_shapes["end_logits"]), - ) + # TODO: this is the same routine in all io binding branches, should we refactor it into a prepare_io_binding_outputs method? + start_logits = output_buffers["start_logits"].view(output_shapes["start_logits"]) + end_logits = output_buffers["end_logits"].view(output_shapes["end_logits"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() - - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids - - # run inference - outputs = self.model.run(None, onnx_inputs) - - start_logits = outputs[self.output_names["start_logits"]] - end_logits = outputs[self.output_names["end_logits"]] - if use_torch: - start_logits = torch.from_numpy(start_logits).to(self.device) - end_logits = torch.from_numpy(end_logits).to(self.device) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - # converts output to namedtuple for pipelines post-processing - return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + + start_logits = model_outputs["start_logits"] + end_logits = model_outputs["end_logits"] + + # converts output to namedtuple for pipelines post-processing + return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits) SEQUENCE_CLASSIFICATION_EXAMPLE = r""" @@ -1370,30 +1366,18 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return SequenceClassifierOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids - - outputs = self.model.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - logits = outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + logits = model_outputs["logits"] - # converts output to namedtuple for pipelines post-processing - return SequenceClassifierOutput(logits=logits) + # converts output to namedtuple for pipelines post-processing + return SequenceClassifierOutput(logits=logits) TOKEN_CLASSIFICATION_EXAMPLE = r""" @@ -1472,32 +1456,17 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return TokenClassifierOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() - - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids - - # run inference - outputs = self.model.run(None, onnx_inputs) - logits = outputs[self.output_names["logits"]] + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + + logits = model_outputs["logits"] - # converts output to namedtuple for pipelines post-processing - return TokenClassifierOutput(logits=logits) + return TokenClassifierOutput(logits=logits) MULTIPLE_CHOICE_EXAMPLE = r""" @@ -1570,31 +1539,18 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return MultipleChoiceModelOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - input_ids = input_ids.cpu().detach().numpy() - attention_mask = attention_mask.cpu().detach().numpy() - if token_type_ids is not None: - token_type_ids = token_type_ids.cpu().detach().numpy() - - onnx_inputs = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - if token_type_ids is not None: - onnx_inputs["token_type_ids"] = token_type_ids - - # run inference - outputs = self.model.run(None, onnx_inputs) - logits = outputs[self.output_names["logits"]] + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + + logits = model_outputs["logits"] - # converts output to namedtuple for pipelines post-processing - return MultipleChoiceModelOutput(logits=logits) + # converts output to namedtuple for pipelines post-processing + return MultipleChoiceModelOutput(logits=logits) IMAGE_CLASSIFICATION_EXAMPLE = r""" @@ -1662,7 +1618,8 @@ def forward( if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( - pixel_values, ordered_input_names=self._ordered_input_names + pixel_values, + ordered_input_names=self._ordered_input_names, ) # run inference with binding & synchronize in case of multiple CUDA streams @@ -1670,25 +1627,18 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return ImageClassifierOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - pixel_values = pixel_values.cpu().detach().numpy() + model_inputs = {"pixel_values": pixel_values} - onnx_inputs = { - "pixel_values": pixel_values, - } - - # run inference - outputs = self.model.run(None, onnx_inputs) - logits = outputs[self.output_names["logits"]] + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + logits = model_outputs["logits"] - # converts output to namedtuple for pipelines post-processing - return ImageClassifierOutput(logits=logits) + # converts output to namedtuple for pipelines post-processing + return ImageClassifierOutput(logits=logits) SEMANTIC_SEGMENTATION_EXAMPLE = r""" @@ -1755,47 +1705,28 @@ def forward( self.raise_on_numpy_input_io_binding(use_torch) if self.device.type == "cuda" and self.use_io_binding: - io_binding = IOBindingHelper.prepare_io_binding( - self, + io_binding, output_shapes, output_buffers = self.prepare_io_binding( pixel_values, - **kwargs, ordered_input_names=self._ordered_input_names, ) - # run inference with binding + # run inference with binding & synchronize in case of multiple CUDA streams io_binding.synchronize_inputs() self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - outputs = {} - for name, output in zip(self.output_names.keys(), io_binding._iobinding.get_outputs()): - outputs[name] = IOBindingHelper.to_pytorch(output) - - # converts output to namedtuple for pipelines post-processing - return SemanticSegmenterOutput(logits=outputs["logits"]) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - onnx_inputs = self._prepare_onnx_inputs(use_torch=use_torch, pixel_values=pixel_values, **kwargs) + model_inputs = {"pixel_values": pixel_values} - # run inference + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - logits = onnx_outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) - - # converts output to namedtuple for pipelines post-processing - return SemanticSegmenterOutput(logits=logits) - - def _prepare_onnx_inputs(self, use_torch: bool, **kwargs): - onnx_inputs = {} - # converts pytorch inputs into numpy inputs for onnx - for input in self.inputs_names.keys(): - onnx_inputs[input] = kwargs.pop(input) - - if use_torch: - onnx_inputs[input] = onnx_inputs[input].cpu().detach().numpy() + logits = model_outputs["logits"] - return onnx_inputs + # converts output to namedtuple for pipelines post-processing + return SemanticSegmenterOutput(logits=logits) AUDIO_CLASSIFICATION_EXAMPLE = r""" @@ -1883,18 +1814,28 @@ def __init__( ) def forward( self, - input_values: Optional[torch.Tensor] = None, - attenton_mask: Optional[torch.Tensor] = None, + input_values: Optional[Union[torch.Tensor, np.ndarray]] = None, + attention_mask: Optional[Union[torch.Tensor, np.ndarray]] = None, + input_features: Optional[Union[torch.Tensor, np.ndarray]] = None, **kwargs, ): - if input_values is None: - # Whisper uses input_features and not input_values. - input_values = kwargs["input_features"] - use_torch = isinstance(input_values, torch.Tensor) + if self.input_name == "input_features": + assert input_features is not None, "input_features must be provided for this model" + model_input = input_features + elif self.input_name == "input_values": + assert input_values is not None, "input_values must be provided for this model" + model_input = input_values + else: + raise ValueError(f"Input {self.input_name} not supported for Audio Classification") + + use_torch = isinstance(model_input, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_values, ordered_input_names=self._ordered_input_names + model_input, + attention_mask, + ordered_input_names=self._ordered_input_names, ) # run inference with binding & synchronize in case of multiple CUDA streams @@ -1902,28 +1843,18 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return SequenceClassifierOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - self.input_name: input_values.cpu().detach().numpy(), - } - else: - onnx_inputs = { - self.input_name: input_values, - } + model_inputs = {self.input_name: model_input, "attention_mask": attention_mask} - # run inference - outputs = self.model.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - logits = outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) + logits = model_outputs["logits"] - # converts output to namedtuple for pipelines post-processing - return SequenceClassifierOutput(logits=logits) + # converts output to namedtuple for pipelines post-processing + return SequenceClassifierOutput(logits=logits) CTC_EXAMPLE = r""" @@ -1971,11 +1902,12 @@ class ORTModelForCTC(ORTModel): ) def forward( self, - input_values: Optional[torch.Tensor] = None, + input_values: Optional[Union[torch.Tensor, np.ndarray]] = None, **kwargs, ): use_torch = isinstance(input_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.use_io_binding: input_size = input_values.shape[1] output_sizes = [] @@ -1990,9 +1922,7 @@ def _conv_output_size(input_size, kernel_size, stride): known_output_shapes = {"logits": [input_values.shape[0], output_sizes[-1], self.config.vocab_size]} io_binding, output_shapes, output_buffers = self.prepare_io_binding( - input_values, - ordered_input_names=self._ordered_input_names, - known_output_shapes=known_output_shapes, + input_values, ordered_input_names=self._ordered_input_names, known_output_shapes=known_output_shapes ) # run inference with binding & synchronize in case of multiple CUDA streams @@ -2000,28 +1930,18 @@ def _conv_output_size(input_size, kernel_size, stride): self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - outputs = {} - - return CausalLMOutput(logits=output_buffers["logits"].view(output_shapes["logits"])) + logits = output_buffers["logits"].view(output_shapes["logits"]) else: - if use_torch: - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_values": input_values.cpu().detach().numpy(), - } - else: - onnx_inputs = { - "input_values": input_values, - } + model_inputs = {"input_values": input_values} - # run inference - outputs = self.model.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - logits = outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) - # converts output to namedtuple for pipelines post-processing - return CausalLMOutput(logits=logits) + logits = model_outputs["logits"] + + # converts output to namedtuple for pipelines post-processing + return CausalLMOutput(logits=logits) AUDIO_XVECTOR_EXAMPLE = r""" @@ -2077,11 +1997,12 @@ class ORTModelForAudioXVector(ORTModel): ) def forward( self, - input_values: Optional[torch.Tensor] = None, + input_values: Optional[Union[torch.Tensor, np.ndarray]] = None, **kwargs, ): use_torch = isinstance(input_values, torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.use_io_binding: io_binding, output_shapes, output_buffers = self.prepare_io_binding( input_values, ordered_input_names=self._ordered_input_names @@ -2092,33 +2013,21 @@ def forward( self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - # converts output to namedtuple for pipelines post-processing - return XVectorOutput( - logits=output_buffers["logits"].view(output_shapes["logits"]), - embeddings=output_buffers["embeddings"].view(output_shapes["embeddings"]), - ) + logits = output_buffers["logits"].view(output_shapes["logits"]) + embeddings = output_buffers["embeddings"].view(output_shapes["embeddings"]) + else: - if use_torch: - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_values": input_values.cpu().detach().numpy(), - } - else: - onnx_inputs = { - "input_values": input_values, - } + model_inputs = {"input_values": input_values} - # run inference - outputs = self.model.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - logits = outputs[self.output_names["logits"]] - embeddings = outputs[self.output_names["embeddings"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) - embeddings = torch.from_numpy(embeddings).to(self.device) + logits = model_outputs["logits"] + embeddings = model_outputs["embeddings"] - # converts output to namedtuple for pipelines post-processing - return XVectorOutput(logits=logits, embeddings=embeddings) + # converts output to namedtuple for pipelines post-processing + return XVectorOutput(logits=logits, embeddings=embeddings) AUDIO_FRAME_CLASSIFICATION_EXAMPLE = r""" @@ -2166,7 +2075,7 @@ class ORTModelForAudioFrameClassification(ORTModel): ) def forward( self, - input_values: Optional[torch.Tensor] = None, + input_values: Optional[Union[torch.Tensor, np.ndarray]] = None, **kwargs, ): use_torch = isinstance(input_values, torch.Tensor) @@ -2175,24 +2084,16 @@ def forward( if self.device.type == "cuda" and self.use_io_binding: raise NotImplementedError() else: - if use_torch: - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = { - "input_values": input_values.cpu().detach().numpy(), - } - else: - onnx_inputs = { - "input_values": input_values, - } + model_inputs = {"input_values": input_values} - # run inference - outputs = self.model.run(None, onnx_inputs) + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - logits = outputs[self.output_names["logits"]] - if use_torch: - logits = torch.from_numpy(logits).to(self.device) - # converts output to namedtuple for pipelines post-processing - return TokenClassifierOutput(logits=logits) + logits = model_outputs["logits"] + + # converts output to namedtuple for pipelines post-processing + return TokenClassifierOutput(logits=logits) CUSTOM_TASKS_EXAMPLE = r""" @@ -2241,57 +2142,27 @@ class ORTModelForCustomTasks(ORTModel): checkpoint="optimum/sbert-all-MiniLM-L6-with-pooler", ) ) - def forward(self, **kwargs): - use_torch = isinstance(next(iter(kwargs.values())), torch.Tensor) + def forward(self, **model_inputs: Union[torch.Tensor, np.ndarray]): + use_torch = isinstance(next(iter(model_inputs.values())), torch.Tensor) self.raise_on_numpy_input_io_binding(use_torch) if self.device.type == "cuda" and self.use_io_binding: - io_binding = IOBindingHelper.prepare_io_binding( - self, - **kwargs, - ordered_input_names=self._ordered_input_names, - ) + # TODO: should this be used in favor of `model.prepare_io_binding`? + io_binding = IOBindingHelper.prepare_io_binding(self, **model_inputs) # run inference with binding io_binding.synchronize_inputs() self.model.run_with_iobinding(io_binding) io_binding.synchronize_outputs() - outputs = {} + model_outputs = {} for name, output in zip(self.output_names.keys(), io_binding._iobinding.get_outputs()): - outputs[name] = IOBindingHelper.to_pytorch(output) + model_outputs[name] = IOBindingHelper.to_pytorch(output) - # converts output to namedtuple for pipelines post-processing - return ModelOutput(**outputs) else: - # converts pytorch inputs into numpy inputs for onnx - onnx_inputs = self._prepare_onnx_inputs(use_torch=use_torch, **kwargs) - - # run inference + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) onnx_outputs = self.model.run(None, onnx_inputs) - outputs = self._prepare_onnx_outputs(onnx_outputs, use_torch=use_torch) - - # converts output to namedtuple for pipelines post-processing - return ModelOutput(outputs) - - def _prepare_onnx_inputs(self, use_torch: bool, **kwargs): - onnx_inputs = {} - # converts pytorch inputs into numpy inputs for onnx - for input in self.inputs_names.keys(): - onnx_inputs[input] = kwargs.pop(input) - - if use_torch: - onnx_inputs[input] = onnx_inputs[input].cpu().detach().numpy() - - return onnx_inputs - - def _prepare_onnx_outputs(self, onnx_outputs, use_torch: bool): - outputs = {} - # converts onnxruntime outputs into tensor for standard outputs - for output, idx in self.output_names.items(): - outputs[output] = onnx_outputs[idx] - - if use_torch: - outputs[output] = torch.from_numpy(outputs[output]).to(self.device) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) - return outputs + # converts output to namedtuple for pipelines post-processing + return ModelOutput(**model_outputs) diff --git a/optimum/utils/testing_utils.py b/optimum/utils/testing_utils.py index a7c2b8bb05..41bd140d86 100644 --- a/optimum/utils/testing_utils.py +++ b/optimum/utils/testing_utils.py @@ -16,6 +16,7 @@ import importlib.util import itertools import os +import shutil import subprocess import sys import unittest @@ -181,3 +182,16 @@ def grid_parameters( else: returned_list = [test_name] + list(params) if add_test_name is True else list(params) yield returned_list + + +def remove_directory(dirpath): + """ + Remove a directory and its content. + This is a cross-platform solution to remove a directory and its content that avoids the use of `shutil.rmtree` on Windows. + Reference: https://github.com/python/cpython/issues/107408 + """ + if os.path.exists(dirpath) and os.path.isdir(dirpath): + if os.name == "nt": + os.system(f"rmdir /S /Q {dirpath}") + else: + shutil.rmtree(dirpath) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 7b2c8a66b9..6c88fddb40 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -14,7 +14,6 @@ # limitations under the License. import gc import os -import shutil import subprocess import tempfile import time @@ -109,7 +108,7 @@ DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, logging, ) -from optimum.utils.testing_utils import grid_parameters, require_hf_token, require_ort_rocm +from optimum.utils.testing_utils import grid_parameters, remove_directory, require_hf_token, require_ort_rocm logger = logging.get_logger() @@ -184,9 +183,8 @@ def test_load_model_from_cache(self): def test_load_model_from_empty_cache(self): dirpath = os.path.join(default_cache_path, "models--" + self.TINY_ONNX_MODEL_ID.replace("/", "--")) + remove_directory(dirpath) - if os.path.exists(dirpath) and os.path.isdir(dirpath): - shutil.rmtree(dirpath) with self.assertRaises(Exception): _ = ORTModel.from_pretrained(self.TINY_ONNX_MODEL_ID, local_files_only=True) @@ -202,9 +200,8 @@ def test_load_seq2seq_model_from_cache(self): def test_load_seq2seq_model_from_empty_cache(self): dirpath = os.path.join(default_cache_path, "models--" + self.TINY_ONNX_SEQ2SEQ_MODEL_ID.replace("/", "--")) + remove_directory(dirpath) - if os.path.exists(dirpath) and os.path.isdir(dirpath): - shutil.rmtree(dirpath) with self.assertRaises(Exception): _ = ORTModelForSeq2SeqLM.from_pretrained(self.TINY_ONNX_SEQ2SEQ_MODEL_ID, local_files_only=True) @@ -225,9 +222,8 @@ def test_load_stable_diffusion_model_from_empty_cache(self): dirpath = os.path.join( default_cache_path, "models--" + self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID.replace("/", "--") ) + remove_directory(dirpath) - if os.path.exists(dirpath) and os.path.isdir(dirpath): - shutil.rmtree(dirpath) with self.assertRaises(Exception): _ = ORTStableDiffusionPipeline.from_pretrained( self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID, local_files_only=True @@ -1008,6 +1004,7 @@ def test_save_load_ort_model_with_external_data(self): # verify loading from local folder works model = ORTModelForSequenceClassification.from_pretrained(tmpdirname, export=False) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") + remove_directory(tmpdirname) @parameterized.expand([(False,), (True,)]) @pytest.mark.run_slow @@ -1015,11 +1012,7 @@ def test_save_load_ort_model_with_external_data(self): def test_save_load_decoder_model_with_external_data(self, use_cache: bool): with tempfile.TemporaryDirectory() as tmpdirname: model = ORTModelForCausalLM.from_pretrained( - "gpt2-large", - use_cache=use_cache, - export=True, - use_merged=False, - use_io_binding=False, + "gpt2-large", use_cache=use_cache, export=True, use_merged=False, use_io_binding=False ) model.save_pretrained(tmpdirname) @@ -1033,6 +1026,7 @@ def test_save_load_decoder_model_with_external_data(self, use_cache: bool): model = ORTModelForCausalLM.from_pretrained( tmpdirname, use_cache=use_cache, export=False, use_io_binding=False ) + remove_directory(tmpdirname) @parameterized.expand([(False,), (True,)]) def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool): @@ -1055,6 +1049,7 @@ def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool): # verify loading from local folder works model = ORTModelForSeq2SeqLM.from_pretrained(tmpdirname, use_cache=use_cache, export=False) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") + remove_directory(tmpdirname) def test_save_load_stable_diffusion_model_with_external_data(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -1076,6 +1071,7 @@ def test_save_load_stable_diffusion_model_with_external_data(self): # verify loading from local folder works model = ORTStableDiffusionPipeline.from_pretrained(tmpdirname, export=False) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") + remove_directory(tmpdirname) @parameterized.expand([(False,), (True,)]) @unittest.skip("Skipping as this test consumes too much memory") @@ -2278,6 +2274,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): @parameterized.expand([(False,), (True,)]) @pytest.mark.run_in_series + # TODO: still gotta find out why this needs to be ran in series / why it fails in parallel + # my guess is that the model surgery is happening in parallel and that's causing the issue def test_inference_old_onnx_model(self, use_cache): tokenizer = get_preprocessor("gpt2") model = AutoModelForCausalLM.from_pretrained("gpt2") @@ -2290,9 +2288,9 @@ def test_inference_old_onnx_model(self, use_cache): tokens = tokenizer(text, return_tensors="pt") onnx_outputs = onnx_model.generate( - **tokens, num_beams=1, do_sample=False, min_new_tokens=10, max_new_tokens=10 + **tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30 ) - outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=10, max_new_tokens=10) + outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) onnx_text_outputs = tokenizer.decode(onnx_outputs[0], skip_special_tokens=True) text_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True) self.assertEqual(onnx_text_outputs, text_outputs) @@ -3605,13 +3603,20 @@ def _get_onnx_model_dir(self, model_id, model_arch, test_name): @pytest.mark.run_in_series def test_inference_old_onnx_model(self): - model = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small") + tokenizer = get_preprocessor("t5-small") + model = AutoModelForSeq2SeqLM.from_pretrained("t5-small") + onnx_model = ORTModelForSeq2SeqLM.from_pretrained("optimum/t5-small") - tokenizer = get_preprocessor("optimum/t5-small") text = "This is a sample output" tokens = tokenizer(text, return_tensors="pt") - model.generate(**tokens) + outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) + onnx_outputs = onnx_model.generate( + **tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30 + ) + onnx_text_outputs = tokenizer.decode(onnx_outputs[0], skip_special_tokens=True) + text_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True) + self.assertEqual(onnx_text_outputs, text_outputs) def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: @@ -4760,6 +4765,9 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach self.assertTrue("logits" in onnx_outputs) self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + self.assertTrue( + torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3) + ) if use_cache: self.assertEqual( @@ -4768,19 +4776,17 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach self.assertEqual( len(onnx_outputs["past_key_values"][0]), len(transformers_outputs["past_key_values"][0]) ) - for i, _ in enumerate(onnx_outputs["past_key_values"]): - for j, ort_pkv in enumerate(onnx_outputs["past_key_values"][i]): - trfs_pkv = transformers_outputs["past_key_values"][i][j] + for i in range(len(onnx_outputs["past_key_values"])): + print(onnx_outputs["past_key_values"][i]) + for ort_pkv, trfs_pkv in zip( + onnx_outputs["past_key_values"][i], transformers_outputs["past_key_values"][i] + ): + ort_pkv = torch.Tensor(ort_pkv) self.assertTrue( torch.allclose(ort_pkv, trfs_pkv, atol=1e-3), f" Maxdiff: {torch.abs(ort_pkv - trfs_pkv).max()}", ) - # Compare tensor outputs - self.assertTrue( - torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3) - ) - gc.collect() @parameterized.expand(grid_parameters(FULL_GRID))