Skip to content

Commit

Permalink
feat(generate): add support for Neuron models download from the Hub
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo authored and michaelbenayoun committed Jul 13, 2023
1 parent fde1a26 commit ed8a0be
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 23 deletions.
6 changes: 3 additions & 3 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import argparse
import os
import time

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from optimum.neuron import NeuronModelForCausalLM


def load_llm_optimum(model_id_or_path, batch_size, seq_length, num_cores, auto_cast_type):
export = not os.path.isdir(model_id_or_path)
config = AutoConfig.from_pretrained(model_id_or_path)
export = getattr(config, "neuron", None) is None

# Load and convert the Hub model to Neuron format
return NeuronModelForCausalLM.from_pretrained(
Expand Down
23 changes: 17 additions & 6 deletions optimum/neuron/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import TYPE_CHECKING, Optional, Tuple, Union

import torch
from huggingface_hub import HfApi, HfFolder
from huggingface_hub import HfApi, HfFolder, snapshot_download
from transformers import GenerationConfig

from ..exporters.neuron.model_configs import * # noqa: F403
Expand Down Expand Up @@ -144,24 +144,35 @@ def _from_transformers(
return cls._from_pretrained(checkpoint_dir, config)

@classmethod
def _get_neuron_paths(cls, model_dir: Union[str, Path, TemporaryDirectory]) -> Tuple[str, str, str]:
def _get_neuron_paths(
cls, model_dir: Union[str, Path, TemporaryDirectory], token: Optional[str] = None
) -> Tuple[str, str, str]:
if isinstance(model_dir, TemporaryDirectory):
model_path = model_dir.name
# We are in the middle of an export: the checkpoint is in the temporary model directory
checkpoint_path = model_path
# There are no compiled artifacts yet
compiled_path = None
else:
model_path = model_dir
# The model has already been exported, the checkpoint is in a subdirectory
# The model has already been exported
if os.path.isdir(model_dir):
model_path = model_dir
else:
# Download the neuron model from the Hub
model_path = snapshot_download(model_dir, token=token)
# The checkpoint is in a subdirectory
checkpoint_path = os.path.join(model_path, cls.CHECKPOINT_DIR)
# So are the compiled artifacts
compiled_path = os.path.join(model_path, cls.COMPILED_DIR)
return model_path, checkpoint_path, compiled_path

@classmethod
def _from_pretrained(
cls, model_id: Union[str, Path, TemporaryDirectory], config: "PretrainedConfig", **kwargs
cls,
model_id: Union[str, Path, TemporaryDirectory],
config: "PretrainedConfig",
use_auth_token: Optional[str] = None,
**kwargs,
) -> "NeuronDecoderModel":
# Verify we are actually trying to load a neuron model
neuron_config = getattr(config, "neuron", None)
Expand All @@ -182,7 +193,7 @@ def _from_pretrained(

exporter = get_exporter(config, task)

model_path, checkpoint_path, compiled_path = cls._get_neuron_paths(model_id)
model_path, checkpoint_path, compiled_path = cls._get_neuron_paths(model_id, use_auth_token)

neuronx_model = exporter.neuronx_class.from_pretrained(
checkpoint_path, batch_size=batch_size, tp_degree=num_cores, amp=auto_cast_type, **neuron_kwargs
Expand Down
35 changes: 21 additions & 14 deletions tests/test_modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from optimum.utils import logging
from optimum.utils.testing_utils import TOKEN, USER

from .exporters.exporters_utils import EXPORT_MODELS_TINY as MODEL_NAMES
from .exporters.exporters_utils import EXPORT_MODELS_TINY


logger = logging.get_logger()
Expand All @@ -35,20 +35,20 @@
DECODER_MODEL_ARCHITECTURES = ["gpt2"]


@pytest.fixture(scope="module", params=[MODEL_NAMES[model_arch] for model_arch in DECODER_MODEL_ARCHITECTURES])
def model_id(request):
@pytest.fixture(scope="module", params=[EXPORT_MODELS_TINY[model_arch] for model_arch in DECODER_MODEL_ARCHITECTURES])
def export_model_id(request):
return request.param


@pytest.fixture(scope="module")
def neuron_model_path(model_id):
def neuron_model_path(export_model_id):
# For now we need to use a batch_size of 2 because it fails with batch_size == 1
model = NeuronModelForCausalLM.from_pretrained(model_id, export=True, batch_size=2)
model = NeuronModelForCausalLM.from_pretrained(export_model_id, export=True, batch_size=2)
model_dir = TemporaryDirectory()
model_path = model_dir.name
model.save_pretrained(model_path)
del model
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(export_model_id)
tokenizer.save_pretrained(model_path)
del tokenizer
# Yield instead of returning to keep a reference to the temporary directory.
Expand All @@ -58,8 +58,8 @@ def neuron_model_path(model_id):


@pytest.fixture(scope="module")
def neuron_repo_id(model_id):
model_name = model_id.split("/")[-1]
def neuron_push_id(export_model_id):
model_name = export_model_id.split("/")[-1]
repo_id = f"{USER}/{model_name}-neuronx"
return repo_id

Expand All @@ -85,9 +85,9 @@ def _check_neuron_model(neuron_model, batch_size=None, num_cores=None, auto_cast
[2, 2, "bf16"],
],
)
def test_model_from_hub(model_id, batch_size, num_cores, auto_cast_type):
def test_model_export(export_model_id, batch_size, num_cores, auto_cast_type):
model = NeuronModelForCausalLM.from_pretrained(
model_id, export=True, batch_size=batch_size, num_cores=num_cores, auto_cast_type=auto_cast_type
export_model_id, export=True, batch_size=batch_size, num_cores=num_cores, auto_cast_type=auto_cast_type
)
_check_neuron_model(model, batch_size, num_cores, auto_cast_type)

Expand All @@ -99,6 +99,13 @@ def test_model_from_path(neuron_model_path):
_check_neuron_model(model)


@is_inferentia_test
@requires_neuronx
def test_model_from_hub():
model = NeuronModelForCausalLM.from_pretrained("dacorvo/tiny-random-gpt2-neuronx")
_check_neuron_model(model)


def _test_model_generation(model, tokenizer, batch_size, max_length, **gen_kwargs):
prompt_text = "Hello, I'm a language model,"
prompts = [prompt_text for _ in range(batch_size)]
Expand Down Expand Up @@ -132,17 +139,17 @@ def test_model_generation(neuron_model_path, gen_kwargs):

@is_inferentia_test
@requires_neuronx
def test_push_to_hub(neuron_model_path, neuron_repo_id):
def test_push_to_hub(neuron_model_path, neuron_push_id):
model = NeuronModelForCausalLM.from_pretrained(neuron_model_path)
model.push_to_hub(neuron_model_path, neuron_repo_id, use_auth_token=TOKEN, endpoint=ENDPOINT_STAGING)
model.push_to_hub(neuron_model_path, neuron_push_id, use_auth_token=TOKEN, endpoint=ENDPOINT_STAGING)
api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN)
try:
hub_files_info = api.list_files_info(neuron_repo_id)
hub_files_info = api.list_files_info(neuron_push_id)
hub_files_path = [info.rfilename for info in hub_files_info]
for path, _, files in os.walk(neuron_model_path):
for name in files:
local_file_path = os.path.join(path, name)
hub_file_path = os.path.relpath(local_file_path, neuron_model_path)
assert hub_file_path in hub_files_path
finally:
api.delete_repo(neuron_repo_id)
api.delete_repo(neuron_push_id)

0 comments on commit ed8a0be

Please sign in to comment.