Skip to content

Commit

Permalink
applied comments of David
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Mar 27, 2024
1 parent e62ee54 commit a26eeaa
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 54 deletions.
9 changes: 2 additions & 7 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
from ...neuron.utils.hub_neuronx_cache import (
ModelCacheEntry,
build_cache_config,
cache_aot_neuron_artifacts,
hub_neuronx_cache,
cache_traced_neuron_artifacts,
)
from ...neuron.utils.version_utils import get_neuroncc_version, get_neuronxcc_version
from ...utils import (
Expand Down Expand Up @@ -402,11 +401,7 @@ def export_models(
model_id = get_model_name_or_path(model_config) if model_name_or_path is None else model_name_or_path
cache_config = build_cache_config(compile_configs)
cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config)

# Use the context manager just for creating registry, AOT compilation won't leverage `create_compile_cache`
# in `libneuronxla`, so we will need to cache compiled artifacts to local manually.
with hub_neuronx_cache("inference", entry=cache_entry):
cache_aot_neuron_artifacts(neuron_dir=output_dir, cache_config_hash=cache_entry.hash)
cache_traced_neuron_artifacts(neuron_dir=output_dir, cache_entry=cache_entry)

# remove models failed to export
for i, model_name in failed_models:
Expand Down
6 changes: 2 additions & 4 deletions optimum/neuron/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
replace_weights,
store_compilation_config,
)
from .utils.cache_utils import load_custom_cache_repo_name_from_hf_home
from .utils.hub_neuronx_cache import ModelCacheEntry, _create_hub_compile_cache_proxy, build_cache_config
from .utils.hub_neuronx_cache import ModelCacheEntry, build_cache_config, create_hub_compile_cache_proxy
from .utils.import_utils import is_neuronx_available
from .utils.misc import maybe_load_preprocessors
from .utils.version_utils import check_compiler_compatibility, get_neuroncc_version, get_neuronxcc_version
Expand Down Expand Up @@ -296,8 +295,7 @@ def _export(
)
cache_config = build_cache_config(compilation_config)
cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config)
cache_repo_id = load_custom_cache_repo_name_from_hf_home()
compile_cache = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id)
compile_cache = create_hub_compile_cache_proxy()
model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_entry.hash}")
cache_available = compile_cache.download_folder(model_cache_dir, model_cache_dir)
else:
Expand Down
6 changes: 2 additions & 4 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@
replace_weights,
store_compilation_config,
)
from .utils.cache_utils import load_custom_cache_repo_name_from_hf_home
from .utils.hub_neuronx_cache import (
ModelCacheEntry,
_create_hub_compile_cache_proxy,
build_cache_config,
create_hub_compile_cache_proxy,
)
from .utils.require_utils import requires_torch_neuronx
from .utils.version_utils import get_neuronxcc_version
Expand Down Expand Up @@ -739,8 +738,7 @@ def _export(
# 3. Lookup cached config
cache_config = build_cache_config(compilation_configs)
cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config)
cache_repo_id = load_custom_cache_repo_name_from_hf_home()
compile_cache = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id)
compile_cache = create_hub_compile_cache_proxy()
model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_entry.hash}")
cache_exist = compile_cache.download_folder(model_cache_dir, model_cache_dir)
else:
Expand Down
22 changes: 12 additions & 10 deletions optimum/neuron/utils/hub_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def upload_file(self, cache_path: str, src_path: str):
self.default_cache.upload_file(cache_path, src_path)

def upload_folder(self, cache_dir: str, src_dir: str):
# Upload folder the default cache: use synchronize to populate the Hub cache
# Upload folder to the default cache: use synchronize to populate the Hub cache
shutil.copytree(src_dir, cache_dir, dirs_exist_ok=True)

def upload_string_to_file(self, cache_path: str, data: str):
Expand Down Expand Up @@ -241,7 +241,7 @@ def get_hub_cache():
return os.getenv("CUSTOM_CACHE_REPO", HUB_CACHE)


