diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index ac2f6ab8a..c02e14578 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -404,7 +404,7 @@ def export_models( logger.info(f"[Total compilation Time] {np.round(total_compilation_time, 2)} seconds.") # cache neuronx model - if not disable_neuron_cache and is_neuronx_available() and not inline_weights_to_neff: + if not disable_neuron_cache and is_neuronx_available(): model_id = get_model_name_or_path(model_config) if model_name_or_path is None else model_name_or_path cache_config = build_cache_config(compile_configs) cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config) diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 437f0b383..6a85e738f 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -723,7 +723,7 @@ def _export( pipe = replace_stable_diffusion_submodels(pipe, submodels) # Check if the cache exists - if not inline_weights_to_neff and not disable_neuron_cache: + if not disable_neuron_cache: save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) # 1. Fetch all model configs @@ -790,7 +790,8 @@ def _export( # load cache neuron_model = cls.from_pretrained(model_cache_dir, data_parallel_mode=data_parallel_mode) # replace weights - neuron_model.replace_weights(weights=pipe) + if not inline_weights_to_neff: + neuron_model.replace_weights(weights=pipe) return neuron_model else: # compile @@ -822,6 +823,7 @@ def _export( lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, library_name=cls.library_name, + checked_config=cache_config, **input_shapes, ) diff --git a/optimum/neuron/modeling_traced.py b/optimum/neuron/modeling_traced.py index 53398d6c4..1b9668c44 100644 --- a/optimum/neuron/modeling_traced.py +++ b/optimum/neuron/modeling_traced.py @@ -290,9 +290,7 @@ def _export( "disable_fallback": disable_fallback, } - if ( - not inline_weights_to_neff and not disable_neuron_cache and is_neuronx_available() - ): # TODO: support caching of Inf1 as well + if not disable_neuron_cache and is_neuronx_available(): # TODO: support caching of Inf1 as well # Check if the cache exists compilation_config = store_compilation_config( config=config, @@ -331,8 +329,9 @@ def _export( force_download=force_download, trust_remote_code=trust_remote_code, ) - # replace weights - neuron_model.replace_weights(weights=model) + if not inline_weights_to_neff: + # replace weights + neuron_model.replace_weights(weights=model) return neuron_model except Exception as e: logger.warning(