From 50e457c7d4b2b9ec7704349029af73060f5d09d0 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Mon, 17 Jul 2023 14:32:28 +0200 Subject: [PATCH] Add support for pushing and fetching Neuron compiled text-generation models from the hub (#133) * feat(generate): override OpimizedModel.push_to_hub The optimum implementation flattens the file hierarchy, but the Neuron decoder models checkpoint must be stored in a subdirectory as they have their own config.json. * feat(generate): add support for Neuron models download from the Hub * review: simplify push_to_hub --- examples/text-generation/run_generation.py | 6 +-- optimum/neuron/modeling_decoder.py | 58 ++++++++++++++++++++-- tests/test_modeling_decoder.py | 52 ++++++++++++++++--- 3 files changed, 100 insertions(+), 16 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index d5c02ca54..efdc2fa75 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -1,15 +1,15 @@ import argparse -import os import time import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from optimum.neuron import NeuronModelForCausalLM def load_llm_optimum(model_id_or_path, batch_size, seq_length, num_cores, auto_cast_type): - export = not os.path.isdir(model_id_or_path) + config = AutoConfig.from_pretrained(model_id_or_path) + export = getattr(config, "neuron", None) is None # Load and convert the Hub model to Neuron format return NeuronModelForCausalLM.from_pretrained( diff --git a/optimum/neuron/modeling_decoder.py b/optimum/neuron/modeling_decoder.py index d495854c1..800570c1e 100644 --- a/optimum/neuron/modeling_decoder.py +++ b/optimum/neuron/modeling_decoder.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union import torch +from huggingface_hub import HfApi, HfFolder, snapshot_download from transformers import GenerationConfig from ..exporters.neuron.model_configs import * # noqa: F403 @@ -143,7 +144,9 @@ def _from_transformers( return cls._from_pretrained(checkpoint_dir, config) @classmethod - def _get_neuron_paths(cls, model_dir: Union[str, Path, TemporaryDirectory]) -> Tuple[str, str, str]: + def _get_neuron_paths( + cls, model_dir: Union[str, Path, TemporaryDirectory], token: Optional[str] = None + ) -> Tuple[str, str, str]: if isinstance(model_dir, TemporaryDirectory): model_path = model_dir.name # We are in the middle of an export: the checkpoint is in the temporary model directory @@ -151,8 +154,13 @@ def _get_neuron_paths(cls, model_dir: Union[str, Path, TemporaryDirectory]) -> T # There are no compiled artifacts yet compiled_path = None else: - model_path = model_dir - # The model has already been exported, the checkpoint is in a subdirectory + # The model has already been exported + if os.path.isdir(model_dir): + model_path = model_dir + else: + # Download the neuron model from the Hub + model_path = snapshot_download(model_dir, token=token) + # The checkpoint is in a subdirectory checkpoint_path = os.path.join(model_path, cls.CHECKPOINT_DIR) # So are the compiled artifacts compiled_path = os.path.join(model_path, cls.COMPILED_DIR) @@ -160,7 +168,11 @@ def _get_neuron_paths(cls, model_dir: Union[str, Path, TemporaryDirectory]) -> T @classmethod def _from_pretrained( - cls, model_id: Union[str, Path, TemporaryDirectory], config: "PretrainedConfig", **kwargs + cls, + model_id: Union[str, Path, TemporaryDirectory], + config: "PretrainedConfig", + use_auth_token: Optional[str] = None, + **kwargs, ) -> "NeuronDecoderModel": # Verify we are actually trying to load a neuron model neuron_config = getattr(config, "neuron", None) @@ -181,7 +193,7 @@ def _from_pretrained( exporter = get_exporter(config, task) - model_path, checkpoint_path, compiled_path = cls._get_neuron_paths(model_id) + model_path, checkpoint_path, compiled_path = cls._get_neuron_paths(model_id, use_auth_token) neuronx_model = exporter.neuronx_class.from_pretrained( checkpoint_path, batch_size=batch_size, tp_degree=num_cores, amp=auto_cast_type, **neuron_kwargs @@ -224,3 +236,39 @@ def _save_pretrained(self, save_directory: Union[str, Path]): self.model_path = save_directory self.generation_config.save_pretrained(save_directory) + + def push_to_hub( + self, + save_directory: str, + repository_id: str, + private: Optional[bool] = None, + use_auth_token: Union[bool, str] = True, + endpoint: Optional[str] = None, + ) -> str: + if isinstance(use_auth_token, str): + huggingface_token = use_auth_token + elif use_auth_token: + huggingface_token = HfFolder.get_token() + else: + raise ValueError("You need to provide `use_auth_token` to be able to push to the hub") + api = HfApi(endpoint=endpoint) + + user = api.whoami(huggingface_token) + self.git_config_username_and_email(git_email=user["email"], git_user=user["fullname"]) + + api.create_repo( + token=huggingface_token, + repo_id=repository_id, + exist_ok=True, + private=private, + ) + for path, subdirs, files in os.walk(save_directory): + for name in files: + local_file_path = os.path.join(path, name) + hub_file_path = os.path.relpath(local_file_path, save_directory) + api.upload_file( + token=huggingface_token, + repo_id=repository_id, + path_or_fileobj=os.path.join(os.getcwd(), local_file_path), + path_in_repo=hub_file_path, + ) diff --git a/tests/test_modeling_decoder.py b/tests/test_modeling_decoder.py index 1d8db825a..8d65a776a 100644 --- a/tests/test_modeling_decoder.py +++ b/tests/test_modeling_decoder.py @@ -12,17 +12,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from tempfile import TemporaryDirectory import pytest import torch +from huggingface_hub import HfApi from transformers import AutoTokenizer +from transformers.testing_utils import ENDPOINT_STAGING from optimum.neuron import NeuronModelForCausalLM from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx from optimum.utils import logging +from optimum.utils.testing_utils import TOKEN, USER -from .exporters.exporters_utils import EXPORT_MODELS_TINY as MODEL_NAMES +from .exporters.exporters_utils import EXPORT_MODELS_TINY logger = logging.get_logger() @@ -31,20 +35,20 @@ DECODER_MODEL_ARCHITECTURES = ["gpt2"] -@pytest.fixture(scope="module", params=[MODEL_NAMES[model_arch] for model_arch in DECODER_MODEL_ARCHITECTURES]) -def model_id(request): +@pytest.fixture(scope="module", params=[EXPORT_MODELS_TINY[model_arch] for model_arch in DECODER_MODEL_ARCHITECTURES]) +def export_model_id(request): return request.param @pytest.fixture(scope="module") -def neuron_model_path(model_id): +def neuron_model_path(export_model_id): # For now we need to use a batch_size of 2 because it fails with batch_size == 1 - model = NeuronModelForCausalLM.from_pretrained(model_id, export=True, batch_size=2) + model = NeuronModelForCausalLM.from_pretrained(export_model_id, export=True, batch_size=2) model_dir = TemporaryDirectory() model_path = model_dir.name model.save_pretrained(model_path) del model - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(export_model_id) tokenizer.save_pretrained(model_path) del tokenizer # Yield instead of returning to keep a reference to the temporary directory. @@ -53,6 +57,13 @@ def neuron_model_path(model_id): yield model_path +@pytest.fixture(scope="module") +def neuron_push_id(export_model_id): + model_name = export_model_id.split("/")[-1] + repo_id = f"{USER}/{model_name}-neuronx" + return repo_id + + def _check_neuron_model(neuron_model, batch_size=None, num_cores=None, auto_cast_type=None): neuron_config = getattr(neuron_model.config, "neuron", None) assert neuron_config @@ -74,9 +85,9 @@ def _check_neuron_model(neuron_model, batch_size=None, num_cores=None, auto_cast [2, 2, "bf16"], ], ) -def test_model_from_hub(model_id, batch_size, num_cores, auto_cast_type): +def test_model_export(export_model_id, batch_size, num_cores, auto_cast_type): model = NeuronModelForCausalLM.from_pretrained( - model_id, export=True, batch_size=batch_size, num_cores=num_cores, auto_cast_type=auto_cast_type + export_model_id, export=True, batch_size=batch_size, num_cores=num_cores, auto_cast_type=auto_cast_type ) _check_neuron_model(model, batch_size, num_cores, auto_cast_type) @@ -88,6 +99,13 @@ def test_model_from_path(neuron_model_path): _check_neuron_model(model) +@is_inferentia_test +@requires_neuronx +def test_model_from_hub(): + model = NeuronModelForCausalLM.from_pretrained("dacorvo/tiny-random-gpt2-neuronx") + _check_neuron_model(model) + + def _test_model_generation(model, tokenizer, batch_size, max_length, **gen_kwargs): prompt_text = "Hello, I'm a language model," prompts = [prompt_text for _ in range(batch_size)] @@ -117,3 +135,21 @@ def test_model_generation(neuron_model_path, gen_kwargs): # Using an incompatible generation length with pytest.raises(ValueError, match="The current sequence length"): _test_model_generation(model, tokenizer, model.batch_size, model.max_length * 2, **gen_kwargs) + + +@is_inferentia_test +@requires_neuronx +def test_push_to_hub(neuron_model_path, neuron_push_id): + model = NeuronModelForCausalLM.from_pretrained(neuron_model_path) + model.push_to_hub(neuron_model_path, neuron_push_id, use_auth_token=TOKEN, endpoint=ENDPOINT_STAGING) + api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) + try: + hub_files_info = api.list_files_info(neuron_push_id) + hub_files_path = [info.rfilename for info in hub_files_info] + for path, _, files in os.walk(neuron_model_path): + for name in files: + local_file_path = os.path.join(path, name) + hub_file_path = os.path.relpath(local_file_path, neuron_model_path) + assert hub_file_path in hub_files_path + finally: + api.delete_repo(neuron_push_id)