def _create_hub_compile_cache_proxy(
def create_hub_compile_cache_proxy(
cache_url: Optional[CacheUrl] = None,
cache_repo_id: Optional[str] = None,
):
Expand Down Expand Up @@ -328,7 +328,7 @@ def hub_neuronx_cache(

def hf_create_compile_cache(cache_url):
try:
return _create_hub_compile_cache_proxy(cache_url, cache_repo_id=cache_repo_id)
return create_hub_compile_cache_proxy(cache_url, cache_repo_id=cache_repo_id)
except Exception as e:
logger.warning(f"Bypassing Hub cache because of the following error: {e}")
return create_compile_cache(cache_url)
Expand Down Expand Up @@ -402,7 +402,7 @@ def synchronize_hub_cache(cache_path: Optional[Union[str, Path]] = None, cache_r
cache_url = CacheUrl(cache_path_str, url_type="fs")
else:
cache_url = None
hub_cache_proxy = _create_hub_compile_cache_proxy(cache_url=cache_url, cache_repo_id=cache_repo_id)
hub_cache_proxy = create_hub_compile_cache_proxy(cache_url=cache_url, cache_repo_id=cache_repo_id)
hub_cache_proxy.synchronize()


Expand Down Expand Up @@ -558,10 +558,12 @@ def build_cache_config(
return next(iter(clean_configs.values()))


def cache_aot_neuron_artifacts(neuron_dir: Path, cache_config_hash: str):
cache_repo_id = load_custom_cache_repo_name_from_hf_home()
compile_cache = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id)
model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_config_hash}")
compile_cache.upload_folder(cache_dir=model_cache_dir, src_dir=neuron_dir)
def cache_traced_neuron_artifacts(neuron_dir: Path, cache_entry: ModelCacheEntry):
# Use the context manager just for creating registry, AOT compilation won't leverage `create_compile_cache`
# in `libneuronxla`, so we will need to cache compiled artifacts to local manually.
with hub_neuronx_cache("inference", entry=cache_entry):
compile_cache = create_hub_compile_cache_proxy()
model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_entry.hash}")
compile_cache.upload_folder(cache_dir=model_cache_dir, src_dir=neuron_dir)

logger.info(f"Model cached in: {model_cache_dir}.")
logger.info(f"Model cached in: {model_cache_dir}.")
35 changes: 6 additions & 29 deletions tests/cache/test_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@
NeuronStableDiffusionXLPipeline,
)
from optimum.neuron.utils import get_hub_cached_entries, synchronize_hub_cache
from optimum.neuron.utils.cache_utils import (
CACHE_REPO_FILENAME,
HF_HOME,
load_custom_cache_repo_name_from_hf_home,
set_custom_cache_repo_name_in_hf_home,
)
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx
from optimum.utils.testing_utils import TOKEN

Expand All @@ -53,12 +47,7 @@ def cache_repos():
if api.repo_exists(cache_repo_id):
api.delete_repo(cache_repo_id)
cache_repo_id = api.create_repo(cache_repo_id, private=True).repo_id
api.repo_info(cache_repo_id, repo_type="model")
cache_dir = TemporaryDirectory()
set_custom_cache_repo_name_in_hf_home(
cache_repo_id, api=api
) # The custom repo will be registered under `HF_HOME`, we need to restore the env by the end of each test.
assert load_custom_cache_repo_name_from_hf_home() == cache_repo_id
cache_path = cache_dir.name
# Modify environment to force neuronx cache to use temporary caches
previous_env = {}
Expand All @@ -79,12 +68,6 @@ def cache_repos():
os.environ[var] = previous_env[var]


def unset_custom_cache_repo_name_in_hf_home(hf_home: str = HF_HOME):
hf_home_cache_repo_file = f"{hf_home}/{CACHE_REPO_FILENAME}"
if os.path.isfile(hf_home_cache_repo_file):
os.remove(hf_home_cache_repo_file)


def export_decoder_model(model_id):
batch_size = 2
sequence_length = 512
Expand Down Expand Up @@ -168,14 +151,14 @@ def get_local_cached_files(cache_path, extension="*"):
return [link for link in links if os.path.isfile(link)]


def check_jit_cache_entry(model, cache_path):
def check_decoder_cache_entry(model, cache_path):
local_files = get_local_cached_files(cache_path, "json")
model_id = model.config.neuron["checkpoint_id"]
model_configurations = [path for path in local_files if model_id in path]
assert len(model_configurations) > 0


