From c114fc83f72983647415b441433480ef9c0d9f92 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Tue, 30 Jan 2024 17:03:56 +0100 Subject: [PATCH] TGI: export model if configuration is cached (#445) * feat(cache): use one registry per optimum version * feat(registry): use model_type as primary key This allows to identify cached configurations that can be applied to models that differ only by their weights, like meta-llama/Llama-2-7b-hf and meta-llama/Llama-2-7b-chat-hf. This also allows to lookup cached configurations for local model folders containing a model config. * doc(cache): fix image link * doc(cache): add cache lookup * refactor(decoder): add get_export_config helper * feat(tgi): export model if cached * review: addressing code comments * wip * review: address doc comments --- docs/source/benchmarks/inferentia-llama2.mdx | 2 +- docs/source/guides/cache_system.mdx | 108 +++++++++++++++--- optimum/neuron/modeling_decoder.py | 90 ++++++++++----- optimum/neuron/utils/__init__.py | 2 +- optimum/neuron/utils/hub_neuronx_cache.py | 72 ++++++++---- tests/cache/test_neuronx_cache.py | 3 +- text-generation-inference/README.md | 88 +++++++++++--- .../server/text_generation_server/cli.py | 4 +- .../text_generation_server/generator.py | 21 ++-- .../server/text_generation_server/model.py | 100 ++++++++++++++-- .../server/text_generation_server/server.py | 11 +- 11 files changed, 386 insertions(+), 115 deletions(-) diff --git a/docs/source/benchmarks/inferentia-llama2.mdx b/docs/source/benchmarks/inferentia-llama2.mdx index c7c6d7fbe..1601ee2a9 100644 --- a/docs/source/benchmarks/inferentia-llama2.mdx +++ b/docs/source/benchmarks/inferentia-llama2.mdx @@ -48,7 +48,7 @@ while 768 is more typical of a Retrieval Augmented Generation (RAG) use-case. Encoding time is expressed in **seconds**. -![Llama2 inferentia2 encoding-time](https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/benchmarks/inferentia-llama2/encoding-times.png "Encoding time") +![Llama2 inferentia2 encoding-time](https://raw.githubusercontent.com/huggingface/optimum-neuron/main/docs/assets/benchmarks/inferentia-llama2/encoding_times.png "Encoding time") We can see that all deployed models exhibit excellent response times, even for long contexts. diff --git a/docs/source/guides/cache_system.mdx b/docs/source/guides/cache_system.mdx index 84ed2dff3..087fac390 100644 --- a/docs/source/guides/cache_system.mdx +++ b/docs/source/guides/cache_system.mdx @@ -13,35 +13,111 @@ specific language governing permissions and limitations under the License. # Neuron Model Cache The Neuron Model Cache is a remote cache for compiled Neuron models in the `neff` format. -It is integrated into the [`NeuronTrainer` and `NeuronModelForCausalLM`] classes to enable loading pretrained models from the cache instead of compiling them locally. +It is integrated into the `NeuronTrainer` and `NeuronModelForCausalLM` classes to enable loading pretrained models from the cache instead of compiling them locally. + +**Note: it is not available for models exported using any other NeuronModelXX classes, that use a different export mechanism.** The Neuron Model Cache is hosted on the [Hugging Face Hub](https://huggingface.co/aws-neuron/optimum-neuron-cache) and includes compiled files for all popular and supported `optimum-neuron` pre-trained models. -When loading a Transformers or Diffusion model, it needs to be compiled to neuron format with [`torch-neuronx`](https://github.com/aws-neuron/aws-neuron-samples/tree/master/torch-neuronx), -in order to run on Neuron platforms. -The compilation produces several compilation files stored in a local directory, usually `/var/tmp/neuron-compile-cache`. -This means that every time you train or export a model on a new host, you need to recompile it, which takes a lot of time. +Before training a Transformers or Diffusion model or loading a NeuronModelForCausalLM on Neuron platforms, it needs to be exported to neuron format +with [`torch-neuronx`](https://github.com/aws-neuron/aws-neuron-samples/tree/master/torch-neuronx). + +When exporting a model, [`torch-neuronx`](https://github.com/aws-neuron/aws-neuron-samples/tree/master/torch-neuronx) will: + +- convert it to a set of [XLA](https://github.com/pytorch/xla/) subgraphs, +- compile each subgraph with the neuronx compiler into a Neuron Executable File Format (NEFF) binary file. + +The first step is relatively fast, but the compilation takes a lot of time. +To avoid recompiling all NEFF files every time a model is loaded on a NeuronX host, [`torch-neuronx`](https://github.com/aws-neuron/aws-neuron-samples/tree/master/torch-neuronx) + stores NEFF files in a local directory, usually `/var/tmp/neuron-compile-cache`. + +However, this local cache is not shared between platforms, which means that every time you train or export a model on a new host, you need to recompile it. + +We created the Neuron Model Cache to solve this limitation by providing a public repository of precompiled model graphs. + +Note: we also support the creation of private, secured, remote model cache. -We created the Neuron Model Cache to solve this limitation by providing a public cache of precompiled available models and a private cache to create your private, secured, remote model cache. +## How to use the Neuron model cache -## How the caching system works +The public model cache will be used when you use the `NeuronTrainer` or `NeuronModelForCausalLM` classes. There are no additional changes needed. -### Hash computation +When exporting a model to neuron format, `optimum-neuron` will simply look for cached NEFF files in the hub repository during the compilation of the +model subgraphs. -Many factors can trigger compilation among which: +If the NEFF files are cached, they will be fetched from the hub and directly loaded instead of being recompiled. -- The input shapes, -- The precision of the model, full-precision or bf16, +## How caching works + +The Optimum Neuron Cache is built on top of the [NeuronX compiler cache](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-features/neuron-caching.html). + +It is important to understand that the cache operates on NEFF binaries, and not on the model itself. + +As explained previously, each model exported to Neuron using the `NeuronTrainer` or `NeuronModelForCausalLM` is composed of [XLA](https://github.com/pytorch/xla/) subgraphs. + +Each subgraph is unique, and results from the combination of: +- the `transformers` or `transformers_neuronx` python modeling code, +- the `transformers` model config, +- the `input_shapes` selected during the export, +- The precision of the model, full-precision, fp16 or bf16. + +When compiling a subgraph to a NEFF file, other parameters influence the result: - The version of the Neuron X compiler, -- The number of Neuron cores used. +- The number of Neuron cores used, +- The compilation parameters (such as the optimization level). + +All these parameters are combined together to create a unique hash that identifies a NEFF file. -These parameters are used to compute a hash that uniquely identifies each compilation file. +This has two very important consequences: +- it is only when actually exporting a model that the associated NEFF files can be identified, +- even a small change in the model configuration will lead to a different set of NEFF files. -**It is important to keep in mind that even a small change in the model configuration will trigger a recompilation.** +It is therefore very difficult to know in advance if the NEFFs associated to a specific model configuration are cached. + +## Neuron model cache lookup (inferentia only) + +The neuron cache lookup is a feature allowing users to look for compatible cached model configurations before exporting +a model for inference. + +It is based on a dedicated registry composed of stored cached configurations. + +Cached model configurations are stored as entries under a specific subfolder in the Neuron Model Cache: + +``` +neuronxcc-2.12.54.0+f631c2365 +├── 0_REGISTRY +│ └── 0.0.18 +│ └── llama +│ └── meta-llama +│ └── Llama-2-7b-chat-hf +│ └── 54c1f6689cd88f246fce.json +``` + +Each entry corresponds to the combination of a model configuration and its export parameters: this is as close as we can get to +uniquely identify the exported model. + +You can use the `optimum-cli` to lookup for compatible cached entries by passing it a hub model_id or the path to a file +containing a model `config.json`. + +```shell +$ optimum-cli neuron cache lookup meta-llama/Llama-2-7b-chat-hf + +*** 1 entrie(s) found in cache for meta-llama/Llama-2-7b-chat-hf *** + +task: text-generation +batch_size: 1 +num_cores: 24 +auto_cast_type: fp16 +sequence_length: 2048 +compiler_type: neuronx-cc +compiler_version: 2.12.54.0+f631c2365 +checkpoint_id: meta-llama/Llama-2-7b-chat-hf +checkpoint_revision: c1b0db933684edbfe29a06fa47eb19cc48025e93 +``` -### How to use the Neuron model cache +**Note that even if compatible cached entries exist, this does not always guarantee that the model will not be recompiled during export +if you modified the compilation parameters or updated the neuronx packages.** -The public model cache will be used when you use the [`NeuronTrainer` or `NeuronModelForCausalLM`] classes. There are no additional changes needed. +## Advanced usage (trainium only) ### How to use a private Neuron model cache (trainium only) diff --git a/optimum/neuron/modeling_decoder.py b/optimum/neuron/modeling_decoder.py index 10b2ce516..1fe713bca 100644 --- a/optimum/neuron/modeling_decoder.py +++ b/optimum/neuron/modeling_decoder.py @@ -14,6 +14,7 @@ # limitations under the License. """Base class for text-generation model architectures on neuron devices.""" +import copy import logging import os import shutil @@ -28,7 +29,7 @@ from ..exporters.neuron.model_configs import * # noqa: F403 from ..exporters.tasks import TasksManager from ..modeling_base import OptimizedModel -from .utils import CacheEntry, hub_neuronx_cache, is_transformers_neuronx_available +from .utils import ModelCacheEntry, hub_neuronx_cache, is_transformers_neuronx_available from .utils.require_utils import requires_transformers_neuronx from .utils.version_utils import check_compiler_compatibility, get_neuronxcc_version @@ -126,7 +127,7 @@ def __init__( os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags + " --model-type=transformer" checkpoint_id = neuron_config.get("checkpoint_id", None) # Only create a cache entry if the model comes from the hub - cache_entry = None if checkpoint_id is None else CacheEntry(neuron_config["checkpoint_id"], neuron_config) + cache_entry = None if checkpoint_id is None else ModelCacheEntry(checkpoint_id, config) with hub_neuronx_cache(entry=cache_entry): neuronx_model.to_neuron() os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags @@ -170,14 +171,7 @@ def _create_checkpoint( return checkpoint_dir @classmethod - @requires_transformers_neuronx - def _from_transformers(cls, *args, **kwargs): - # Deprecate it when optimum uses `_export` as from_pretrained_method in a stable release. - return cls._export(*args, **kwargs) - - @classmethod - @requires_transformers_neuronx - def _export( + def get_export_config( cls, model_id: str, config: "PretrainedConfig", @@ -187,23 +181,11 @@ def _export( batch_size: Optional[int] = None, sequence_length: Optional[int] = None, num_cores: Optional[int] = None, - auto_cast_type: Optional[str] = "fp32", - **kwargs, - ) -> "NeuronDecoderModel": - if not os.path.isdir("/sys/class/neuron_device/"): - raise SystemError("Decoder models can only be exported on a neuron platform.") - + auto_cast_type: Optional[str] = None, + ) -> "PretrainedConfig": if task is None: task = TasksManager.infer_task_from_model(cls.auto_model_class) - # Instantiate the transformers model checkpoint - checkpoint_dir = cls._create_checkpoint( - model_id, - task=task, - revision=revision, - **kwargs, - ) - if os.path.isdir(model_id): checkpoint_id = None checkpoint_revision = None @@ -223,9 +205,15 @@ def _export( if num_cores is None: # Use all available cores num_cores = len(os.listdir("/sys/class/neuron_device/")) * 2 - - # Update the config - config.neuron = { + if auto_cast_type is None: + auto_cast_type = "fp32" + if config.torch_dtype == "float16": + auto_cast_type = "fp16" + elif config.torch_dtype == "bfloat16": + auto_cast_type = "bf16" + + new_config = copy.deepcopy(config) + new_config.neuron = { "task": task, "batch_size": batch_size, "num_cores": num_cores, @@ -236,6 +224,52 @@ def _export( "checkpoint_id": checkpoint_id, "checkpoint_revision": checkpoint_revision, } + return new_config + + @classmethod + @requires_transformers_neuronx + def _from_transformers(cls, *args, **kwargs): + # Deprecate it when optimum uses `_export` as from_pretrained_method in a stable release. + return cls._export(*args, **kwargs) + + @classmethod + @requires_transformers_neuronx + def _export( + cls, + model_id: str, + config: "PretrainedConfig", + use_auth_token: Optional[str] = None, + revision: Optional[str] = None, + task: Optional[str] = None, + batch_size: Optional[int] = None, + sequence_length: Optional[int] = None, + num_cores: Optional[int] = None, + auto_cast_type: Optional[str] = "fp32", + **kwargs, + ) -> "NeuronDecoderModel": + if not os.path.isdir("/sys/class/neuron_device/"): + raise SystemError("Decoder models can only be exported on a neuron platform.") + + # Update the config + new_config = cls.get_export_config( + model_id, + config, + use_auth_token=use_auth_token, + revision=revision, + task=task, + batch_size=batch_size, + sequence_length=sequence_length, + num_cores=num_cores, + auto_cast_type=auto_cast_type, + ) + + # Instantiate the transformers model checkpoint + checkpoint_dir = cls._create_checkpoint( + model_id, + task=new_config.neuron["task"], + revision=revision, + **kwargs, + ) # Try to reload the generation config (if any) generation_config = None @@ -244,7 +278,7 @@ def _export( except OSError: pass - return cls(config, checkpoint_dir, generation_config=generation_config) + return cls(new_config, checkpoint_dir, generation_config=generation_config) @classmethod def _get_neuron_dirs(cls, model_path: Union[str, Path]) -> Tuple[str, str]: diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 59eb94b26..4bce752b7 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -24,7 +24,7 @@ ENCODER_NAME, NEURON_FILE_NAME, ) -from .hub_neuronx_cache import CacheEntry, get_hub_cached_entries, hub_neuronx_cache, synchronize_hub_cache +from .hub_neuronx_cache import ModelCacheEntry, get_hub_cached_entries, hub_neuronx_cache, synchronize_hub_cache from .import_utils import ( is_accelerate_available, is_neuron_available, diff --git a/optimum/neuron/utils/hub_neuronx_cache.py b/optimum/neuron/utils/hub_neuronx_cache.py index 0eb484aa8..22ccd64f7 100644 --- a/optimum/neuron/utils/hub_neuronx_cache.py +++ b/optimum/neuron/utils/hub_neuronx_cache.py @@ -17,12 +17,12 @@ import logging import os from contextlib import contextmanager -from dataclasses import dataclass from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Dict, Optional +from typing import Optional from huggingface_hub import HfApi, get_token +from transformers import AutoConfig, PretrainedConfig from ..version import __version__ from .import_utils import is_neuronx_available @@ -201,24 +201,47 @@ def _create_hub_compile_cache_proxy( return CompileCacheHfProxy(cache_repo_id, default_cache, endpoint=endpoint, token=token) -@dataclass -class CacheEntry: - key: str - metadata: Dict[str, Any] +class ModelCacheEntry: + """A class describing a model cache entry + + Args: + model_id (`str`): + The model id, used as a key for the cache entry. + config (`transformers.PretrainedConfig`): + The configuration of the model. + + """ + + def __init__(self, model_id: str, config: PretrainedConfig): + self.model_id = model_id + # Remove keys set to default values + self.config = config.to_diff_dict() + excluded_keys = ["_name_or_path", "transformers_version"] + for key in excluded_keys: + self.config.pop(key, None) + + def to_json(self) -> str: + return json.dumps(self.config) + + @property + def hash(self): + hash_gen = hashlib.sha512() + hash_gen.update(self.to_json().encode("utf-8")) + return str(hash_gen.hexdigest())[:20] -REGISTRY_FOLDER = "0_REGISTRY" +REGISTRY_FOLDER = f"0_REGISTRY/{__version__}" @requires_torch_neuronx @contextmanager -def hub_neuronx_cache(entry: Optional[CacheEntry] = None): +def hub_neuronx_cache(entry: Optional[ModelCacheEntry] = None): """A context manager to activate the Hugging Face Hub proxy compiler cache. Args: - entry (`Optional[CacheEntry]`, defaults to `None`): - An optional dataclass containing metadata associated with the cache session. - Will create a dedicated entry in the cache registry. + entry (`Optional[ModelCacheEntry]`, defaults to `None`): + An optional dataclass containing metadata associated with the model corresponding + to the cache session. Will create a dedicated entry in the cache registry. """ def hf_create_compile_cache(cache_url): @@ -239,17 +262,14 @@ def hf_create_compile_cache(cache_url): else: # Create cache entry in local cache: it can be later synchronized with the hub cache registry_path = default_cache.get_cache_dir_with_cache_key(REGISTRY_FOLDER) - entry_path = f"{registry_path}/{entry.key}" - metadata_json = json.dumps(entry.metadata, indent=4) - hash_gen = hashlib.sha512() - hash_gen.update(metadata_json.encode("utf-8")) - metadata_key = str(hash_gen.hexdigest())[:20] - metadata_path = f"{entry_path}/{metadata_key}.json" - if not default_cache.exists(metadata_path): + model_type = entry.config["model_type"] + entry_path = f"{registry_path}/{model_type}/{entry.model_id}" + config_path = f"{entry_path}/{entry.hash}.json" + if not default_cache.exists(config_path): oldmask = os.umask(000) Path(entry_path).mkdir(parents=True, exist_ok=True) os.umask(oldmask) - default_cache.upload_string_to_file(metadata_path, metadata_json) + default_cache.upload_string_to_file(config_path, entry.to_json()) finally: patch_everywhere("create_compile_cache", create_compile_cache, "libneuronxla") @@ -267,8 +287,6 @@ def synchronize_hub_cache(cache_repo_id: Optional[str] = None): def get_hub_cached_entries(model_id: str, cache_repo_id: Optional[str] = None): - if os.path.isdir(model_id): - raise ValueError("Please pass a hub model_id in the form /.") if cache_repo_id is None: cache_repo_id = get_hub_cache() # Allocate a Hub API with refreshed information (required for tests altering the env) @@ -276,12 +294,20 @@ def get_hub_cached_entries(model_id: str, cache_repo_id: Optional[str] = None): token = get_token() api = HfApi(endpoint=endpoint, token=token) repo_files = api.list_repo_files(cache_repo_id) - registry_pattern = REGISTRY_FOLDER + "/" + model_id + # Get the config corresponding to the model + target_entry = ModelCacheEntry(model_id, (AutoConfig.from_pretrained(model_id))) + # Extract model type: it will be used as primary key for lookup + model_type = target_entry.config["model_type"] + registry_pattern = REGISTRY_FOLDER + "/" + model_type model_files = [path for path in repo_files if registry_pattern in path] model_entries = [] with TemporaryDirectory() as tmpdir: for model_path in model_files: local_path = api.hf_hub_download(cache_repo_id, model_path, local_dir=tmpdir) with open(local_path) as f: - model_entries.append(json.load(f)) + entry_config = json.load(f) + # All config parameters but neuron config must match + neuron_config = entry_config.pop("neuron") + if entry_config == target_entry.config: + model_entries.append(neuron_config) return model_entries diff --git a/tests/cache/test_neuronx_cache.py b/tests/cache/test_neuronx_cache.py index 86dd29e75..0d3a5656a 100644 --- a/tests/cache/test_neuronx_cache.py +++ b/tests/cache/test_neuronx_cache.py @@ -85,7 +85,8 @@ def check_decoder_generation(model): def get_local_cached_files(cache_path, extension="*"): - return glob.glob(f"{cache_path}/**/*/*.{extension}", recursive=True) + links = glob.glob(f"{cache_path}/**/*/*.{extension}", recursive=True) + return [link for link in links if os.path.isfile(link)] def check_cache_entry(model, cache_path): diff --git a/text-generation-inference/README.md b/text-generation-inference/README.md index 455ab879b..704fad430 100644 --- a/text-generation-inference/README.md +++ b/text-generation-inference/README.md @@ -39,29 +39,32 @@ docker run ghcr.io/huggingface/neuronx-tgi:latest ``` -Note that we export a shared volume mounted as `/data` in the container: this is where the hub model will be cached to -speed up further instantiations of the service. - -Note also that all neuron devices have to be explicitly made visible to the container. - -For instance, if your instance has 12 neuron devices, the launch command becomes: +If your instance has 12 neuron devices, the launch command becomes: ``` docker run -p 8080:80 \ @@ -78,10 +81,63 @@ docker run -p 8080:80 \ --device=/dev/neuron9 \ --device=/dev/neuron10 \ --device=/dev/neuron11 \ - ... + -e HF_TOKEN=${HF_TOKEN} \ + ghcr.io/huggingface/neuronx-tgi:latest \ + +``` + + +### Using a neuron model from the 🤗 [HuggingFace Hub](https://huggingface.co/aws-neuron) (recommended) + +There are plenty of already exported neuron models on the hub, under the [aws-neuron](https://huggingface.co/aws-neuron) organization. + +The snippet below shows how you can deploy a service from a hub neuron model: + +``` +docker run -p 8080:80 \ + -v $(pwd)/data:/data \ + --device=/dev/neuron0 \ + -e HF_TOKEN=${HF_TOKEN} \ + ghcr.io/huggingface/neuronx-tgi:latest \ + --model-id aws-neuron/Llama-2-7b-hf-neuron-budget \ + --max-concurrent-requests 1 \ + --max-input-length 1024 \ + --max-total-tokens 2048 \ + --max-batch-prefill-tokens 1024 \ + --max-batch-total-tokens 2048 +``` + +### Using a standard model from the 🤗 [HuggingFace Hub](https://huggingface.co/aws-neuron) + + +We maintain a Neuron Model Cache of the most popular architecture and deployment parameters under [aws-neuron/optimum-neuron-cache](https://huggingface.co/aws-neuron/optimum-neuron-cache). + +If you just want to try the service quickly using a model that has not bee exported yet, it is thus still +possible to export it dynamically, pending some conditions: +- you must specify the export parameters when launching the service (or use default parameters), +- the model configuration must be cached. + +The snippet below shows how you can deploy a service from a hub standard model: + +``` +docker run -p 8080:80 \ + -v $(pwd)/data:/data \ + --device=/dev/neuron0 \ + -e HF_TOKEN=${HF_TOKEN} \ + -e HF_BATCH_SIZE=1 \ + -e HF_SEQUENCE_LENGTH=1024 \ + -e HF_AUTO_CAST_TYPE="fp16" \ + -e HF_NUM_CORES=2 \ + ghcr.io/huggingface/neuronx-tgi:latest \ + --model-id aws-neuron/Llama-2-7b-hf-neuron-budget \ + --max-concurrent-requests 1 \ + --max-input-length 512 \ + --max-total-tokens 1024 \ + --max-batch-prefill-tokens 512 \ + --max-batch-total-tokens 1024 ``` -### From a local path +### Using a model exported to a local path Alternatively, you can first [export the model to neuron format](https://huggingface.co/docs/optimum-neuron/main/en/guides/models#configuring-the-export-of-a-generative-model) locally, and deploy the service inside the shared volume: diff --git a/text-generation-inference/server/text_generation_server/cli.py b/text-generation-inference/server/text_generation_server/cli.py index a2b22c4ae..fbcbc3d4e 100644 --- a/text-generation-inference/server/text_generation_server/cli.py +++ b/text-generation-inference/server/text_generation_server/cli.py @@ -56,9 +56,11 @@ def serve( logger.warning("'trust_remote_code' argument is not supported and will be ignored.") # Import here after the logger is added to log potential import exceptions + from .model import fetch_model from .server import serve - serve(model_id, revision, uds_path) + model_path = fetch_model(model_id, revision) + serve(model_path, uds_path) @app.command() diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 8a61f20ad..0d717b1ae 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -1,5 +1,6 @@ import copy import logging +import time from abc import ABC from enum import Enum from typing import List, Optional, Tuple @@ -12,7 +13,6 @@ from optimum.neuron import NeuronModelForCausalLM from optimum.neuron.generation import TokenSelector -from .model import fetch_model from .pb.generate_pb2 import ( Batch, CachedBatch, @@ -504,22 +504,21 @@ def _clear(self, request_ids: List): @classmethod def from_pretrained( cls, - model_id: str, - revision: Optional[str], + model_path: str, ): """Instantiate a NeuronGenerator. Args: - model_id (`str`): - The *model_id* of a model on the HuggingFace hub or the path to a local model. - In either case, the hub or local path must also contain a Tokenizer. - revision (`str`): - The revision of the model on the HuggingFace hub. + model_path (`str`): + The path to a local neuron model. This path must also contain a Tokenizer. Returns: A NeuronGenerator. """ - model_path = fetch_model(model_id, revision) - model = NeuronModelForCausalLM.from_pretrained(model_path, revision=revision) - tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) + logger.info("Loading model on neuron devices (this can take a few minutes).") + start = time.time() + model = NeuronModelForCausalLM.from_pretrained(model_path) + end = time.time() + logger.info(f"Model successfully loaded in {end - start:.2f} s.") + tokenizer = AutoTokenizer.from_pretrained(model_path) return cls(model, tokenizer) diff --git a/text-generation-inference/server/text_generation_server/model.py b/text-generation-inference/server/text_generation_server/model.py index 16db5230d..c759fa1b4 100644 --- a/text-generation-inference/server/text_generation_server/model.py +++ b/text-generation-inference/server/text_generation_server/model.py @@ -1,9 +1,53 @@ import os +import time from typing import Optional from huggingface_hub import snapshot_download +from huggingface_hub.constants import HF_HUB_CACHE from loguru import logger -from transformers import AutoConfig +from transformers import AutoConfig, AutoTokenizer + +from optimum.neuron import NeuronModelForCausalLM +from optimum.neuron.utils import ModelCacheEntry, get_hub_cached_entries + + +def get_export_kwargs_from_env(): + batch_size = os.environ.get("HF_BATCH_SIZE", None) + if batch_size is not None: + batch_size = int(batch_size) + sequence_length = os.environ.get("HF_SEQUENCE_LENGTH", None) + if sequence_length is not None: + sequence_length = int(sequence_length) + num_cores = os.environ.get("HF_NUM_CORES", None) + if num_cores is not None: + num_cores = int(num_cores) + auto_cast_type = os.environ.get("HF_AUTO_CAST_TYPE", None) + return { + "task": "text-generation", + "batch_size": batch_size, + "sequence_length": sequence_length, + "num_cores": num_cores, + "auto_cast_type": auto_cast_type, + } + + +def is_cached(model_id, neuron_config): + # Look for cached entries for the specified model + in_cache = False + entries = get_hub_cached_entries(model_id) + # Look for compatible entries + for entry in entries: + compatible = True + for key, value in neuron_config.items(): + # Only weights can be different + if key in ["checkpoint_id", "checkpoint_revision"]: + continue + if entry[key] != value: + compatible = False + if compatible: + in_cache = True + break + return in_cache def fetch_model( @@ -21,17 +65,51 @@ def fetch_model( Returns: Local folder path (string) of the model. """ + if not os.path.isdir("/sys/class/neuron_device/"): + raise SystemError("No neuron cores detected on the host.") if os.path.isdir(model_id): if revision is not None: logger.warning("Revision {} ignored for local model at {}".format(revision, model_id)) - model_path = model_id - else: - # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model) - # Note that the model may already be present in the cache. - logger.info("Fetching revision {} for {}".format(revision, model_id)) - model_path = snapshot_download(model_id, revision=revision) - config = AutoConfig.from_pretrained(model_path) + return model_id + # Download the model from the Hub (HUGGING_FACE_HUB_TOKEN must be set for a private or gated model) + # Note that the model may already be present in the cache. + config = AutoConfig.from_pretrained(model_id, revision=revision) neuron_config = getattr(config, "neuron", None) - if neuron_config is None: - raise ValueError("The target model is not a Neuron model. Please export it to neuron first.") - return model_path + if neuron_config is not None: + logger.info("Fetching revision {} for neuron model {}".format(revision, model_id)) + return snapshot_download(model_id, revision=revision) + # Not a neuron model: evaluate the export config and check if it has been exported locally + export_kwargs = get_export_kwargs_from_env() + export_config = NeuronModelForCausalLM.get_export_config(model_id, config, revision=revision, **export_kwargs) + entry = ModelCacheEntry(model_id, export_config) + export_path = f"{HF_HUB_CACHE}/{entry.hash}" + if os.path.exists(export_path): + # The model has already been exported for that configuration + logger.info(f"Neuron model for {model_id} with {export_config.neuron} found under {export_path}.") + return export_path + # Look for compatible cached entries on the hub + neuron_config = export_config.neuron + if not is_cached(model_id, neuron_config): + error_msg = ( + f"No cached version found for {model_id} with {neuron_config}." + "You can start a discussion to request it on https://huggingface.co/aws-neuron/optimum-neuron-cache." + ) + raise ValueError(error_msg) + # Export the model + logger.warning(f"{model_id} is not a neuron model: it will be exported using cached artifacts.") + start = time.time() + logger.info(f"Fetching revision {revision} of model {model_id}.") + model_path = snapshot_download(model_id, revision=revision) + end = time.time() + logger.info(f"Model successfully fetched in {end - start:.2f} s.") + logger.info(f"Exporting model to neuron with config {neuron_config}.") + start = time.time() + model = NeuronModelForCausalLM.from_pretrained(model_path, export=True, **export_kwargs) + # Save for later retrieval + model.save_pretrained(export_path) + end = time.time() + # We also need to fetch and save the tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) + tokenizer.save_pretrained(export_path) + logger.info(f"Model successfully exported in {end - start:.2f} s under {export_path}.") + return export_path diff --git a/text-generation-inference/server/text_generation_server/server.py b/text-generation-inference/server/text_generation_server/server.py index 05e4e3256..87c2b60e3 100644 --- a/text-generation-inference/server/text_generation_server/server.py +++ b/text-generation-inference/server/text_generation_server/server.py @@ -1,6 +1,6 @@ import asyncio from pathlib import Path -from typing import List, Optional +from typing import List from grpc import aio from grpc_reflection.v1alpha import reflection @@ -49,17 +49,16 @@ async def Decode(self, request, context): def serve( - model_id: str, - revision: Optional[str], + model_path: str, uds_path: Path, ): - async def serve_inner(model_id: str, revision: Optional[str]): + async def serve_inner(model_path: str): unix_socket_template = "unix://{}-{}" local_url = unix_socket_template.format(uds_path, 0) server_urls = [local_url] try: - generator = NeuronGenerator.from_pretrained(model_id, revision) + generator = NeuronGenerator.from_pretrained(model_path) except Exception: logger.exception("Error when initializing model") raise @@ -85,4 +84,4 @@ async def serve_inner(model_id: str, revision: Optional[str]): logger.info("Signal received. Shutting down") await server.stop(0) - asyncio.run(serve_inner(model_id, revision)) + asyncio.run(serve_inner(model_path))