From 53f7ed43ba35e18f65281ab66359c7bfbf09e2c6 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Thu, 23 May 2024 09:07:50 +0000 Subject: [PATCH] feat: add NeuronModel base class This base class will implement transformers PreTrainedModel methods that are not implemented in optimum PreTrainedModel base class. --- docs/source/package_reference/modeling.mdx | 8 +- optimum/neuron/__init__.py | 4 +- optimum/neuron/modeling.py | 20 +- optimum/neuron/modeling_base.py | 597 +---------------- optimum/neuron/modeling_decoder.py | 4 +- optimum/neuron/modeling_diffusion.py | 18 +- optimum/neuron/modeling_seq2seq.py | 10 +- optimum/neuron/modeling_traced.py | 611 ++++++++++++++++++ optimum/neuron/pipelines/transformers/base.py | 12 +- tests/inference/inference_utils.py | 2 +- tests/inference/test_modeling.py | 16 +- tests/pipelines/test_encoder_pipelines.py | 4 +- 12 files changed, 673 insertions(+), 633 deletions(-) create mode 100644 optimum/neuron/modeling_traced.py diff --git a/docs/source/package_reference/modeling.mdx b/docs/source/package_reference/modeling.mdx index a35f8a24e..bbf63f191 100644 --- a/docs/source/package_reference/modeling.mdx +++ b/docs/source/package_reference/modeling.mdx @@ -18,12 +18,12 @@ limitations under the License. ## Generic model classes -### NeuronBaseModel +### NeuronTracedModel -The `NeuronBaseModel` class is available for instantiating a base Neuron model without a specific head. +The `NeuronTracedModel` class is available for instantiating a base Neuron model without a specific head. It is used as the base class for all tasks but text generation. -[[autodoc]] modeling_base.NeuronBaseModel +[[autodoc]] modeling_traced.NeuronTracedModel ### NeuronDecoderModel @@ -104,4 +104,4 @@ The following Neuron model classes are available for natural language processing ### NeuronStableDiffusionXLInpaintPipeline [[autodoc]] modeling_diffusion.NeuronStableDiffusionXLInpaintPipeline - - __call__ \ No newline at end of file + - __call__ diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index f2b43ff74..369107cc7 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -29,7 +29,7 @@ "hf_argparser": ["NeuronHfArgumentParser"], "trainers": ["NeuronTrainer", "Seq2SeqNeuronTrainer"], "training_args": ["NeuronTrainingArguments", "Seq2SeqNeuronTrainingArguments"], - "modeling_base": ["NeuronBaseModel"], + "modeling_traced": ["NeuronTracedModel"], "modeling": [ "NeuronModelForFeatureExtraction", "NeuronModelForSentenceTransformers", @@ -73,7 +73,6 @@ NeuronModelForSequenceClassification, NeuronModelForTokenClassification, ) - from .modeling_base import NeuronBaseModel from .modeling_decoder import NeuronDecoderModel from .modeling_diffusion import ( NeuronLatentConsistencyModelPipeline, @@ -85,6 +84,7 @@ NeuronStableDiffusionXLPipeline, ) from .modeling_seq2seq import NeuronModelForSeq2SeqLM + from .modeling_traced import NeuronTracedModel from .pipelines import pipeline from .trainers import NeuronTrainer, Seq2SeqNeuronTrainer from .training_args import NeuronTrainingArguments, Seq2SeqNeuronTrainingArguments diff --git a/optimum/neuron/modeling.py b/optimum/neuron/modeling.py index a2ee7c67e..ad516eb8d 100644 --- a/optimum/neuron/modeling.py +++ b/optimum/neuron/modeling.py @@ -43,8 +43,8 @@ from transformers.utils import ModelOutput from .generation import TokenSelector -from .modeling_base import NeuronBaseModel from .modeling_decoder import NeuronDecoderModel +from .modeling_traced import NeuronTracedModel if TYPE_CHECKING: @@ -61,13 +61,13 @@ _TOKENIZER_FOR_DOC = "AutoTokenizer" NEURON_MODEL_START_DOCSTRING = r""" - This model inherits from [`~neuron.modeling.NeuronBaseModel`]. Check the superclass documentation for the generic methods the + This model inherits from [`~neuron.modeling.NeuronTracedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving) Args: config (`transformers.PretrainedConfig`): [PretrainedConfig](https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig) is the Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`optimum.neuron.modeling.NeuronBaseModel.from_pretrained`] method to load the model weights. + configuration. Check out the [`optimum.neuron.modeling.NeuronTracedModel.from_pretrained`] method to load the model weights. model (`torch.jit._script.ScriptModule`): [torch.jit._script.ScriptModule](https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html) is the TorchScript module with embedded NEFF(Neuron Executable File Format) compiled by neuron(x) compiler. """ @@ -125,7 +125,7 @@ """, NEURON_MODEL_START_DOCSTRING, ) -class NeuronModelForFeatureExtraction(NeuronBaseModel): +class NeuronModelForFeatureExtraction(NeuronTracedModel): """ Feature Extraction model on Neuron devices. """ @@ -198,7 +198,7 @@ def forward( """, NEURON_MODEL_START_DOCSTRING, ) -class NeuronModelForSentenceTransformers(NeuronBaseModel): +class NeuronModelForSentenceTransformers(NeuronTracedModel): """ Sentence Transformers model on Neuron devices. """ @@ -283,7 +283,7 @@ def forward( """, NEURON_MODEL_START_DOCSTRING, ) -class NeuronModelForMaskedLM(NeuronBaseModel): +class NeuronModelForMaskedLM(NeuronTracedModel): """ Masked language model for on Neuron devices. """ @@ -353,7 +353,7 @@ def forward( """, NEURON_MODEL_START_DOCSTRING, ) -class NeuronModelForQuestionAnswering(NeuronBaseModel): +class NeuronModelForQuestionAnswering(NeuronTracedModel): """ Question Answering model on Neuron devices. """ @@ -422,7 +422,7 @@ def forward( """, NEURON_MODEL_START_DOCSTRING, ) -class NeuronModelForSequenceClassification(NeuronBaseModel): +class NeuronModelForSequenceClassification(NeuronTracedModel): """ Sequence Classification model on Neuron devices. """ @@ -490,7 +490,7 @@ def forward( """, NEURON_MODEL_START_DOCSTRING, ) -class NeuronModelForTokenClassification(NeuronBaseModel): +class NeuronModelForTokenClassification(NeuronTracedModel): """ Token Classification model on Neuron devices. """ @@ -571,7 +571,7 @@ def forward( """, NEURON_MODEL_START_DOCSTRING, ) -class NeuronModelForMultipleChoice(NeuronBaseModel): +class NeuronModelForMultipleChoice(NeuronTracedModel): """ Multiple choice model on Neuron devices. """ diff --git a/optimum/neuron/modeling_base.py b/optimum/neuron/modeling_base.py index 6c52c6f5e..fd6812e8c 100644 --- a/optimum/neuron/modeling_base.py +++ b/optimum/neuron/modeling_base.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,600 +12,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""NeuronBaseModel base classe for inference on neuron devices using the same API as Transformers.""" -import logging -import os -import shutil -from contextlib import contextmanager -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Union import torch -from huggingface_hub import HfApi, HfFolder, hf_hub_download -from huggingface_hub.utils import is_google_colab -from transformers import AutoConfig, AutoModel -from ..exporters.neuron import main_export -from ..exporters.neuron.model_configs import * # noqa: F403 -from ..exporters.tasks import TasksManager from ..modeling_base import OptimizedModel -from .utils import ( - NEURON_FILE_NAME, - check_if_weights_replacable, - is_neuron_available, - replace_weights, - store_compilation_config, -) -from .utils.hub_neuronx_cache import ModelCacheEntry, build_cache_config, create_hub_compile_cache_proxy -from .utils.import_utils import is_neuronx_available -from .utils.misc import maybe_load_preprocessors -from .utils.version_utils import check_compiler_compatibility, get_neuroncc_version, get_neuronxcc_version if TYPE_CHECKING: - from transformers import PretrainedConfig + from transformers import PretrainedConfig, PreTrainedModel - from ..exporters.neuron import NeuronDefaultConfig -if is_neuron_available(): - NEURON_COMPILER_TYPE = "neuron-cc" - NEURON_COMPILER_VERSION = get_neuroncc_version() +class NeuronModel(OptimizedModel): -if is_neuronx_available(): - NEURON_COMPILER_TYPE = "neuronx-cc" - NEURON_COMPILER_VERSION = get_neuronxcc_version() - -logger = logging.getLogger(__name__) - - -class NeuronBaseModel(OptimizedModel): - """ - Base class running compiled and optimized models on Neuron devices. - - It implements generic methods for interacting with the Hugging Face Hub as well as compiling vanilla - transformers models to neuron-optimized TorchScript module and export it using `optimum.exporters.neuron` toolchain. - - Class attributes: - - model_type (`str`, *optional*, defaults to `"neuron_model"`) -- The name of the model type to use when - registering the NeuronBaseModel classes. - - auto_model_class (`Type`, *optional*, defaults to `AutoModel`) -- The `AutoModel` class to be represented by the - current NeuronBaseModel class. - - Common attributes: - - model (`torch.jit._script.ScriptModule`) -- The loaded `ScriptModule` compiled for neuron devices. - - config ([`~transformers.PretrainedConfig`]) -- The configuration of the model. - - model_save_dir (`Path`) -- The directory where a neuron compiled model is saved. - By default, if the loaded model is local, the directory where the original model will be used. Otherwise, the - cache directory will be used. - """ - - model_type = "neuron_model" - auto_model_class = AutoModel - library_name = "transformers" - - def __init__( - self, - model: torch.jit._script.ScriptModule, - config: "PretrainedConfig", - model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - model_file_name: Optional[str] = None, - preprocessors: Optional[List] = None, - neuron_config: Optional["NeuronDefaultConfig"] = None, - **kwargs, - ): + def __init__(self, model: "PreTrainedModel", config: "PretrainedConfig"): super().__init__(model, config) - - self.model = model - self.model_file_name = model_file_name or NEURON_FILE_NAME - self.config = config - self.neuron_config = self._neuron_config_init(self.config) if neuron_config is None else neuron_config - self.input_static_shapes = NeuronBaseModel.get_input_static_shapes(self.neuron_config) - self._attributes_init(model_save_dir, preprocessors, **kwargs) - - @staticmethod - def load_model(path: Union[str, Path]) -> torch.jit._script.ScriptModule: - """ - Loads a TorchScript module compiled by neuron(x)-cc compiler. It will be first loaded onto CPU and then moved to - one or multiple [NeuronCore](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/neuroncores-arch.html). - - Args: - path (`Union[str, Path]`): - Path of the compiled model. - """ - if not isinstance(path, Path): - path = Path(path) - - if path.is_file(): - model = torch.jit.load(path) - return model - - def replace_weights(self, weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]] = None): - check_if_weights_replacable(self.config, weights) - if weights is not None: - replace_weights(self.model, weights) - - def _save_pretrained(self, save_directory: Union[str, Path]): - """ - Saves a model and its configuration file to a directory, so that it can be re-loaded using the - [`~optimum.neuron.modeling_base.NeuronBaseModel.from_pretrained`] class method. - - Args: - save_directory (`Union[str, Path]`): - Directory where to save the model file. - """ - src_path = self.model_save_dir / self.model_file_name - dst_path = Path(save_directory) / self.model_file_name - - shutil.copyfile(src_path, dst_path) - - @classmethod - def _from_pretrained( - cls, - model_id: Union[str, Path], - config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - force_download: bool = False, - cache_dir: Optional[str] = None, - file_name: Optional[str] = None, - subfolder: str = "", - local_files_only: bool = False, - model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - neuron_config: Optional["NeuronDefaultConfig"] = None, - **kwargs, - ) -> "NeuronBaseModel": - model_path = Path(model_id) - - if file_name is None: - if model_path.is_dir(): - neuron_files = list(model_path.glob("*.neuron")) - else: - if isinstance(use_auth_token, bool): - token = HfFolder().get_token() - else: - token = use_auth_token - repo_files = map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)) - pattern = "*.neuron" if subfolder == "" else f"{subfolder}/*.neuron" - neuron_files = [p for p in repo_files if p.match(pattern)] - - if len(neuron_files) == 0: - raise FileNotFoundError(f"Could not find any neuron model file in {model_path}") - elif len(neuron_files) > 1: - raise RuntimeError( - f"Too many neuron model files were found in {model_path}, specify which one to load by using the " - "file_name argument." - ) - else: - file_name = neuron_files[0].name - - # Check compiler compatibility(compiler type and version) of the saved model vs. system. - if hasattr(config, "neuron") and "compiler_type" in config.neuron: - model_compiler_type = config.neuron.get("compiler_type") - model_compiler_version = config.neuron.get("compiler_version") - check_compiler_compatibility(model_compiler_type, model_compiler_version) - - preprocessors = None - if model_path.is_dir(): - model = NeuronBaseModel.load_model(model_path / file_name) - new_model_save_dir = model_path - else: - model_cache_path = hf_hub_download( - repo_id=model_id, - filename=file_name, - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - - model = NeuronBaseModel.load_model(model_cache_path) - new_model_save_dir = Path(model_cache_path).parent - - preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) - - # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it - # instead of the path only. - if model_save_dir is None: - model_save_dir = new_model_save_dir - - return cls( - model=model, - config=config, - model_save_dir=model_save_dir, - model_file_name=file_name, - preprocessors=preprocessors, - neuron_config=neuron_config, - ) - - @classmethod - def _from_transformers(cls, *args, **kwargs): - # Deprecate it when optimum uses `_export` as from_pretrained_method in a stable release. - return cls._export(*args, **kwargs) - - @classmethod - def _export( - cls, - model_id: str, - config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, - library_name: Optional[str] = None, - force_download: bool = False, - cache_dir: Optional[str] = None, - compiler_workdir: Optional[Union[str, Path]] = None, - disable_neuron_cache: bool = False, - inline_weights_to_neff: bool = True, - optlevel: str = "2", - subfolder: str = "", - local_files_only: bool = False, - trust_remote_code: bool = False, - task: Optional[str] = None, - auto_cast: Optional[str] = None, - auto_cast_type: Optional[str] = None, - disable_fast_relayout: Optional[bool] = False, - disable_fallback: bool = False, - dynamic_batch_size: bool = False, - **kwargs_shapes, - ) -> "NeuronBaseModel": - """ - Exports a vanilla Transformers model into a neuron-compiled TorchScript Module using `optimum.exporters.neuron.export`. - - Args: - kwargs_shapes (`Dict[str, int]`): - Shapes to use during inference. This argument allows to override the default shapes used during the export. - """ - if task is None: - task = TasksManager.infer_task_from_model(cls.auto_model_class) - task = TasksManager.map_from_synonym(task) - library_name = TasksManager.infer_library_from_model(model_id, subfolder=subfolder, library_name=library_name) - - # Get compilation arguments - if is_neuron_available() and dynamic_batch_size is True and "batch_size" in kwargs_shapes: - kwargs_shapes["batch_size"] = 1 - disable_fallback = True # Turn off the fallback for neuron, otherwise dynamic batching will still fail - auto_cast_type = None if auto_cast is None else auto_cast_type - compiler_kwargs = { - "auto_cast": auto_cast, - "auto_cast_type": auto_cast_type, - "disable_fast_relayout": disable_fast_relayout, - "disable_fallback": disable_fallback, - } - - if ( - not inline_weights_to_neff and not disable_neuron_cache and is_neuronx_available() - ): # TODO: support caching of Inf1 as well - # Check if the cache exists - compilation_config = store_compilation_config( - config=config, - input_shapes=kwargs_shapes, - compiler_kwargs=compiler_kwargs, - dynamic_batch_size=dynamic_batch_size, - compiler_type=NEURON_COMPILER_TYPE, - compiler_version=NEURON_COMPILER_VERSION, - inline_weights_to_neff=inline_weights_to_neff, - optlevel=optlevel, - model_type=getattr(config, "model_type", None), - task=task, - ) - cache_config = build_cache_config(compilation_config) - cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config) - compile_cache = create_hub_compile_cache_proxy() - model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_entry.hash}") - cache_available = compile_cache.download_folder(model_cache_dir, model_cache_dir) - else: - cache_available = False - - # load cache - if cache_available: - try: - neuron_model = cls.from_pretrained(model_cache_dir) - model = TasksManager.get_model_from_task( - task=task, - model_name_or_path=model_id, - subfolder=subfolder, - revision=revision, - framework="pt", - library_name=library_name, - cache_dir=cache_dir, - use_auth_token=use_auth_token, - local_files_only=local_files_only, - force_download=force_download, - trust_remote_code=trust_remote_code, - ) - # replace weights - neuron_model.replace_weights(weights=model) - return neuron_model - except Exception as e: - logger.warning( - f"Found the cached artifacts but failed to re-load them with error: {e}. \n Falling back to recompilation." - ) - cache_available = False - - # compile - if not cache_available: - # compile - save_dir = TemporaryDirectory() - save_dir_path = Path(save_dir.name) - main_export( - model_name_or_path=model_id, - output=save_dir_path, - compiler_kwargs=compiler_kwargs, - task=task, - dynamic_batch_size=dynamic_batch_size, - cache_dir=cache_dir, - disable_neuron_cache=disable_neuron_cache, - compiler_workdir=compiler_workdir, - inline_weights_to_neff=inline_weights_to_neff, - optlevel=optlevel, - trust_remote_code=trust_remote_code, - subfolder=subfolder, - revision=revision, - force_download=force_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - do_validation=False, - library_name=library_name, - **kwargs_shapes, - ) - config = AutoConfig.from_pretrained(save_dir_path) - - return cls._from_pretrained(save_dir_path, config, model_save_dir=save_dir) - - def push_to_hub( - self, - save_directory: str, - repository_id: str, - private: Optional[bool] = None, - revision: Optional[str] = None, - use_auth_token: Union[bool, str] = True, - endpoint: Optional[str] = None, - ) -> str: - if isinstance(use_auth_token, str): - huggingface_token = use_auth_token - elif use_auth_token: - huggingface_token = HfFolder.get_token() + if hasattr(model, "device"): + self.device = model.device else: - raise ValueError("You need to provide `use_auth_token` to be able to push to the hub") - api = HfApi(endpoint=endpoint) - - user = api.whoami(huggingface_token) - if is_google_colab(): - # Only in Google Colab to avoid the warning message - self.git_config_username_and_email(git_email=user["email"], git_user=user["fullname"]) - - api.create_repo( - token=huggingface_token, - repo_id=repository_id, - exist_ok=True, - private=private, - ) - for path, subdirs, files in os.walk(save_directory): - for name in files: - local_file_path = os.path.join(path, name) - hub_file_path = os.path.relpath(local_file_path, save_directory) - api.upload_file( - token=huggingface_token, - repo_id=repository_id, - path_or_fileobj=os.path.join(os.getcwd(), local_file_path), - path_in_repo=hub_file_path, - revision=revision, - ) - - def forward(self, *args, **kwargs): - raise NotImplementedError - - def _attributes_init( - self, - model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - preprocessors: Optional[List] = None, - **kwargs, - ): - """ - Initializes attributes. - """ - self._path_tempdirectory_instance = None - if isinstance(model_save_dir, TemporaryDirectory): - self._path_tempdirectory_instance = model_save_dir - self.model_save_dir = Path(model_save_dir.name) - elif isinstance(model_save_dir, str): - self.model_save_dir = Path(model_save_dir) - else: - self.model_save_dir = model_save_dir - - self.preprocessors = preprocessors if preprocessors is not None else [] - - # Registers the NeuronModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating - # a pipeline https://github.com/huggingface/transformers/blob/3d3204c025b6b5de013e07dd364208e28b4d9589/src/transformers/pipelines/base.py#L940 - AutoConfig.register(self.model_type, AutoConfig) - if hasattr(self.auto_model_class, "register"): - self.auto_model_class.register(AutoConfig, self.__class__) - - @classmethod - def _neuron_config_init(cls, config: "PretrainedConfig") -> "NeuronDefaultConfig": - """ - Builds a `NeuronDefaultConfig` with an instance of the `PretrainedConfig` and the task. - """ - if not hasattr(config, "neuron"): - raise ValueError( - "Unable to identify neuron configuration with the keyword `neuron`, make sure that your config file contains necessary information" - ) - - neuron_config = config.neuron - # Fetch compiler information - compiler_type = neuron_config.get("compiler_type") - compiler_version = neuron_config.get("compiler_version") - - # Fetch mandatory shapes from config - compile_shapes = { - key.replace("static_", ""): value - for (key, value) in config.to_diff_dict().get("neuron").items() - if key.startswith("static_") - } - - # Neuron config constructuor - task = getattr(config, "task") or TasksManager.infer_task_from_model(cls.auto_model_class) - task = TasksManager.map_from_synonym(task) - model_type = neuron_config.get("model_type", None) or config.model_type - neuron_config_constructor = TasksManager.get_exporter_config_constructor( - model_type=model_type, - exporter="neuron", - task=task, - library_name=cls.library_name, - ) - - return neuron_config_constructor( - config, - dynamic_batch_size=neuron_config.get("dynamic_batch_size", False), - compiler_type=compiler_type, - compiler_version=compiler_version, - **compile_shapes, - ) - - @classmethod - def get_input_static_shapes(cls, neuron_config: "NeuronDefaultConfig") -> Dict[str, int]: - """ - Gets a dictionary of inputs with their valid static shapes. - """ - axes = neuron_config._axes - input_static_shapes = { - name: value.shape - for name, value in neuron_config.generate_dummy_inputs(return_tuple=False, **axes).items() - } - return input_static_shapes - - def _validate_static_shape(self, input_shapes: List[int], target_shapes: List[int]) -> bool: - """ - Checks if a input needs to be padded. - """ - if self.neuron_config.dynamic_batch_size is True: - batch_size_check = input_shapes[0] % target_shapes[0] == 0 - other_check = input_shapes[1:] == target_shapes[1:] if len(input_shapes) > 1 else True - return batch_size_check and other_check - else: - return input_shapes == target_shapes - - def _raise_if_invalid_padding(self, input_name, input_tensor, target_shapes, to_pad, dim): - if to_pad < 0: - extra = ", unless you set `dynamic_batch_size=True` during the compilation" if dim == 0 else "" - raise ValueError( - f"Unable to pad {input_name} with shape: {input_tensor.shape} on dimension {dim} as input shapes must be inferior" - f" than the static shapes used for compilation: {target_shapes}{extra}." - ) - - def _pad_to_compiled_shape( - self, inputs: Dict[str, "torch.Tensor"], padding_side: Literal["right", "left"] = "right" - ): - """ - Pads input tensors if they are not in valid shape. - - Args: - inputs (`Dict[str, "torch.Tensor"]`): - Dictionary of input torch tensors. - padding_side (`Literal["right", "left"]`, defaults to "right"): - The side on which to apply the padding. - """ - logger.info(f"Padding input tensors, the padding side is: {padding_side}.") - for input_name, input_tensor in inputs.items(): - target_shapes = self.input_static_shapes[input_name] - padding = () - if self._validate_static_shape(input_tensor.shape, target_shapes): - continue - - # Dimensions other than 0 - for i in reversed(range(1, input_tensor.dim())): - to_pad = target_shapes[i] - input_tensor.size(i) - - self._raise_if_invalid_padding(input_name, input_tensor, target_shapes, to_pad, i) - padding += (0, to_pad) if padding_side == "right" else (to_pad, 0) - - if ( - self.preprocessors is not None - and len(self.preprocessors) > 0 - and self.preprocessors[0].pad_token_id is not None - and input_name == "input_ids" - ): - pad_id = self.preprocessors[0].pad_token_id - else: - pad_id = 0 - - input_tensor = torch.nn.functional.pad(input_tensor, padding, mode="constant", value=pad_id) - - # Pad to batch size: dimension 0 (pad_token_id can't be 0) - padding = (0,) * len(padding) - is_encoder_decoder = getattr(self.config, "is_encoder_decoder", False) - if ( - not is_encoder_decoder - and self.neuron_config.dynamic_batch_size is True - and input_tensor.size(0) % target_shapes[0] == 0 - ): - inputs[input_name] = input_tensor - continue - elif not is_encoder_decoder and self.neuron_config.dynamic_batch_size is True: - target_shape = (input_tensor.size(0) // target_shapes[0] + 1) * target_shapes[0] - to_pad = target_shape - input_tensor.size(0) - else: - to_pad = target_shapes[0] - input_tensor.size(0) - self._raise_if_invalid_padding(input_name, input_tensor, target_shapes, to_pad, 0) - padding += (0, to_pad) if padding_side == "right" else (to_pad, 0) - - pad_id = 1 - inputs[input_name] = torch.nn.functional.pad(input_tensor, padding, mode="constant", value=pad_id) - - return inputs - - @contextmanager - def neuron_padding_manager(self, inputs: Dict[str, "torch.Tensor"]): - inputs = tuple(self._pad_to_compiled_shape(inputs).values()) - yield inputs - - @staticmethod - def remove_padding( - outputs: List[torch.Tensor], - dims: List[int], - indices: List[int], - padding_side: Literal["right", "left"] = "right", - ) -> List[torch.Tensor]: - """ - Removes padding from output tensors. - - Args: - outputs (`List[torch.Tensor]`): - List of torch tensors which are inference output. - dims (`List[int]`): - List of dimensions in which we slice a tensor. - indices (`List[int]`): - List of indices in which we slice a tensor along an axis. - padding_side (`Literal["right", "left"]`, defaults to "right"): - The side on which the padding has been applied. - """ - if len(dims) != len(indices): - raise ValueError(f"The size of `dims`({len(dims)}) and indices`({len(indices)}) must be equal.") - - for dim, indice in zip(dims, indices): - if padding_side == "right": - outputs = [ - torch.index_select(output_tensor, dim, torch.LongTensor(range(indice))) - for output_tensor in outputs - ] - elif padding_side == "left": - outputs = [ - torch.index_select( - output_tensor, - dim, - torch.LongTensor(range(output_tensor.shape[dim] - indice, output_tensor.shape[dim])), - ) - for output_tensor in outputs - ] - - return outputs + self.device = torch.device("cpu") - @property - def is_weights_neff_separated(self) -> bool: - """ - Whether the Neuron model has separated weights and neff graph (by setting `inline_weights_to_neff=False` during the compilation). - """ - return not self.config.neuron.get("inline_weights_to_neff", True) + def to(self, device: Union[str, torch.device]): + if not isinstance(device, torch.device): + device = torch.device(device) + if device.type != self.model.device.type: + raise ValueError(f"Neuron models cannot be moved to {device.type}.") diff --git a/optimum/neuron/modeling_decoder.py b/optimum/neuron/modeling_decoder.py index 3b37a6f78..38575e1bb 100644 --- a/optimum/neuron/modeling_decoder.py +++ b/optimum/neuron/modeling_decoder.py @@ -30,7 +30,7 @@ from ..exporters.neuron.model_configs import * # noqa: F403 from ..exporters.tasks import TasksManager -from ..modeling_base import OptimizedModel +from .modeling_base import NeuronModel from .utils import ModelCacheEntry, hub_neuronx_cache, is_transformers_neuronx_available from .utils.require_utils import requires_transformers_neuronx from .utils.version_utils import check_compiler_compatibility, get_neuronxcc_version @@ -111,7 +111,7 @@ def get_available_cores() -> int: return visible_cores -class NeuronDecoderModel(OptimizedModel): +class NeuronDecoderModel(NeuronModel): """ Base class to convert and run pre-trained transformers decoder models on Neuron devices. diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 3712ffe45..fdc4f50a2 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -39,7 +39,7 @@ from ..exporters.neuron.model_configs import * # noqa: F403 from ..exporters.tasks import TasksManager from ..utils import is_diffusers_available -from .modeling_base import NeuronBaseModel +from .modeling_traced import NeuronTracedModel from .utils import ( DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, DIFFUSION_MODEL_TEXT_ENCODER_NAME, @@ -101,7 +101,7 @@ logger = logging.getLogger(__name__) -class NeuronStableDiffusionPipelineBase(NeuronBaseModel): +class NeuronStableDiffusionPipelineBase(NeuronTracedModel): auto_model_class = StableDiffusionPipeline library_name = "diffusers" base_model_prefix = "neuron_model" @@ -319,7 +319,7 @@ def load_model( submodels.pop("unet") for submodel_name, submodel_path in submodels.items(): if submodel_path is not None and submodel_path.is_file(): - submodels[submodel_name] = NeuronBaseModel.load_model(submodel_path) + submodels[submodel_name] = NeuronTracedModel.load_model(submodel_path) else: submodels[submodel_name] = None submodels["unet"] = torch_neuronx.DataParallel( @@ -331,7 +331,7 @@ def load_model( logger.info("Loading the pipeline without any data parallelism...") for submodel_name, submodel_path in submodels.items(): if submodel_path is not None and submodel_path.is_file(): - submodels[submodel_name] = NeuronBaseModel.load_model(submodel_path) + submodels[submodel_name] = NeuronTracedModel.load_model(submodel_path) else: raise ValueError("You need to pass `data_parallel_mode` to define Neuron Core allocation.") @@ -820,7 +820,7 @@ class _NeuronDiffusionModelPart: def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronBaseModel, + parent_model: NeuronTracedModel, config: Optional[Union[DiffusersPretrainedConfig, PretrainedConfig]] = None, neuron_config: Optional["NeuronDefaultConfig"] = None, model_type: str = "unet", @@ -845,7 +845,7 @@ class NeuronModelTextEncoder(_NeuronDiffusionModelPart): def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronBaseModel, + parent_model: NeuronTracedModel, config: Optional[DiffusersPretrainedConfig] = None, neuron_config: Optional[Dict[str, str]] = None, ): @@ -882,7 +882,7 @@ class NeuronModelUnet(_NeuronDiffusionModelPart): def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronBaseModel, + parent_model: NeuronTracedModel, config: Optional[DiffusersPretrainedConfig] = None, neuron_config: Optional[Dict[str, str]] = None, ): @@ -918,7 +918,7 @@ class NeuronModelVaeEncoder(_NeuronDiffusionModelPart): def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronBaseModel, + parent_model: NeuronTracedModel, config: Optional[DiffusersPretrainedConfig] = None, neuron_config: Optional[Dict[str, str]] = None, ): @@ -935,7 +935,7 @@ class NeuronModelVaeDecoder(_NeuronDiffusionModelPart): def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronBaseModel, + parent_model: NeuronTracedModel, config: Optional[DiffusersPretrainedConfig] = None, neuron_config: Optional[Dict[str, str]] = None, ): diff --git a/optimum/neuron/modeling_seq2seq.py b/optimum/neuron/modeling_seq2seq.py index c212a3848..1332b44c8 100644 --- a/optimum/neuron/modeling_seq2seq.py +++ b/optimum/neuron/modeling_seq2seq.py @@ -37,7 +37,7 @@ from ..exporters.tasks import TasksManager from ..utils.save_utils import maybe_load_preprocessors from .generation import NeuronGenerationMixin -from .modeling_base import NeuronBaseModel +from .modeling_traced import NeuronTracedModel from .utils import ( DECODER_NAME, ENCODER_NAME, @@ -55,7 +55,7 @@ logger = logging.getLogger(__name__) -class NeuronModelForConditionalGeneration(NeuronBaseModel, ABC): +class NeuronModelForConditionalGeneration(NeuronTracedModel, ABC): base_model_prefix = "neuron_model" config_name = "config.json" @@ -531,7 +531,7 @@ class _NeuronSeq2SeqModelPart: def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronBaseModel, + parent_model: NeuronTracedModel, config: Optional["PretrainedConfig"] = None, neuron_config: Optional["NeuronDefaultConfig"] = None, model_type: str = "encoder", @@ -562,7 +562,7 @@ class NeuronEncoder(_NeuronSeq2SeqModelPart): def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronBaseModel, + parent_model: NeuronTracedModel, config: Optional["PretrainedConfig"] = None, neuron_config: Optional[Dict[str, str]] = None, ): @@ -585,7 +585,7 @@ class NeuronDecoder(_NeuronSeq2SeqModelPart): def __init__( self, model: torch.jit._script.ScriptModule, - parent_model: NeuronBaseModel, + parent_model: NeuronTracedModel, config: Optional["PretrainedConfig"] = None, neuron_config: Optional[Dict[str, str]] = None, ): diff --git a/optimum/neuron/modeling_traced.py b/optimum/neuron/modeling_traced.py new file mode 100644 index 000000000..9149bae98 --- /dev/null +++ b/optimum/neuron/modeling_traced.py @@ -0,0 +1,611 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NeuronTracedModel base classe for inference on neuron devices using the same API as Transformers.""" + +import logging +import os +import shutil +from contextlib import contextmanager +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union + +import torch +from huggingface_hub import HfApi, HfFolder, hf_hub_download +from huggingface_hub.utils import is_google_colab +from transformers import AutoConfig, AutoModel + +from ..exporters.neuron import main_export +from ..exporters.neuron.model_configs import * # noqa: F403 +from ..exporters.tasks import TasksManager +from .modeling_base import NeuronModel +from .utils import ( + NEURON_FILE_NAME, + check_if_weights_replacable, + is_neuron_available, + replace_weights, + store_compilation_config, +) +from .utils.hub_neuronx_cache import ModelCacheEntry, build_cache_config, create_hub_compile_cache_proxy +from .utils.import_utils import is_neuronx_available +from .utils.misc import maybe_load_preprocessors +from .utils.version_utils import check_compiler_compatibility, get_neuroncc_version, get_neuronxcc_version + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ..exporters.neuron import NeuronDefaultConfig + +if is_neuron_available(): + NEURON_COMPILER_TYPE = "neuron-cc" + NEURON_COMPILER_VERSION = get_neuroncc_version() + +if is_neuronx_available(): + NEURON_COMPILER_TYPE = "neuronx-cc" + NEURON_COMPILER_VERSION = get_neuronxcc_version() + +logger = logging.getLogger(__name__) + + +class NeuronTracedModel(NeuronModel): + """ + Base class running compiled and optimized models on Neuron devices. + + It implements generic methods for interacting with the Hugging Face Hub as well as compiling vanilla + transformers models to neuron-optimized TorchScript module and export it using `optimum.exporters.neuron` toolchain. + + Class attributes: + - model_type (`str`, *optional*, defaults to `"neuron_model"`) -- The name of the model type to use when + registering the NeuronTracedModel classes. + - auto_model_class (`Type`, *optional*, defaults to `AutoModel`) -- The `AutoModel` class to be represented by the + current NeuronTracedModel class. + + Common attributes: + - model (`torch.jit._script.ScriptModule`) -- The loaded `ScriptModule` compiled for neuron devices. + - config ([`~transformers.PretrainedConfig`]) -- The configuration of the model. + - model_save_dir (`Path`) -- The directory where a neuron compiled model is saved. + By default, if the loaded model is local, the directory where the original model will be used. Otherwise, the + cache directory will be used. + """ + + model_type = "neuron_model" + auto_model_class = AutoModel + library_name = "transformers" + + def __init__( + self, + model: torch.jit._script.ScriptModule, + config: "PretrainedConfig", + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + model_file_name: Optional[str] = None, + preprocessors: Optional[List] = None, + neuron_config: Optional["NeuronDefaultConfig"] = None, + **kwargs, + ): + super().__init__(model, config) + + self.model = model + self.model_file_name = model_file_name or NEURON_FILE_NAME + self.config = config + self.neuron_config = self._neuron_config_init(self.config) if neuron_config is None else neuron_config + self.input_static_shapes = NeuronTracedModel.get_input_static_shapes(self.neuron_config) + self._attributes_init(model_save_dir, preprocessors, **kwargs) + + @staticmethod + def load_model(path: Union[str, Path]) -> torch.jit._script.ScriptModule: + """ + Loads a TorchScript module compiled by neuron(x)-cc compiler. It will be first loaded onto CPU and then moved to + one or multiple [NeuronCore](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/neuroncores-arch.html). + + Args: + path (`Union[str, Path]`): + Path of the compiled model. + """ + if not isinstance(path, Path): + path = Path(path) + + if path.is_file(): + model = torch.jit.load(path) + return model + + def replace_weights(self, weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]] = None): + check_if_weights_replacable(self.config, weights) + if weights is not None: + replace_weights(self.model, weights) + + def _save_pretrained(self, save_directory: Union[str, Path]): + """ + Saves a model and its configuration file to a directory, so that it can be re-loaded using the + [`~optimum.neuron.modeling_traced.NeuronTracedModel.from_pretrained`] class method. + + Args: + save_directory (`Union[str, Path]`): + Directory where to save the model file. + """ + src_path = self.model_save_dir / self.model_file_name + dst_path = Path(save_directory) / self.model_file_name + + shutil.copyfile(src_path, dst_path) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: "PretrainedConfig", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + subfolder: str = "", + local_files_only: bool = False, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + neuron_config: Optional["NeuronDefaultConfig"] = None, + **kwargs, + ) -> "NeuronTracedModel": + model_path = Path(model_id) + + if file_name is None: + if model_path.is_dir(): + neuron_files = list(model_path.glob("*.neuron")) + else: + if isinstance(use_auth_token, bool): + token = HfFolder().get_token() + else: + token = use_auth_token + repo_files = map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)) + pattern = "*.neuron" if subfolder == "" else f"{subfolder}/*.neuron" + neuron_files = [p for p in repo_files if p.match(pattern)] + + if len(neuron_files) == 0: + raise FileNotFoundError(f"Could not find any neuron model file in {model_path}") + elif len(neuron_files) > 1: + raise RuntimeError( + f"Too many neuron model files were found in {model_path}, specify which one to load by using the " + "file_name argument." + ) + else: + file_name = neuron_files[0].name + + # Check compiler compatibility(compiler type and version) of the saved model vs. system. + if hasattr(config, "neuron") and "compiler_type" in config.neuron: + model_compiler_type = config.neuron.get("compiler_type") + model_compiler_version = config.neuron.get("compiler_version") + check_compiler_compatibility(model_compiler_type, model_compiler_version) + + preprocessors = None + if model_path.is_dir(): + model = NeuronTracedModel.load_model(model_path / file_name) + new_model_save_dir = model_path + else: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=file_name, + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + + model = NeuronTracedModel.load_model(model_cache_path) + new_model_save_dir = Path(model_cache_path).parent + + preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) + + # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it + # instead of the path only. + if model_save_dir is None: + model_save_dir = new_model_save_dir + + return cls( + model=model, + config=config, + model_save_dir=model_save_dir, + model_file_name=file_name, + preprocessors=preprocessors, + neuron_config=neuron_config, + ) + + @classmethod + def _from_transformers(cls, *args, **kwargs): + # Deprecate it when optimum uses `_export` as from_pretrained_method in a stable release. + return cls._export(*args, **kwargs) + + @classmethod + def _export( + cls, + model_id: str, + config: "PretrainedConfig", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + library_name: Optional[str] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + compiler_workdir: Optional[Union[str, Path]] = None, + disable_neuron_cache: bool = False, + inline_weights_to_neff: bool = True, + optlevel: str = "2", + subfolder: str = "", + local_files_only: bool = False, + trust_remote_code: bool = False, + task: Optional[str] = None, + auto_cast: Optional[str] = None, + auto_cast_type: Optional[str] = None, + disable_fast_relayout: Optional[bool] = False, + disable_fallback: bool = False, + dynamic_batch_size: bool = False, + **kwargs_shapes, + ) -> "NeuronTracedModel": + """ + Exports a vanilla Transformers model into a neuron-compiled TorchScript Module using `optimum.exporters.neuron.export`. + + Args: + kwargs_shapes (`Dict[str, int]`): + Shapes to use during inference. This argument allows to override the default shapes used during the export. + """ + if task is None: + task = TasksManager.infer_task_from_model(cls.auto_model_class) + task = TasksManager.map_from_synonym(task) + library_name = TasksManager.infer_library_from_model(model_id, subfolder=subfolder, library_name=library_name) + + # Get compilation arguments + if is_neuron_available() and dynamic_batch_size is True and "batch_size" in kwargs_shapes: + kwargs_shapes["batch_size"] = 1 + disable_fallback = True # Turn off the fallback for neuron, otherwise dynamic batching will still fail + auto_cast_type = None if auto_cast is None else auto_cast_type + compiler_kwargs = { + "auto_cast": auto_cast, + "auto_cast_type": auto_cast_type, + "disable_fast_relayout": disable_fast_relayout, + "disable_fallback": disable_fallback, + } + + if ( + not inline_weights_to_neff and not disable_neuron_cache and is_neuronx_available() + ): # TODO: support caching of Inf1 as well + # Check if the cache exists + compilation_config = store_compilation_config( + config=config, + input_shapes=kwargs_shapes, + compiler_kwargs=compiler_kwargs, + dynamic_batch_size=dynamic_batch_size, + compiler_type=NEURON_COMPILER_TYPE, + compiler_version=NEURON_COMPILER_VERSION, + inline_weights_to_neff=inline_weights_to_neff, + optlevel=optlevel, + model_type=getattr(config, "model_type", None), + task=task, + ) + cache_config = build_cache_config(compilation_config) + cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config) + compile_cache = create_hub_compile_cache_proxy() + model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_entry.hash}") + cache_available = compile_cache.download_folder(model_cache_dir, model_cache_dir) + else: + cache_available = False + + # load cache + if cache_available: + try: + neuron_model = cls.from_pretrained(model_cache_dir) + model = TasksManager.get_model_from_task( + task=task, + model_name_or_path=model_id, + subfolder=subfolder, + revision=revision, + framework="pt", + library_name=library_name, + cache_dir=cache_dir, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + ) + # replace weights + neuron_model.replace_weights(weights=model) + return neuron_model + except Exception as e: + logger.warning( + f"Found the cached artifacts but failed to re-load them with error: {e}. \n Falling back to recompilation." + ) + cache_available = False + + # compile + if not cache_available: + # compile + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + main_export( + model_name_or_path=model_id, + output=save_dir_path, + compiler_kwargs=compiler_kwargs, + task=task, + dynamic_batch_size=dynamic_batch_size, + cache_dir=cache_dir, + disable_neuron_cache=disable_neuron_cache, + compiler_workdir=compiler_workdir, + inline_weights_to_neff=inline_weights_to_neff, + optlevel=optlevel, + trust_remote_code=trust_remote_code, + subfolder=subfolder, + revision=revision, + force_download=force_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + do_validation=False, + library_name=library_name, + **kwargs_shapes, + ) + config = AutoConfig.from_pretrained(save_dir_path) + + return cls._from_pretrained(save_dir_path, config, model_save_dir=save_dir) + + def push_to_hub( + self, + save_directory: str, + repository_id: str, + private: Optional[bool] = None, + revision: Optional[str] = None, + use_auth_token: Union[bool, str] = True, + endpoint: Optional[str] = None, + ) -> str: + if isinstance(use_auth_token, str): + huggingface_token = use_auth_token + elif use_auth_token: + huggingface_token = HfFolder.get_token() + else: + raise ValueError("You need to provide `use_auth_token` to be able to push to the hub") + api = HfApi(endpoint=endpoint) + + user = api.whoami(huggingface_token) + if is_google_colab(): + # Only in Google Colab to avoid the warning message + self.git_config_username_and_email(git_email=user["email"], git_user=user["fullname"]) + + api.create_repo( + token=huggingface_token, + repo_id=repository_id, + exist_ok=True, + private=private, + ) + for path, subdirs, files in os.walk(save_directory): + for name in files: + local_file_path = os.path.join(path, name) + hub_file_path = os.path.relpath(local_file_path, save_directory) + api.upload_file( + token=huggingface_token, + repo_id=repository_id, + path_or_fileobj=os.path.join(os.getcwd(), local_file_path), + path_in_repo=hub_file_path, + revision=revision, + ) + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def _attributes_init( + self, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + **kwargs, + ): + """ + Initializes attributes. + """ + self._path_tempdirectory_instance = None + if isinstance(model_save_dir, TemporaryDirectory): + self._path_tempdirectory_instance = model_save_dir + self.model_save_dir = Path(model_save_dir.name) + elif isinstance(model_save_dir, str): + self.model_save_dir = Path(model_save_dir) + else: + self.model_save_dir = model_save_dir + + self.preprocessors = preprocessors if preprocessors is not None else [] + + # Registers the NeuronModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating + # a pipeline https://github.com/huggingface/transformers/blob/3d3204c025b6b5de013e07dd364208e28b4d9589/src/transformers/pipelines/base.py#L940 + AutoConfig.register(self.model_type, AutoConfig) + if hasattr(self.auto_model_class, "register"): + self.auto_model_class.register(AutoConfig, self.__class__) + + @classmethod + def _neuron_config_init(cls, config: "PretrainedConfig") -> "NeuronDefaultConfig": + """ + Builds a `NeuronDefaultConfig` with an instance of the `PretrainedConfig` and the task. + """ + if not hasattr(config, "neuron"): + raise ValueError( + "Unable to identify neuron configuration with the keyword `neuron`, make sure that your config file contains necessary information" + ) + + neuron_config = config.neuron + # Fetch compiler information + compiler_type = neuron_config.get("compiler_type") + compiler_version = neuron_config.get("compiler_version") + + # Fetch mandatory shapes from config + compile_shapes = { + key.replace("static_", ""): value + for (key, value) in config.to_diff_dict().get("neuron").items() + if key.startswith("static_") + } + + # Neuron config constructuor + task = getattr(config, "task") or TasksManager.infer_task_from_model(cls.auto_model_class) + task = TasksManager.map_from_synonym(task) + model_type = neuron_config.get("model_type", None) or config.model_type + neuron_config_constructor = TasksManager.get_exporter_config_constructor( + model_type=model_type, + exporter="neuron", + task=task, + library_name=cls.library_name, + ) + + return neuron_config_constructor( + config, + dynamic_batch_size=neuron_config.get("dynamic_batch_size", False), + compiler_type=compiler_type, + compiler_version=compiler_version, + **compile_shapes, + ) + + @classmethod + def get_input_static_shapes(cls, neuron_config: "NeuronDefaultConfig") -> Dict[str, int]: + """ + Gets a dictionary of inputs with their valid static shapes. + """ + axes = neuron_config._axes + input_static_shapes = { + name: value.shape + for name, value in neuron_config.generate_dummy_inputs(return_tuple=False, **axes).items() + } + return input_static_shapes + + def _validate_static_shape(self, input_shapes: List[int], target_shapes: List[int]) -> bool: + """ + Checks if a input needs to be padded. + """ + if self.neuron_config.dynamic_batch_size is True: + batch_size_check = input_shapes[0] % target_shapes[0] == 0 + other_check = input_shapes[1:] == target_shapes[1:] if len(input_shapes) > 1 else True + return batch_size_check and other_check + else: + return input_shapes == target_shapes + + def _raise_if_invalid_padding(self, input_name, input_tensor, target_shapes, to_pad, dim): + if to_pad < 0: + extra = ", unless you set `dynamic_batch_size=True` during the compilation" if dim == 0 else "" + raise ValueError( + f"Unable to pad {input_name} with shape: {input_tensor.shape} on dimension {dim} as input shapes must be inferior" + f" than the static shapes used for compilation: {target_shapes}{extra}." + ) + + def _pad_to_compiled_shape( + self, inputs: Dict[str, "torch.Tensor"], padding_side: Literal["right", "left"] = "right" + ): + """ + Pads input tensors if they are not in valid shape. + + Args: + inputs (`Dict[str, "torch.Tensor"]`): + Dictionary of input torch tensors. + padding_side (`Literal["right", "left"]`, defaults to "right"): + The side on which to apply the padding. + """ + logger.info(f"Padding input tensors, the padding side is: {padding_side}.") + for input_name, input_tensor in inputs.items(): + target_shapes = self.input_static_shapes[input_name] + padding = () + if self._validate_static_shape(input_tensor.shape, target_shapes): + continue + + # Dimensions other than 0 + for i in reversed(range(1, input_tensor.dim())): + to_pad = target_shapes[i] - input_tensor.size(i) + + self._raise_if_invalid_padding(input_name, input_tensor, target_shapes, to_pad, i) + padding += (0, to_pad) if padding_side == "right" else (to_pad, 0) + + if ( + self.preprocessors is not None + and len(self.preprocessors) > 0 + and self.preprocessors[0].pad_token_id is not None + and input_name == "input_ids" + ): + pad_id = self.preprocessors[0].pad_token_id + else: + pad_id = 0 + + input_tensor = torch.nn.functional.pad(input_tensor, padding, mode="constant", value=pad_id) + + # Pad to batch size: dimension 0 (pad_token_id can't be 0) + padding = (0,) * len(padding) + is_encoder_decoder = getattr(self.config, "is_encoder_decoder", False) + if ( + not is_encoder_decoder + and self.neuron_config.dynamic_batch_size is True + and input_tensor.size(0) % target_shapes[0] == 0 + ): + inputs[input_name] = input_tensor + continue + elif not is_encoder_decoder and self.neuron_config.dynamic_batch_size is True: + target_shape = (input_tensor.size(0) // target_shapes[0] + 1) * target_shapes[0] + to_pad = target_shape - input_tensor.size(0) + else: + to_pad = target_shapes[0] - input_tensor.size(0) + self._raise_if_invalid_padding(input_name, input_tensor, target_shapes, to_pad, 0) + padding += (0, to_pad) if padding_side == "right" else (to_pad, 0) + + pad_id = 1 + inputs[input_name] = torch.nn.functional.pad(input_tensor, padding, mode="constant", value=pad_id) + + return inputs + + @contextmanager + def neuron_padding_manager(self, inputs: Dict[str, "torch.Tensor"]): + inputs = tuple(self._pad_to_compiled_shape(inputs).values()) + yield inputs + + @staticmethod + def remove_padding( + outputs: List[torch.Tensor], + dims: List[int], + indices: List[int], + padding_side: Literal["right", "left"] = "right", + ) -> List[torch.Tensor]: + """ + Removes padding from output tensors. + + Args: + outputs (`List[torch.Tensor]`): + List of torch tensors which are inference output. + dims (`List[int]`): + List of dimensions in which we slice a tensor. + indices (`List[int]`): + List of indices in which we slice a tensor along an axis. + padding_side (`Literal["right", "left"]`, defaults to "right"): + The side on which the padding has been applied. + """ + if len(dims) != len(indices): + raise ValueError(f"The size of `dims`({len(dims)}) and indices`({len(indices)}) must be equal.") + + for dim, indice in zip(dims, indices): + if padding_side == "right": + outputs = [ + torch.index_select(output_tensor, dim, torch.LongTensor(range(indice))) + for output_tensor in outputs + ] + elif padding_side == "left": + outputs = [ + torch.index_select( + output_tensor, + dim, + torch.LongTensor(range(output_tensor.shape[dim] - indice, output_tensor.shape[dim])), + ) + for output_tensor in outputs + ] + + return outputs + + @property + def is_weights_neff_separated(self) -> bool: + """ + Whether the Neuron model has separated weights and neff graph (by setting `inline_weights_to_neff=False` during the compilation). + """ + return not self.config.neuron.get("inline_weights_to_neff", True) diff --git a/optimum/neuron/pipelines/transformers/base.py b/optimum/neuron/pipelines/transformers/base.py index a9be42b40..a91baae2b 100644 --- a/optimum/neuron/pipelines/transformers/base.py +++ b/optimum/neuron/pipelines/transformers/base.py @@ -34,8 +34,7 @@ from transformers.feature_extraction_utils import PreTrainedFeatureExtractor from transformers.onnx.utils import get_preprocessor -from optimum.modeling_base import OptimizedModel -from optimum.neuron.modeling_base import NeuronBaseModel +from optimum.neuron.modeling_base import NeuronModel from optimum.neuron.pipelines.transformers.sentence_transformers import ( FeatureExtractionPipeline, is_sentence_transformer_model, @@ -134,7 +133,7 @@ def load_pipeline( model, export=export, **compiler_args, **input_shapes, **hub_kwargs, **kwargs ) # uses neuron model - elif isinstance(model, (NeuronBaseModel, NeuronModelForCausalLM)): + elif isinstance(model, NeuronModel): if tokenizer is None and load_tokenizer: for preprocessor in model.preprocessors: if isinstance(preprocessor, (PreTrainedTokenizer, PreTrainedTokenizerFast)): @@ -142,7 +141,7 @@ def load_pipeline( break if tokenizer is None: raise ValueError( - "Could not automatically find a tokenizer for the NeuronBaseModel, you must pass a tokenizer explicitly" + "Could not automatically find a tokenizer for the NeuronModel, you must pass a tokenizer explicitly" ) if feature_extractor is None and load_feature_extractor: for preprocessor in model.preprocessors: @@ -165,7 +164,7 @@ def load_pipeline( def pipeline( task: str = None, - model: Optional[Union[str, NeuronBaseModel]] = None, + model: Optional[Union[str, NeuronModel]] = None, tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, use_fast: bool = True, @@ -195,7 +194,7 @@ def pipeline( if isinstance(model, str): config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **kwargs) hub_kwargs["_commit_hash"] = config._commit_hash - elif isinstance(model, (PreTrainedModel, OptimizedModel)): + elif isinstance(model, (PreTrainedModel, NeuronModel)): config = model.config if export: @@ -279,5 +278,6 @@ def pipeline( use_fast=use_fast, batch_size=batch_size, pipeline_class=NEURONX_SUPPORTED_TASKS[task]["impl"], + device=model.device, **kwargs, ) diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py index 57cca94d5..bd68d67f8 100644 --- a/tests/inference/inference_utils.py +++ b/tests/inference/inference_utils.py @@ -110,7 +110,7 @@ def _setup(self, model_args: Dict): dynamic_batch_size = model_args.get("dynamic_batch_size", False) if model_arch_and_params not in self.neuron_model_dirs: - # model_args will contain kwargs to pass to NeuronBaseModel.from_pretrained() + # model_args will contain kwargs to pass to NeuronTracedModel.from_pretrained() model_args.pop("test_name") model_args.pop("model_arch") model_args.pop("dynamic_batch_size", None) diff --git a/tests/inference/test_modeling.py b/tests/inference/test_modeling.py index fab907d7d..cfdaab941 100644 --- a/tests/inference/test_modeling.py +++ b/tests/inference/test_modeling.py @@ -38,7 +38,6 @@ from transformers.onnx.utils import get_preprocessor from optimum.neuron import ( - NeuronBaseModel, NeuronModelForFeatureExtraction, NeuronModelForMaskedLM, NeuronModelForMultipleChoice, @@ -46,6 +45,7 @@ NeuronModelForSentenceTransformers, NeuronModelForSequenceClassification, NeuronModelForTokenClassification, + NeuronTracedModel, pipeline, ) from optimum.neuron.utils import NEURON_FILE_NAME, is_neuron_available, is_neuronx_available @@ -79,12 +79,12 @@ class NeuronModelIntegrationTest(NeuronModelIntegrationTestMixin): TINY_MODEL_REMOTE = "Jingya/tiny-random-bert-remote-code" def test_load_model_from_local_path(self): - model = NeuronBaseModel.from_pretrained(self.local_model_path) + model = NeuronTracedModel.from_pretrained(self.local_model_path) self.assertIsInstance(model.model, torch.jit._script.ScriptModule) self.assertIsInstance(model.config, PretrainedConfig) def test_load_model_from_hub(self): - model = NeuronBaseModel.from_pretrained(self.neuron_model_id) + model = NeuronTracedModel.from_pretrained(self.neuron_model_id) self.assertIsInstance(model.model, torch.jit._script.ScriptModule) self.assertIsInstance(model.config, PretrainedConfig) @@ -96,9 +96,9 @@ def test_load_model_from_hub_subfolder(self): self.assertIsInstance(model.config, PretrainedConfig) def test_load_model_from_cache(self): - _ = NeuronBaseModel.from_pretrained(self.neuron_model_id) # caching + _ = NeuronTracedModel.from_pretrained(self.neuron_model_id) # caching - model = NeuronBaseModel.from_pretrained(self.neuron_model_id, local_files_only=True) + model = NeuronTracedModel.from_pretrained(self.neuron_model_id, local_files_only=True) self.assertIsInstance(model.model, torch.jit._script.ScriptModule) self.assertIsInstance(model.config, PretrainedConfig) @@ -109,15 +109,15 @@ def test_load_model_from_empty_cache(self): if os.path.exists(dirpath) and os.path.isdir(dirpath): shutil.rmtree(dirpath) with self.assertRaises(Exception): - _ = NeuronBaseModel.from_pretrained(self.neuron_model_id, local_files_only=True) + _ = NeuronTracedModel.from_pretrained(self.neuron_model_id, local_files_only=True) def test_load_model_from_hub_without_neuron_model(self): with self.assertRaises(FileNotFoundError): - NeuronBaseModel.from_pretrained(self.FAIL_NEURON_MODEL_ID) + NeuronTracedModel.from_pretrained(self.FAIL_NEURON_MODEL_ID) def test_save_model(self): with tempfile.TemporaryDirectory() as tmpdirname: - model = NeuronBaseModel.from_pretrained(self.local_model_path) + model = NeuronTracedModel.from_pretrained(self.local_model_path) model.save_pretrained(tmpdirname) # folder contains all config files and neuron exported model folder_contents = os.listdir(tmpdirname) diff --git a/tests/pipelines/test_encoder_pipelines.py b/tests/pipelines/test_encoder_pipelines.py index 3ec69a35d..5b2a40076 100644 --- a/tests/pipelines/test_encoder_pipelines.py +++ b/tests/pipelines/test_encoder_pipelines.py @@ -1,4 +1,4 @@ -from optimum.neuron import NeuronBaseModel +from optimum.neuron import NeuronTracedModel from optimum.neuron.pipelines import pipeline from optimum.neuron.utils.testing_utils import is_inferentia_test @@ -7,4 +7,4 @@ def test_export_no_parameters(std_text_task, inf_encoder_model): p = pipeline(std_text_task, inf_encoder_model, export=True) assert p.task == std_text_task - assert isinstance(p.model, NeuronBaseModel) + assert isinstance(p.model, NeuronTracedModel)