From b357e529420feb61734dbac24a90dea6571cce27 Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Fri, 15 Mar 2024 21:30:06 +0000 Subject: [PATCH] sd caching cli support --- optimum/exporters/neuron/__init__.py | 2 - optimum/exporters/neuron/__main__.py | 1 + optimum/exporters/neuron/convert.py | 42 ++++--- optimum/exporters/neuron/utils.py | 13 --- optimum/neuron/modeling_base.py | 4 +- optimum/neuron/modeling_diffusion.py | 69 +++++++----- optimum/neuron/utils/__init__.py | 2 +- optimum/neuron/utils/hub_neuronx_cache.py | 131 ++++++++++++++++++---- optimum/neuron/utils/misc.py | 19 +++- 9 files changed, 183 insertions(+), 100 deletions(-) diff --git a/optimum/exporters/neuron/__init__.py b/optimum/exporters/neuron/__init__.py index 563c970f6..fb6c05709 100644 --- a/optimum/exporters/neuron/__init__.py +++ b/optimum/exporters/neuron/__init__.py @@ -26,7 +26,6 @@ "base": ["NeuronDefaultConfig"], "convert": ["export", "export_models", "validate_model_outputs", "validate_models_outputs"], "utils": [ - "DiffusersPretrainedConfig", "build_stable_diffusion_components_mandatory_shapes", "get_stable_diffusion_models_for_export", ], @@ -41,7 +40,6 @@ from .base import NeuronDefaultConfig from .convert import export, export_models, validate_model_outputs, validate_models_outputs from .utils import ( - DiffusersPretrainedConfig, build_stable_diffusion_components_mandatory_shapes, get_stable_diffusion_models_for_export, ) diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index 1b16d36ab..67d95811c 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -498,6 +498,7 @@ def main_export( optlevel=optlevel, output_file_names=output_model_names, compiler_kwargs=compiler_kwargs, + model_name_or_path=model_name_or_path, ) # Validate compiled model diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index 2d54841b1..831b150db 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -25,20 +25,20 @@ from ...exporters.error_utils import OutputMatchError, ShapeError from ...neuron.utils import ( + DiffusersPretrainedConfig, convert_neuronx_compiler_args_to_neuron, is_neuron_available, is_neuronx_available, store_compilation_config, ) from ...neuron.utils.version_utils import get_neuroncc_version, get_neuronxcc_version -from ...neuron.utils.hub_neuronx_cache import ModelCacheEntry, hub_neuronx_cache, _create_hub_compile_cache_proxy, build_cache_config +from ...neuron.utils.hub_neuronx_cache import ModelCacheEntry, hub_neuronx_cache, _create_hub_compile_cache_proxy, build_cache_config, cache_aot_neuron_artifacts from ...neuron.utils.cache_utils import get_model_name_or_path, load_custom_cache_repo_name_from_hf_home from ...utils import ( is_diffusers_available, is_sentence_transformers_available, logging, ) -from .utils import DiffusersPretrainedConfig if TYPE_CHECKING: @@ -282,6 +282,7 @@ def export_models( output_file_names: Optional[Dict[str, str]] = None, compiler_kwargs: Optional[Dict[str, Any]] = {}, configs: Optional[Dict[str, Any]] = {}, + model_name_or_path: Optional[str] = None, ) -> Tuple[List[List[str]], List[List[str]]]: """ Exports a Pytorch model with multiple component models to separate files. @@ -309,6 +310,8 @@ def export_models( Arguments to pass to the Neuron(x) compiler for exporting Neuron models. configs (`Optional[Dict[str, Any]]`, defaults to `None`): A list of pretrained model configs. + model_name_or_path (`Optional[str]`, defaults to `None`): + Path to pretrained model or model identifier from huggingface.co/models. Returns: `Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named outputs from the Neuron configuration. @@ -324,6 +327,7 @@ def export_models( failed_models = [] total_compilation_time = 0 + compile_configs = dict() for i, model_name in enumerate(models_and_neuron_configs.keys()): logger.info(f"***** Compiling {model_name} *****") submodel, sub_neuron_config = models_and_neuron_configs[model_name] @@ -380,18 +384,7 @@ def export_models( output_hidden_states=getattr(sub_neuron_config, "output_hidden_states", False), ) model_config.save_pretrained(output_path.parent) - - # cache neuronx model - if not disable_neuron_cache and is_neuronx_available() and not model_config.neuron["inline_weights_to_neff"]: - model_id = get_model_name_or_path(model_config) - cache_config = build_cache_config(copy.deepcopy(model_config).to_diff_dict()) - cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config) - - # Use the context manager just for creating registry, AOT compilation won't leverage `create_compile_cache` - # in `libneuronxla`, so we will need to cache compiled artifacts to local manually. - with hub_neuronx_cache("inference", entry=cache_entry): - cache_aot_neuron_artifacts(neuron_dir=output_path.parent, cache_config_hash=cache_entry.hash) - + compile_configs[model_name] = model_config except Exception as e: failed_models.append((i, model_name)) output_path.parent.rmdir() @@ -399,7 +392,19 @@ def export_models( f"An error occured when trying to trace {model_name} with the error message: {e}.\n" f"The export is failed and {model_name} neuron model won't be stored." ) + logger.info(f"[Total compilation Time] {np.round(total_compilation_time, 2)} seconds.") + + # cache neuronx model + if not disable_neuron_cache and is_neuronx_available() and not inline_weights_to_neff: + model_id = get_model_name_or_path(model_config) if model_name_or_path is None else model_name_or_path + cache_config = build_cache_config(compile_configs) + cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config) + + # Use the context manager just for creating registry, AOT compilation won't leverage `create_compile_cache` + # in `libneuronxla`, so we will need to cache compiled artifacts to local manually. + with hub_neuronx_cache("inference", entry=cache_entry): + cache_aot_neuron_artifacts(neuron_dir=output_path.parent, cache_config_hash=cache_entry.hash) # remove models failed to export for i, model_name in failed_models: @@ -410,15 +415,6 @@ def export_models( return outputs -def cache_aot_neuron_artifacts(neuron_dir: Path, cache_config_hash: str): - cache_repo_id = load_custom_cache_repo_name_from_hf_home() - compile_cache = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id) - model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_config_hash}") - compile_cache.upload_folder(cache_dir=model_cache_dir, src_dir=neuron_dir) - - logger.info(f"Model cached in: {model_cache_dir}.") - - def export( model: "PreTrainedModel", config: "NeuronDefaultConfig", diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index 7e49381df..67a0eadea 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -73,19 +73,6 @@ from diffusers import ModelMixin, StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline -class DiffusersPretrainedConfig(PretrainedConfig): - # override to update `model_type` - def to_dict(self): - """ - Serializes this instance to a Python dictionary. - - Returns: - :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance. - """ - output = copy.deepcopy(self.__dict__) - return output - - def build_stable_diffusion_components_mandatory_shapes( batch_size: Optional[int] = None, sequence_length: Optional[int] = None, diff --git a/optimum/neuron/modeling_base.py b/optimum/neuron/modeling_base.py index a32560994..fb4b1dec6 100644 --- a/optimum/neuron/modeling_base.py +++ b/optimum/neuron/modeling_base.py @@ -279,7 +279,7 @@ def _export( "disable_fallback": disable_fallback, } - # CHECK IF CACHED + # Check if the cache exists compilation_config = store_compilation_config( config=config, input_shapes=kwargs_shapes, @@ -296,8 +296,6 @@ def _export( cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config) cache_repo_id = load_custom_cache_repo_name_from_hf_home() compile_cache = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id) - - # check if cache exists model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_entry.hash}") cache_exist = compile_cache.download_folder(model_cache_dir, model_cache_dir) diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 06e1b558f..21424158e 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -27,7 +27,7 @@ from huggingface_hub import snapshot_download from transformers import CLIPFeatureExtractor, CLIPTokenizer, PretrainedConfig -from ..exporters.neuron import DiffusersPretrainedConfig, main_export, normalize_stable_diffusion_input_shapes +from ..exporters.neuron import main_export, normalize_stable_diffusion_input_shapes from ..exporters.neuron.model_configs import * # noqa: F403 from ..exporters.tasks import TasksManager from ..utils import is_diffusers_available @@ -40,6 +40,7 @@ DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME, is_neuronx_available, + DiffusersPretrainedConfig, ) @@ -644,35 +645,43 @@ def _export( "disable_fast_relayout": disable_fast_relayout, "disable_fallback": disable_fallback, } - - save_dir = TemporaryDirectory() - save_dir_path = Path(save_dir.name) - - main_export( - model_name_or_path=model_id, - output=save_dir_path, - compiler_kwargs=compiler_kwargs, - task=task, - dynamic_batch_size=dynamic_batch_size, - cache_dir=cache_dir, - compiler_workdir=compiler_workdir, - inline_weights_to_neff=inline_weights_to_neff, - optlevel=optlevel, - trust_remote_code=trust_remote_code, - subfolder=subfolder, - revision=revision, - force_download=force_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - do_validation=False, - submodels={"unet": unet_id}, - lora_model_ids=lora_model_ids, - lora_weight_names=lora_weight_names, - lora_adapter_names=lora_adapter_names, - lora_scales=lora_scales, - library_name=cls.library_name, - **input_shapes, - ) + + # TODO: Check if the cache exists + cache_exist = False + + if cache_exist: + # load cache + pass + else: + # compile + save_dir = TemporaryDirectory() + save_dir_path = Path(save_dir.name) + + main_export( + model_name_or_path=model_id, + output=save_dir_path, + compiler_kwargs=compiler_kwargs, + task=task, + dynamic_batch_size=dynamic_batch_size, + cache_dir=cache_dir, + compiler_workdir=compiler_workdir, + inline_weights_to_neff=inline_weights_to_neff, + optlevel=optlevel, + trust_remote_code=trust_remote_code, + subfolder=subfolder, + revision=revision, + force_download=force_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + do_validation=False, + submodels={"unet": unet_id}, + lora_model_ids=lora_model_ids, + lora_weight_names=lora_weight_names, + lora_adapter_names=lora_adapter_names, + lora_scales=lora_scales, + library_name=cls.library_name, + **input_shapes, + ) return cls._from_pretrained( model_id=save_dir_path, diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 8cebeb893..b580a2334 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -35,7 +35,7 @@ is_transformers_neuronx_available, ) from .input_generators import DummyBeamValuesGenerator -from .misc import check_if_weights_replacable, replace_weights +from .misc import check_if_weights_replacable, replace_weights, DiffusersPretrainedConfig from .optimization_utils import get_attention_scores_sd, get_attention_scores_sdxl from .patching import DynamicPatch, ModelPatcher, Patcher, patch_everywhere, patch_within_function from .training_utils import ( diff --git a/optimum/neuron/utils/hub_neuronx_cache.py b/optimum/neuron/utils/hub_neuronx_cache.py index 2336145c0..0017d0d7d 100644 --- a/optimum/neuron/utils/hub_neuronx_cache.py +++ b/optimum/neuron/utils/hub_neuronx_cache.py @@ -16,6 +16,7 @@ import json import logging import os +import copy import shutil from contextlib import contextmanager from enum import Enum @@ -32,7 +33,6 @@ from .require_utils import requires_torch_neuronx, requires_torch_xla from .cache_utils import load_custom_cache_repo_name_from_hf_home - if is_neuronx_available(): from libneuronxla.neuron_cc_cache import ( CacheUrl, @@ -61,6 +61,9 @@ def create_compile_cache(): logger = logging.getLogger(__name__) +CACHE_WHITE_LIST = ["_name_or_path", "transformers_version", "_diffusers_version", "eos_token_id", "bos_token_id", "pad_token_id", "torchscript", "torch_dtype", "vocab_size", "_commit_hash", "sample_size"] +NEURON_CONFIG_WHITE_LIST = ["input_names", "output_names"] + class CompileCacheHfProxy(CompileCache): """A HuggingFace Hub proxy cache implementing the CompileCache API. @@ -162,8 +165,15 @@ def download_folder(self, folder_path: str, dst_path: str): return True else: rel_folder_path = self._rel_path(folder_path) - folder_info = list(self.api.list_repo_tree(self.repo_id, rel_folder_path)) - folder_exists = len(folder_info) > 1 + try: + folder_info = list(self.api.list_repo_tree(self.repo_id, rel_folder_path)) + folder_exists = len(folder_info) > 1 + except Exception as e: + logger.warning( + f"{rel_folder_path} not found in {self.repo_id}: {e} \nThe model will be recompiled." + ) + folder_exists = False + if folder_exists: # cached remotely for repo_content in folder_info: @@ -174,13 +184,8 @@ def download_folder(self, folder_path: str, dst_path: str): dst_path.mkdir(parents=True, exist_ok=True) os.symlink(local_path, dst_path / filename) logger.info(f"Fetched cached {rel_folder_path} from {self.repo_id}") - return True - else: - logger.warning( - f"{rel_folder_path} not found in {self.repo_id}: the corresponding graph will be recompiled." - " This may take up to one hour for large models." - ) - return False + + return folder_exists def synchronize(self): if isinstance(self.default_cache, CompileCacheS3): @@ -403,12 +408,17 @@ def get_hub_cached_entries( api = HfApi(endpoint=endpoint, token=token) repo_files = api.list_repo_files(cache_repo_id) # Get the config corresponding to the model - target_entry = ModelCacheEntry(model_id, (AutoConfig.from_pretrained(model_id))) + try: + config = AutoConfig.from_pretrained(model_id) + except Exception: + config = get_multimodels_configs(api, model_id) # Applied on SD, encoder-decoder models + target_entry = ModelCacheEntry(model_id, config) # Extract model type: it will be used as primary key for lookup model_type = target_entry.config["model_type"] registry_folder = get_registry_folder_for_mode(mode) registry_pattern = registry_folder + "/" + model_type model_files = [path for path in repo_files if registry_pattern in path] + white_list = CACHE_WHITE_LIST + ["task", ] # All parameters except those in the whitelist must match model_entries = [] with TemporaryDirectory() as tmpdir: for model_path in model_files: @@ -416,28 +426,101 @@ def get_hub_cached_entries( with open(local_path) as f: entry_config = json.load(f) # Remove neuron config for comparison as the target does not have it - neuron_config = entry_config.pop("neuron") - # All parameters except those in the whitelist must match - white_list = ["_name_or_path", "transformers_version", "eos_token_id", "bos_token_id", "pad_token_id", "torchscript", "torch_dtype", "task"] + if model_type=="stable-diffusion": + model_entries = lookup_match_entries_for_stable_diffusion(entry_config, target_entry, white_list, model_entries) + else: + neuron_config = entry_config.pop("neuron") + for param in white_list: + entry_config.pop(param, None) + target_entry.config.pop(param, None) + if entry_config == target_entry.config: + model_entries.append(neuron_config) + + return model_entries + + +def lookup_match_entries_for_stable_diffusion(entry_config, target_entry, white_list, model_entries): + neuron_config = entry_config["unet"].pop("neuron") + non_checked_components = ["vae", "vae_encoder", "vae_decoder"] + is_matched = True + for param in non_checked_components: + entry_config.pop(param, None) + target_entry.config.pop(param, None) + for name, value in entry_config.items(): + if isinstance(value, Dict): + for param in white_list: + value.pop(param, None) + target_entry.config[name].pop(param, None) + for term in set(entry_config[name]).intersection(set(target_entry.config[name])): + if entry_config[name][term] != target_entry.config[name][term]: + is_matched =False + if is_matched: + model_entries.append(neuron_config) + + return model_entries + + +def get_multimodels_configs(api, model_id): + repo_files = api.list_repo_files(model_id) + config_pattern = "/config.json" + config_files = [path for path in repo_files if config_pattern in path] + lookup_configs = dict() + with TemporaryDirectory() as tmpdir: + for model_path in config_files: + local_path = api.hf_hub_download(model_id, model_path, local_dir=tmpdir) + with open(local_path) as f: + entry_config = json.load(f) + white_list = CACHE_WHITE_LIST for param in white_list: entry_config.pop(param, None) - target_entry.config.pop(param, None) - if entry_config == target_entry.config: - model_entries.append(neuron_config) - return model_entries + lookup_configs[model_path.split("/")[-2]] = entry_config + + if "unet" in lookup_configs: + lookup_configs["model_type"] = "stable-diffusion" + return lookup_configs -# Only applied on traced TorchScript models -def build_cache_config(config: Dict, white_list: Optional[List] = None): - # TODO: consider case with multiple models thus multiple configs, eg. stable diffusion. Maybe concatenate. +def exclude_white_list_from_config(config: Dict, white_list: Optional[List] = None, neuron_white_list: Optional[List] = None): if white_list is None: - white_list = ["_name_or_path", "transformers_version", "eos_token_id", "bos_token_id", "pad_token_id", "torchscript", "torch_dtype"] + white_list = CACHE_WHITE_LIST + + if neuron_white_list is None: + neuron_white_list = NEURON_CONFIG_WHITE_LIST - neuron_white_list = ["input_names", "output_names"] for param in white_list: config.pop(param, None) for param in neuron_white_list: config["neuron"].pop(param, None) - return config \ No newline at end of file + return config + + +# Only applied on traced TorchScript models +def build_cache_config( + configs: Dict[str, PretrainedConfig], + white_list: Optional[List] = None, + neuron_white_list: Optional[List] = None, + +): + clean_configs = dict() + for name, config in configs.items(): + config = copy.deepcopy(config).to_diff_dict() if isinstance(config, PretrainedConfig) else config + config = exclude_white_list_from_config(config, white_list, neuron_white_list) + clean_configs[name] = config + + if "unet" in configs: + clean_configs["model_type"] = "stable-diffusion" + if len(clean_configs) > 1: + return clean_configs + else: + return next(iter(clean_configs.values())) + + +def cache_aot_neuron_artifacts(neuron_dir: Path, cache_config_hash: str): + cache_repo_id = load_custom_cache_repo_name_from_hf_home() + compile_cache = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id) + model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_config_hash}") + compile_cache.upload_folder(cache_dir=model_cache_dir, src_dir=neuron_dir) + + logger.info(f"Model cached in: {model_cache_dir}.") \ No newline at end of file diff --git a/optimum/neuron/utils/misc.py b/optimum/neuron/utils/misc.py index c619b2627..0aa6211e0 100644 --- a/optimum/neuron/utils/misc.py +++ b/optimum/neuron/utils/misc.py @@ -16,12 +16,13 @@ import inspect import os +import copy import re from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPProcessor +from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPProcessor, PretrainedConfig from transformers.modeling_utils import _add_variant from transformers.utils import ( FLAX_WEIGHTS_NAME, @@ -43,9 +44,6 @@ from .require_utils import requires_safetensors -if TYPE_CHECKING: - from transformers import PretrainedConfig - logger = logging.get_logger() @@ -606,3 +604,16 @@ def maybe_save_preprocessors( src_name_or_path, subfolder=src_subfolder, trust_remote_code=trust_remote_code ): preprocessor.save_pretrained(dest_dir) + + +class DiffusersPretrainedConfig(PretrainedConfig): + # override to update `model_type` + def to_dict(self): + """ + Serializes this instance to a Python dictionary. + + Returns: + :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + return output