Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Mar 22, 2024
1 parent f35fa16 commit 0fefd97
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 15 deletions.
8 changes: 6 additions & 2 deletions optimum/neuron/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,15 @@ def load_custom_cache_repo_name_from_hf_home(
return None


def set_custom_cache_repo_name_in_hf_home(repo_id: str, hf_home: str = HF_HOME, check_repo: bool = True):
def set_custom_cache_repo_name_in_hf_home(
repo_id: str, hf_home: str = HF_HOME, check_repo: bool = True, api: Optional[HfApi] = None
):
hf_home_cache_repo_file = f"{hf_home}/{CACHE_REPO_FILENAME}"
if api is None:
api = HfApi()
if check_repo:
try:
HfApi().repo_info(repo_id, repo_type="model")
api.repo_info(repo_id, repo_type="model")
except Exception as e:
raise ValueError(
f"Could not save the custom Neuron cache repo to be {repo_id} because it does not exist or is "
Expand Down
5 changes: 3 additions & 2 deletions optimum/neuron/utils/hub_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def get_hub_cached_entries(
try:
config = AutoConfig.from_pretrained(model_id)
except Exception:
config = get_multimodels_configs(api, model_id) # Applied on SD, encoder-decoder models
config = get_multimodels_configs_from_hub(model_id) # Applied on SD, encoder-decoder models
target_entry = ModelCacheEntry(model_id, config)
# Extract model type: it will be used as primary key for lookup
model_type = target_entry.config["model_type"]
Expand Down Expand Up @@ -489,7 +489,8 @@ def _prepare_config_for_matching(entry_config, target_entry, model_type):
return entry_config, target_entry_config, neuron_config


def get_multimodels_configs(api, model_id):
def get_multimodels_configs_from_hub(model_id):
api = HfApi()
repo_files = api.list_repo_files(model_id)
config_pattern = "/config.json"
config_files = [path for path in repo_files if config_pattern in path]
Expand Down
44 changes: 33 additions & 11 deletions tests/cache/test_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@
import subprocess
from tempfile import TemporaryDirectory

import PIL
import pytest
import torch
from huggingface_hub import HfApi
from transformers import AutoTokenizer
from transformers.testing_utils import ENDPOINT_STAGING

from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSequenceClassification, NeuronStableDiffusionPipeline
from optimum.neuron.utils import get_hub_cached_entries, synchronize_hub_cache
from optimum.neuron.utils.cache_utils import (
load_custom_cache_repo_name_from_hf_home,
set_custom_cache_repo_name_in_hf_home,
)
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx
from optimum.utils.testing_utils import TOKEN

Expand All @@ -40,7 +46,10 @@ def cache_repos():
if api.repo_exists(cache_repo_id):
api.delete_repo(cache_repo_id)
cache_repo_id = api.create_repo(cache_repo_id, private=True).repo_id
api.repo_info(cache_repo_id, repo_type="model")
cache_dir = TemporaryDirectory()
set_custom_cache_repo_name_in_hf_home(cache_repo_id, api=api)
assert load_custom_cache_repo_name_from_hf_home() == cache_repo_id
cache_path = cache_dir.name
# Modify environment to force neuronx cache to use temporary caches
previous_env = {}
Expand Down Expand Up @@ -111,26 +120,40 @@ def check_decoder_generation(model):
assert sample_output.shape[0] == batch_size


def check_encoder_inference(model):
pass
def check_encoder_inference(model, tokenizer):
text = ["This is a sample output"]
tokens = tokenizer(text, return_tensors="pt")
outputs = model(**tokens)
assert "logits" in outputs


def check_stable_diffusion_inference(model):
pass
prompts = ["sailing ship in storm by Leonardo da Vinci"]
image = model(prompts, num_images_per_prompt=4).images[0]
assert isinstance(image, PIL.Image.Image)


def get_local_cached_files(cache_path, extension="*"):
links = glob.glob(f"{cache_path}/**/*/*.{extension}", recursive=True)
return [link for link in links if os.path.isfile(link)]


def check_cache_entry(model, cache_path):
def check_jit_cache_entry(model, cache_path):
local_files = get_local_cached_files(cache_path, "json")
model_id = model.config.neuron["checkpoint_id"]
model_configurations = [path for path in local_files if model_id in path]
assert len(model_configurations) > 0


def check_aot_cache_entry(cache_path):
local_files = get_local_cached_files(cache_path, "json")
registry_path = [path for path in local_files if "REGISTRY" in path][0]
registry_key = registry_path.split("/")[-1].replace(".json", "")
local_files.remove(registry_path)
hash_key = local_files[0].split("/")[-2].replace("MODULE_", "")
assert registry_key == hash_key


def assert_local_and_hub_cache_sync(cache_path, cache_repo_id):
api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN)
remote_files = api.list_repo_files(cache_repo_id)
Expand All @@ -153,7 +176,7 @@ def test_decoder_cache(cache_repos):
# Export the model a first time to populate the local cache
model = export_decoder_model(model_id)
check_decoder_generation(model)
check_cache_entry(model, cache_path)
check_jit_cache_entry(model, cache_path)
# Synchronize the hub cache with the local cache
synchronize_hub_cache(cache_repo_id=cache_repo_id)
assert_local_and_hub_cache_sync(cache_path, cache_repo_id)
Expand Down Expand Up @@ -182,16 +205,16 @@ def test_encoder_cache(cache_repos):
model_id = "hf-internal-testing/tiny-random-BertModel"
# Export the model a first time to populate the local cache
model = export_encoder_model(model_id)
check_encoder_inference(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
check_encoder_inference(model, tokenizer)
# check registry
check_cache_entry(model, cache_path)
check_aot_cache_entry(cache_path)
# Synchronize the hub cache with the local cache
synchronize_hub_cache(cache_repo_id=cache_repo_id)
assert_local_and_hub_cache_sync(cache_path, cache_repo_id)
# Verify we are able to fetch the cached entry for the model
model_entries = get_hub_cached_entries(model_id, "inference", cache_repo_id=cache_repo_id)
assert len(model_entries) == 1
assert model_entries[0] == model.config.neuron
# Clear the local cache
for root, dirs, files in os.walk(cache_path):
for f in files:
Expand All @@ -201,7 +224,7 @@ def test_encoder_cache(cache_repos):
assert local_cache_size(cache_path) == 0
# Export the model again: the compilation artifacts should be fetched from the Hub
model = export_encoder_model(model_id)
check_encoder_inference(model)
check_encoder_inference(model, tokenizer)
# Verify the local cache directory has not been populated
assert len(get_local_cached_files(cache_path, ".neuron")) == 0

Expand All @@ -215,14 +238,13 @@ def test_stable_diffusion_cache(cache_repos):
model = export_stable_diffusion_model(model_id)
check_stable_diffusion_inference(model)
# check registry
check_cache_entry(model, cache_path)
check_aot_cache_entry(cache_path)
# Synchronize the hub cache with the local cache
synchronize_hub_cache(cache_repo_id=cache_repo_id)
assert_local_and_hub_cache_sync(cache_path, cache_repo_id)
# Verify we are able to fetch the cached entry for the model
model_entries = get_hub_cached_entries(model_id, "inference", cache_repo_id=cache_repo_id)
assert len(model_entries) == 1
assert model_entries[0] == model.config.neuron
# Clear the local cache
for root, dirs, files in os.walk(cache_path):
for f in files:
Expand Down

0 comments on commit 0fefd97

Please sign in to comment.