Skip to content

Commit

Permalink
sd caching cli support
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Mar 15, 2024
1 parent 6cf9af6 commit b357e52
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 100 deletions.
2 changes: 0 additions & 2 deletions optimum/exporters/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
"base": ["NeuronDefaultConfig"],
"convert": ["export", "export_models", "validate_model_outputs", "validate_models_outputs"],
"utils": [
"DiffusersPretrainedConfig",
"build_stable_diffusion_components_mandatory_shapes",
"get_stable_diffusion_models_for_export",
],
Expand All @@ -41,7 +40,6 @@
from .base import NeuronDefaultConfig
from .convert import export, export_models, validate_model_outputs, validate_models_outputs
from .utils import (
DiffusersPretrainedConfig,
build_stable_diffusion_components_mandatory_shapes,
get_stable_diffusion_models_for_export,
)
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def main_export(
optlevel=optlevel,
output_file_names=output_model_names,
compiler_kwargs=compiler_kwargs,
model_name_or_path=model_name_or_path,
)

# Validate compiled model
Expand Down
42 changes: 19 additions & 23 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@

from ...exporters.error_utils import OutputMatchError, ShapeError
from ...neuron.utils import (
DiffusersPretrainedConfig,
convert_neuronx_compiler_args_to_neuron,
is_neuron_available,
is_neuronx_available,
store_compilation_config,
)
from ...neuron.utils.version_utils import get_neuroncc_version, get_neuronxcc_version
from ...neuron.utils.hub_neuronx_cache import ModelCacheEntry, hub_neuronx_cache, _create_hub_compile_cache_proxy, build_cache_config
from ...neuron.utils.hub_neuronx_cache import ModelCacheEntry, hub_neuronx_cache, _create_hub_compile_cache_proxy, build_cache_config, cache_aot_neuron_artifacts
from ...neuron.utils.cache_utils import get_model_name_or_path, load_custom_cache_repo_name_from_hf_home
from ...utils import (
is_diffusers_available,
is_sentence_transformers_available,
logging,
)
from .utils import DiffusersPretrainedConfig


if TYPE_CHECKING:
Expand Down Expand Up @@ -282,6 +282,7 @@ def export_models(
output_file_names: Optional[Dict[str, str]] = None,
compiler_kwargs: Optional[Dict[str, Any]] = {},
configs: Optional[Dict[str, Any]] = {},
model_name_or_path: Optional[str] = None,
) -> Tuple[List[List[str]], List[List[str]]]:
"""
Exports a Pytorch model with multiple component models to separate files.
Expand Down Expand Up @@ -309,6 +310,8 @@ def export_models(
Arguments to pass to the Neuron(x) compiler for exporting Neuron models.
configs (`Optional[Dict[str, Any]]`, defaults to `None`):
A list of pretrained model configs.
model_name_or_path (`Optional[str]`, defaults to `None`):
Path to pretrained model or model identifier from huggingface.co/models.
Returns:
`Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named
outputs from the Neuron configuration.
Expand All @@ -324,6 +327,7 @@ def export_models(

failed_models = []
total_compilation_time = 0
compile_configs = dict()
for i, model_name in enumerate(models_and_neuron_configs.keys()):
logger.info(f"***** Compiling {model_name} *****")
submodel, sub_neuron_config = models_and_neuron_configs[model_name]
Expand Down Expand Up @@ -380,26 +384,27 @@ def export_models(
output_hidden_states=getattr(sub_neuron_config, "output_hidden_states", False),
)
model_config.save_pretrained(output_path.parent)

# cache neuronx model
if not disable_neuron_cache and is_neuronx_available() and not model_config.neuron["inline_weights_to_neff"]:
model_id = get_model_name_or_path(model_config)
cache_config = build_cache_config(copy.deepcopy(model_config).to_diff_dict())
cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config)

# Use the context manager just for creating registry, AOT compilation won't leverage `create_compile_cache`
# in `libneuronxla`, so we will need to cache compiled artifacts to local manually.
with hub_neuronx_cache("inference", entry=cache_entry):
cache_aot_neuron_artifacts(neuron_dir=output_path.parent, cache_config_hash=cache_entry.hash)

compile_configs[model_name] = model_config
except Exception as e:
failed_models.append((i, model_name))
output_path.parent.rmdir()
logger.error(
f"An error occured when trying to trace {model_name} with the error message: {e}.\n"
f"The export is failed and {model_name} neuron model won't be stored."
)

logger.info(f"[Total compilation Time] {np.round(total_compilation_time, 2)} seconds.")

# cache neuronx model
if not disable_neuron_cache and is_neuronx_available() and not inline_weights_to_neff:
model_id = get_model_name_or_path(model_config) if model_name_or_path is None else model_name_or_path
cache_config = build_cache_config(compile_configs)
cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config)

# Use the context manager just for creating registry, AOT compilation won't leverage `create_compile_cache`
# in `libneuronxla`, so we will need to cache compiled artifacts to local manually.
with hub_neuronx_cache("inference", entry=cache_entry):
cache_aot_neuron_artifacts(neuron_dir=output_path.parent, cache_config_hash=cache_entry.hash)

# remove models failed to export
for i, model_name in failed_models:
Expand All @@ -410,15 +415,6 @@ def export_models(
return outputs


def cache_aot_neuron_artifacts(neuron_dir: Path, cache_config_hash: str):
cache_repo_id = load_custom_cache_repo_name_from_hf_home()
compile_cache = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id)
model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_config_hash}")
compile_cache.upload_folder(cache_dir=model_cache_dir, src_dir=neuron_dir)

logger.info(f"Model cached in: {model_cache_dir}.")


def export(
model: "PreTrainedModel",
config: "NeuronDefaultConfig",
Expand Down
13 changes: 0 additions & 13 deletions optimum/exporters/neuron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,6 @@
from diffusers import ModelMixin, StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline


class DiffusersPretrainedConfig(PretrainedConfig):
# override to update `model_type`
def to_dict(self):
"""
Serializes this instance to a Python dictionary.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
return output


