diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index bc5145f6f..ac4a08732 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -217,6 +217,8 @@ def infer_stable_diffusion_shapes_from_diffusers( scaled_width = width // vae_scale_factor input_shapes["text_encoder"].update({"sequence_length": sequence_length}) + if hasattr(model, "text_encoder_2"): + input_shapes["text_encoder_2"] = input_shapes["text_encoder"] input_shapes["unet"].update( { "sequence_length": sequence_length, @@ -290,7 +292,7 @@ def _get_submodels_and_neuron_configs( task=task, library_name=library_name, ) - check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes) + input_shapes = check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes) neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, **input_shapes) model_name = getattr(model, "name_or_path", None) or model_name_or_path model_name = model_name.split("/")[-1] if model_name else model.config.model_type diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index e787ce16c..2baf9cccb 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -342,15 +342,13 @@ def export_models( output_path = output_dir / output_file_name output_path.parent.mkdir(parents=True, exist_ok=True) - compiler_workdir_path = compiler_workdir / model_name if compiler_workdir is not None else None - try: start_time = time.time() neuron_inputs, neuron_outputs = export( model=submodel, config=sub_neuron_config, output=output_path, - compiler_workdir=compiler_workdir_path, + compiler_workdir=compiler_workdir, inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, **compiler_kwargs, diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index a111034da..47574e5af 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -262,11 +262,27 @@ def get_stable_diffusion_models_for_export( def _load_lora_weights_to_pipeline( pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLImg2ImgPipeline"], - lora_model_ids: Optional[List[str]] = None, - weight_names: Optional[List[str]] = None, - adapter_names: Optional[List[str]] = None, - lora_scales: Optional[List[float]] = None, + lora_model_ids: Optional[Union[str, List[str]]] = None, + weight_names: Optional[Union[str, List[str]]] = None, + adapter_names: Optional[Union[str, List[str]]] = None, + lora_scales: Optional[Union[float, List[float]]] = None, ): + if isinstance(lora_model_ids, str): + lora_model_ids = [ + lora_model_ids, + ] + if isinstance(weight_names, str): + weight_names = [ + weight_names, + ] + if isinstance(adapter_names, str): + adapter_names = [ + adapter_names, + ] + if isinstance(lora_scales, float): + lora_scales = [ + lora_scales, + ] if lora_model_ids and weight_names: if len(lora_model_ids) == 1: pipeline.load_lora_weights(lora_model_ids[0], weight_name=weight_names[0]) @@ -288,9 +304,9 @@ def _load_lora_weights_to_pipeline( def get_submodels_for_export_stable_diffusion( pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLImg2ImgPipeline"], task: str, - lora_model_ids: Optional[List[str]] = None, - lora_weight_names: Optional[List[str]] = None, - lora_adapter_names: Optional[List[str]] = None, + lora_model_ids: Optional[Union[str, List[str]]] = None, + lora_weight_names: Optional[Union[str, List[str]]] = None, + lora_adapter_names: Optional[Union[str, List[str]]] = None, lora_scales: Optional[List[float]] = None, ) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]: """ @@ -388,6 +404,8 @@ def check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes): raise AttributeError( f"Cannot find the value of `{name}` which is mandatory for exporting the model to the neuron format, please set the value explicitly." ) + input_shapes = {axis: input_shapes[axis] for axis in mandatory_shapes} + return input_shapes def replace_stable_diffusion_submodels(pipeline, submodels): diff --git a/optimum/neuron/modeling_base.py b/optimum/neuron/modeling_base.py index 682864d06..a3c01a35d 100644 --- a/optimum/neuron/modeling_base.py +++ b/optimum/neuron/modeling_base.py @@ -240,6 +240,7 @@ def _export( force_download: bool = False, cache_dir: Optional[str] = None, compiler_workdir: Optional[Union[str, Path]] = None, + disable_neuron_cache: Optional[bool] = False, inline_weights_to_neff: bool = False, optlevel: str = "2", subfolder: str = "", @@ -277,7 +278,9 @@ def _export( "disable_fallback": disable_fallback, } - if not inline_weights_to_neff: + if ( + not inline_weights_to_neff and not disable_neuron_cache and is_neuronx_available() + ): # TODO: support caching of Inf1 as well # Check if the cache exists compilation_config = store_compilation_config( config=config, @@ -296,30 +299,38 @@ def _export( cache_repo_id = load_custom_cache_repo_name_from_hf_home() compile_cache = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id) model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_entry.hash}") - cache_exist = compile_cache.download_folder(model_cache_dir, model_cache_dir) + cache_available = compile_cache.download_folder(model_cache_dir, model_cache_dir) else: - cache_exist = False + cache_available = False + + # load cache + if cache_available: + try: + neuron_model = cls.from_pretrained(model_cache_dir) + model = TasksManager.get_model_from_task( + task=task, + model_name_or_path=model_id, + subfolder=subfolder, + revision=revision, + framework="pt", + library_name=library_name, + cache_dir=cache_dir, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + ) + # replace weights + neuron_model.replace_weights(weights=model) + return neuron_model + except Exception as e: + logger.warning( + f"Found the cached artifacts but failed to re-load them with error: {e}. \n Falling back to recompilation." + ) + cache_available = False - if cache_exist: - # load cache - neuron_model = cls.from_pretrained(model_cache_dir) - model = TasksManager.get_model_from_task( - task=task, - model_name_or_path=model_id, - subfolder=subfolder, - revision=revision, - framework="pt", - library_name=library_name, - cache_dir=cache_dir, - use_auth_token=use_auth_token, - local_files_only=local_files_only, - force_download=force_download, - trust_remote_code=trust_remote_code, - ) - # replace weights - neuron_model.replace_weights(weights=model) - return neuron_model - else: + # compile + if not cache_available: # compile save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) @@ -330,6 +341,7 @@ def _export( task=task, dynamic_batch_size=dynamic_batch_size, cache_dir=cache_dir, + disable_neuron_cache=disable_neuron_cache, compiler_workdir=compiler_workdir, inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 6f8e27530..823041ca9 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -59,6 +59,7 @@ _create_hub_compile_cache_proxy, build_cache_config, ) +from .utils.require_utils import requires_torch_neuronx from .utils.version_utils import get_neuronxcc_version @@ -265,6 +266,7 @@ def is_lcm(unet_config): return any(pattern in unet_name_or_path for pattern in patterns) @staticmethod + @requires_torch_neuronx def load_model( data_parallel_mode: Optional[str], text_encoder_path: Union[str, Path], @@ -426,6 +428,7 @@ def _save_pretrained( self.feature_extractor.save_pretrained(save_directory.joinpath("feature_extractor")) @classmethod + @requires_torch_neuronx def _from_pretrained( cls, model_id: Union[str, Path], @@ -558,11 +561,13 @@ def _from_pretrained( ) @classmethod + @requires_torch_neuronx def _from_transformers(cls, *args, **kwargs): # Deprecate it when optimum uses `_export` as from_pretrained_method in a stable release. return cls._export(*args, **kwargs) @classmethod + @requires_torch_neuronx def _export( cls, model_id: Union[str, Path], @@ -573,6 +578,7 @@ def _export( force_download: bool = True, cache_dir: Optional[str] = None, compiler_workdir: Optional[str] = None, + disable_neuron_cache: Optional[bool] = False, inline_weights_to_neff: bool = False, optlevel: str = "2", subfolder: str = "", @@ -616,6 +622,8 @@ def _export( standard cache should not be used. compiler_workdir (`Optional[str]`, defaults to `None`): Path to a directory in which the neuron compiler will store all intermediary files during the compilation(neff, weight, hlo graph...). + disable_neuron_cache (`bool`, defaults to `False`): + Whether to disable automatic caching of compiled models. If set to True, will not load neuron cache nor cache the compiled artifacts. inline_weights_to_neff (`bool`, defaults to `False`): Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff. optlevel (`str`, defaults to `"2"`): @@ -685,7 +693,7 @@ def _export( pipe = replace_stable_diffusion_submodels(pipe, submodels) # Check if the cache exists - if not inline_weights_to_neff: + if not inline_weights_to_neff and not disable_neuron_cache: # 1. Fetch all model configs models_for_export = get_submodels_for_export_stable_diffusion( pipeline=pipe, @@ -757,6 +765,7 @@ def _export( task=task, dynamic_batch_size=dynamic_batch_size, cache_dir=cache_dir, + disable_neuron_cache=disable_neuron_cache, compiler_workdir=compiler_workdir, inline_weights_to_neff=inline_weights_to_neff, optlevel=optlevel, diff --git a/optimum/neuron/utils/hub_neuronx_cache.py b/optimum/neuron/utils/hub_neuronx_cache.py index 99daa944f..083a8f27a 100644 --- a/optimum/neuron/utils/hub_neuronx_cache.py +++ b/optimum/neuron/utils/hub_neuronx_cache.py @@ -74,7 +74,6 @@ def create_compile_cache(): "_commit_hash", "sample_size", "projection_dim", - "task", "_use_default_values", ] NEURON_CONFIG_WHITE_LIST = ["input_names", "output_names", "model_type"] diff --git a/tests/cache/test_neuronx_cache.py b/tests/cache/test_neuronx_cache.py index 84ee0d669..6d607ee4e 100644 --- a/tests/cache/test_neuronx_cache.py +++ b/tests/cache/test_neuronx_cache.py @@ -26,7 +26,12 @@ from transformers import AutoTokenizer from transformers.testing_utils import ENDPOINT_STAGING -from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSequenceClassification, NeuronStableDiffusionPipeline +from optimum.neuron import ( + NeuronModelForCausalLM, + NeuronModelForSequenceClassification, + NeuronStableDiffusionPipeline, + NeuronStableDiffusionXLPipeline, +) from optimum.neuron.utils import get_hub_cached_entries, synchronize_hub_cache from optimum.neuron.utils.cache_utils import ( CACHE_REPO_FILENAME, @@ -122,6 +127,21 @@ def export_stable_diffusion_model(model_id): ) +def export_stable_diffusion_xl_model(model_id): + batch_size = 1 + height = 64 + width = 64 + num_images_per_prompt = 4 + return NeuronStableDiffusionXLPipeline.from_pretrained( + model_id, + export=True, + batch_size=batch_size, + height=height, + width=width, + num_images_per_prompt=num_images_per_prompt, + ) + + def check_decoder_generation(model): batch_size = model.config.neuron["batch_size"] input_ids = torch.ones((batch_size, 20), dtype=torch.int64) @@ -272,6 +292,37 @@ def test_stable_diffusion_cache(cache_repos): unset_custom_cache_repo_name_in_hf_home() +@is_inferentia_test +@requires_neuronx +def test_stable_diffusion_xl_cache(cache_repos): + cache_path, cache_repo_id = cache_repos + model_id = "echarlaix/tiny-random-stable-diffusion-xl" + # Export the model a first time to populate the local cache + model = export_stable_diffusion_xl_model(model_id) + check_stable_diffusion_inference(model) + # check registry + check_aot_cache_entry(cache_path) + # Synchronize the hub cache with the local cache + synchronize_hub_cache(cache_repo_id=cache_repo_id) + assert_local_and_hub_cache_sync(cache_path, cache_repo_id) + # Verify we are able to fetch the cached entry for the model + model_entries = get_hub_cached_entries(model_id, "inference", cache_repo_id=cache_repo_id) + assert len(model_entries) == 1 + # Clear the local cache + for root, dirs, files in os.walk(cache_path): + for f in files: + os.unlink(os.path.join(root, f)) + for d in dirs: + shutil.rmtree(os.path.join(root, d)) + assert local_cache_size(cache_path) == 0 + # Export the model again: the compilation artifacts should be fetched from the Hub + model = export_stable_diffusion_xl_model(model_id) + check_stable_diffusion_inference(model) + # Verify the local cache directory has not been populated + assert len(get_local_cached_files(cache_path, ".neuron")) == 0 + unset_custom_cache_repo_name_in_hf_home() + + @is_inferentia_test @requires_neuronx @pytest.mark.parametrize( diff --git a/tests/inference/test_modeling.py b/tests/inference/test_modeling.py index 42cbb2152..efb888982 100644 --- a/tests/inference/test_modeling.py +++ b/tests/inference/test_modeling.py @@ -139,9 +139,14 @@ def test_save_compiler_intermediary_files(self): save_path = f"{tempdir}/neff" neff_path = os.path.join(save_path, "graph.neff") _ = NeuronModelForSequenceClassification.from_pretrained( - self.MODEL_ID, export=True, compiler_workdir=save_path, **self.STATIC_INPUTS_SHAPES + self.MODEL_ID, + export=True, + compiler_workdir=save_path, + disable_neuron_cache=True, + **self.STATIC_INPUTS_SHAPES, ) self.assertTrue(os.path.isdir(save_path)) + os.listdir(save_path) self.assertTrue(os.path.exists(neff_path)) @requires_neuronx @@ -656,7 +661,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): "hf-internal-testing/tiny-random-t5", from_transformers=True, **self.STATIC_INPUTS_SHAPES ) - self.assertIn("is not supported yet", str(context.exception)) + self.assertIn("doesn't support", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) @requires_neuronx @@ -862,7 +867,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): "hf-internal-testing/tiny-random-t5", from_transformers=True, **self.STATIC_INPUTS_SHAPES ) - self.assertIn("is not supported yet", str(context.exception)) + self.assertIn("doesn't support", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) @requires_neuronx