Skip to content

Commit

Permalink
Enable caching for inlined models (#604)
Browse files Browse the repository at this point in the history
* enable caching for inlined models

* try fix
  • Loading branch information
JingyaHuang authored May 29, 2024
1 parent ad9e51b commit 1efb43d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def export_models(
logger.info(f"[Total compilation Time] {np.round(total_compilation_time, 2)} seconds.")

# cache neuronx model
if not disable_neuron_cache and is_neuronx_available() and not inline_weights_to_neff:
if not disable_neuron_cache and is_neuronx_available():
model_id = get_model_name_or_path(model_config) if model_name_or_path is None else model_name_or_path
cache_config = build_cache_config(compile_configs)
cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config)
Expand Down
6 changes: 4 additions & 2 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def _export(
pipe = replace_stable_diffusion_submodels(pipe, submodels)

# Check if the cache exists
if not inline_weights_to_neff and not disable_neuron_cache:
if not disable_neuron_cache:
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
# 1. Fetch all model configs
Expand Down Expand Up @@ -790,7 +790,8 @@ def _export(
# load cache
neuron_model = cls.from_pretrained(model_cache_dir, data_parallel_mode=data_parallel_mode)
# replace weights
neuron_model.replace_weights(weights=pipe)
if not inline_weights_to_neff:
neuron_model.replace_weights(weights=pipe)
return neuron_model
else:
# compile
Expand Down Expand Up @@ -822,6 +823,7 @@ def _export(
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
library_name=cls.library_name,
checked_config=cache_config,
**input_shapes,
)

Expand Down
9 changes: 4 additions & 5 deletions optimum/neuron/modeling_traced.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,7 @@ def _export(
"disable_fallback": disable_fallback,
}

if (
not inline_weights_to_neff and not disable_neuron_cache and is_neuronx_available()
): # TODO: support caching of Inf1 as well
if not disable_neuron_cache and is_neuronx_available(): # TODO: support caching of Inf1 as well
# Check if the cache exists
compilation_config = store_compilation_config(
config=config,
Expand Down Expand Up @@ -331,8 +329,9 @@ def _export(
force_download=force_download,
trust_remote_code=trust_remote_code,
)
# replace weights
neuron_model.replace_weights(weights=model)
if not inline_weights_to_neff:
# replace weights
neuron_model.replace_weights(weights=model)
return neuron_model
except Exception as e:
logger.warning(
Expand Down

0 comments on commit 1efb43d

Please sign in to comment.