def build_stable_diffusion_components_mandatory_shapes(
batch_size: Optional[int] = None,
sequence_length: Optional[int] = None,
Expand Down
4 changes: 1 addition & 3 deletions optimum/neuron/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def _export(
"disable_fallback": disable_fallback,
}

# CHECK IF CACHED
# Check if the cache exists
compilation_config = store_compilation_config(
config=config,
input_shapes=kwargs_shapes,
Expand All @@ -296,8 +296,6 @@ def _export(
cache_entry = ModelCacheEntry(model_id=model_id, config=cache_config)
cache_repo_id = load_custom_cache_repo_name_from_hf_home()
compile_cache = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id)

# check if cache exists
model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_entry.hash}")
cache_exist = compile_cache.download_folder(model_cache_dir, model_cache_dir)

Expand Down
69 changes: 39 additions & 30 deletions optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from huggingface_hub import snapshot_download
from transformers import CLIPFeatureExtractor, CLIPTokenizer, PretrainedConfig

from ..exporters.neuron import DiffusersPretrainedConfig, main_export, normalize_stable_diffusion_input_shapes
from ..exporters.neuron import main_export, normalize_stable_diffusion_input_shapes
from ..exporters.neuron.model_configs import * # noqa: F403
from ..exporters.tasks import TasksManager
from ..utils import is_diffusers_available
Expand All @@ -40,6 +40,7 @@
DIFFUSION_MODEL_VAE_ENCODER_NAME,
NEURON_FILE_NAME,
is_neuronx_available,
DiffusersPretrainedConfig,
)


Expand Down Expand Up @@ -644,35 +645,43 @@ def _export(
"disable_fast_relayout": disable_fast_relayout,
"disable_fallback": disable_fallback,
}

save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)

main_export(
model_name_or_path=model_id,
output=save_dir_path,
compiler_kwargs=compiler_kwargs,
task=task,
dynamic_batch_size=dynamic_batch_size,
cache_dir=cache_dir,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
trust_remote_code=trust_remote_code,
subfolder=subfolder,
revision=revision,
force_download=force_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
do_validation=False,
submodels={"unet": unet_id},
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
library_name=cls.library_name,
**input_shapes,
)

# TODO: Check if the cache exists
cache_exist = False

if cache_exist:
# load cache
pass
else:
# compile
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)

main_export(
model_name_or_path=model_id,
output=save_dir_path,
compiler_kwargs=compiler_kwargs,
task=task,
dynamic_batch_size=dynamic_batch_size,
cache_dir=cache_dir,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
trust_remote_code=trust_remote_code,
subfolder=subfolder,
revision=revision,
force_download=force_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
do_validation=False,
submodels={"unet": unet_id},
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
library_name=cls.library_name,
**input_shapes,
)

return cls._from_pretrained(
model_id=save_dir_path,
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
is_transformers_neuronx_available,
)
from .input_generators import DummyBeamValuesGenerator
from .misc import check_if_weights_replacable, replace_weights
from .misc import check_if_weights_replacable, replace_weights, DiffusersPretrainedConfig
from .optimization_utils import get_attention_scores_sd, get_attention_scores_sdxl
from .patching import DynamicPatch, ModelPatcher, Patcher, patch_everywhere, patch_within_function
from .training_utils import (
Expand Down
Loading

0 comments on commit b357e52

Please sign in to comment.