def check_aot_cache_entry(cache_path):
def check_traced_cache_entry(cache_path):
local_files = get_local_cached_files(cache_path, "json")
registry_path = [path for path in local_files if "REGISTRY" in path][0]
registry_key = registry_path.split("/")[-1].replace(".json", "")
Expand Down Expand Up @@ -206,7 +189,7 @@ def test_decoder_cache(cache_repos):
# Export the model a first time to populate the local cache
model = export_decoder_model(model_id)
check_decoder_generation(model)
check_jit_cache_entry(model, cache_path)
check_decoder_cache_entry(model, cache_path)
# Synchronize the hub cache with the local cache
synchronize_hub_cache(cache_repo_id=cache_repo_id)
assert_local_and_hub_cache_sync(cache_path, cache_repo_id)
Expand All @@ -226,7 +209,6 @@ def test_decoder_cache(cache_repos):
check_decoder_generation(model)
# Verify the local cache directory has not been populated
assert len(get_local_cached_files(cache_path, "neff")) == 0
unset_custom_cache_repo_name_in_hf_home()


@is_inferentia_test
Expand All @@ -239,7 +221,7 @@ def test_encoder_cache(cache_repos):
tokenizer = AutoTokenizer.from_pretrained(model_id)
check_encoder_inference(model, tokenizer)
# check registry
check_aot_cache_entry(cache_path)
check_traced_cache_entry(cache_path)
# Synchronize the hub cache with the local cache
synchronize_hub_cache(cache_repo_id=cache_repo_id)
assert_local_and_hub_cache_sync(cache_path, cache_repo_id)
Expand All @@ -258,7 +240,6 @@ def test_encoder_cache(cache_repos):
check_encoder_inference(model, tokenizer)
# Verify the local cache directory has not been populated
assert len(get_local_cached_files(cache_path, ".neuron")) == 0
unset_custom_cache_repo_name_in_hf_home()


@is_inferentia_test
Expand All @@ -270,7 +251,7 @@ def test_stable_diffusion_cache(cache_repos):
model = export_stable_diffusion_model(model_id)
check_stable_diffusion_inference(model)
# check registry
check_aot_cache_entry(cache_path)
check_traced_cache_entry(cache_path)
# Synchronize the hub cache with the local cache
synchronize_hub_cache(cache_repo_id=cache_repo_id)
assert_local_and_hub_cache_sync(cache_path, cache_repo_id)
Expand All @@ -289,7 +270,6 @@ def test_stable_diffusion_cache(cache_repos):
check_stable_diffusion_inference(model)
# Verify the local cache directory has not been populated
assert len(get_local_cached_files(cache_path, ".neuron")) == 0
unset_custom_cache_repo_name_in_hf_home()


@is_inferentia_test
Expand All @@ -301,7 +281,7 @@ def test_stable_diffusion_xl_cache(cache_repos):
model = export_stable_diffusion_xl_model(model_id)
check_stable_diffusion_inference(model)
# check registry
check_aot_cache_entry(cache_path)
check_traced_cache_entry(cache_path)
# Synchronize the hub cache with the local cache
synchronize_hub_cache(cache_repo_id=cache_repo_id)
assert_local_and_hub_cache_sync(cache_path, cache_repo_id)
Expand All @@ -320,7 +300,6 @@ def test_stable_diffusion_xl_cache(cache_repos):
check_stable_diffusion_inference(model)
# Verify the local cache directory has not been populated
assert len(get_local_cached_files(cache_path, ".neuron")) == 0
unset_custom_cache_repo_name_in_hf_home()


@is_inferentia_test
Expand All @@ -335,7 +314,6 @@ def test_stable_diffusion_xl_cache(cache_repos):
ids=["invalid_repo", "invalid_endpoint", "invalid_token"],
)
def test_decoder_cache_unavailable(cache_repos, var, value, match):
unset_custom_cache_repo_name_in_hf_home() # clean the repo set by cli since it's prioritized than env variable
# Modify the specified environment variable to trigger an error
os.environ[var] = value
# Just exporting the model will only emit a warning
Expand Down Expand Up @@ -366,4 +344,3 @@ def test_optimum_neuron_cli_cache_synchronize(cache_repos):
stdout = stdout.decode("utf-8")
assert p.returncode == 0
assert f"1 entrie(s) found in cache for {model_id}" in stdout
unset_custom_cache_repo_name_in_hf_home()

0 comments on commit a26eeaa

Please sign in to comment.