Skip to content

Commit

Permalink
all inf2 tests clear up
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Mar 26, 2024
1 parent 136a043 commit 59fe430
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 40 deletions.
4 changes: 3 additions & 1 deletion optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ def infer_stable_diffusion_shapes_from_diffusers(
scaled_width = width // vae_scale_factor

input_shapes["text_encoder"].update({"sequence_length": sequence_length})
if hasattr(model, "text_encoder_2"):
input_shapes["text_encoder_2"] = input_shapes["text_encoder"]
input_shapes["unet"].update(
{
"sequence_length": sequence_length,
Expand Down Expand Up @@ -290,7 +292,7 @@ def _get_submodels_and_neuron_configs(
task=task,
library_name=library_name,
)
check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes)
input_shapes = check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes)
neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, **input_shapes)
model_name = getattr(model, "name_or_path", None) or model_name_or_path
model_name = model_name.split("/")[-1] if model_name else model.config.model_type
Expand Down
4 changes: 1 addition & 3 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,13 @@ def export_models(
output_path = output_dir / output_file_name
output_path.parent.mkdir(parents=True, exist_ok=True)

compiler_workdir_path = compiler_workdir / model_name if compiler_workdir is not None else None

try:
start_time = time.time()
neuron_inputs, neuron_outputs = export(
model=submodel,
config=sub_neuron_config,
output=output_path,
compiler_workdir=compiler_workdir_path,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
**compiler_kwargs,
Expand Down
32 changes: 25 additions & 7 deletions optimum/exporters/neuron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,27 @@ def get_stable_diffusion_models_for_export(

def _load_lora_weights_to_pipeline(
pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLImg2ImgPipeline"],
lora_model_ids: Optional[List[str]] = None,
weight_names: Optional[List[str]] = None,
adapter_names: Optional[List[str]] = None,
lora_scales: Optional[List[float]] = None,
lora_model_ids: Optional[Union[str, List[str]]] = None,
weight_names: Optional[Union[str, List[str]]] = None,
adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
):
if isinstance(lora_model_ids, str):
lora_model_ids = [
lora_model_ids,
]
if isinstance(weight_names, str):
weight_names = [
weight_names,
]
if isinstance(adapter_names, str):
adapter_names = [
adapter_names,
]
if isinstance(lora_scales, float):
lora_scales = [
lora_scales,
]
if lora_model_ids and weight_names:
if len(lora_model_ids) == 1:
pipeline.load_lora_weights(lora_model_ids[0], weight_name=weight_names[0])
Expand All @@ -288,9 +304,9 @@ def _load_lora_weights_to_pipeline(
def get_submodels_for_export_stable_diffusion(
pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLImg2ImgPipeline"],
task: str,
lora_model_ids: Optional[List[str]] = None,
lora_weight_names: Optional[List[str]] = None,
lora_adapter_names: Optional[List[str]] = None,
lora_model_ids: Optional[Union[str, List[str]]] = None,
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[List[float]] = None,
) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]:
"""
Expand Down Expand Up @@ -388,6 +404,8 @@ def check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes):
raise AttributeError(
f"Cannot find the value of `{name}` which is mandatory for exporting the model to the neuron format, please set the value explicitly."
)
input_shapes = {axis: input_shapes[axis] for axis in mandatory_shapes}
return input_shapes


def replace_stable_diffusion_submodels(pipeline, submodels):
Expand Down
58 changes: 35 additions & 23 deletions optimum/neuron/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def _export(
force_download: bool = False,
cache_dir: Optional[str] = None,
compiler_workdir: Optional[Union[str, Path]] = None,
disable_neuron_cache: Optional[bool] = False,
inline_weights_to_neff: bool = False,
optlevel: str = "2",
subfolder: str = "",
Expand Down Expand Up @@ -277,7 +278,9 @@ def _export(
"disable_fallback": disable_fallback,
}

if not inline_weights_to_neff:
if (
not inline_weights_to_neff and not disable_neuron_cache and is_neuronx_available()
): # TODO: support caching of Inf1 as well
# Check if the cache exists
compilation_config = store_compilation_config(
config=config,
Expand All @@ -296,30 +299,38 @@ def _export(
cache_repo_id = load_custom_cache_repo_name_from_hf_home()
compile_cache = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id)
model_cache_dir = compile_cache.default_cache.get_cache_dir_with_cache_key(f"MODULE_{cache_entry.hash}")
cache_exist = compile_cache.download_folder(model_cache_dir, model_cache_dir)
cache_available = compile_cache.download_folder(model_cache_dir, model_cache_dir)
else:
cache_exist = False
cache_available = False

# load cache
if cache_available:
try:
neuron_model = cls.from_pretrained(model_cache_dir)
model = TasksManager.get_model_from_task(
task=task,
model_name_or_path=model_id,
subfolder=subfolder,
revision=revision,
framework="pt",
library_name=library_name,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
)
# replace weights
neuron_model.replace_weights(weights=model)
return neuron_model
except Exception as e:
logger.warning(
f"Found the cached artifacts but failed to re-load them with error: {e}. \n Falling back to recompilation."
)
cache_available = False

if cache_exist:
# load cache
neuron_model = cls.from_pretrained(model_cache_dir)
model = TasksManager.get_model_from_task(
task=task,
model_name_or_path=model_id,
subfolder=subfolder,
revision=revision,
framework="pt",
library_name=library_name,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
)
# replace weights
neuron_model.replace_weights(weights=model)
return neuron_model
else:
# compile
if not cache_available:
# compile
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
Expand All @@ -330,6 +341,7 @@ def _export(
task=task,
dynamic_batch_size=dynamic_batch_size,
cache_dir=cache_dir,
disable_neuron_cache=disable_neuron_cache,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
Expand Down
11 changes: 10 additions & 1 deletion optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
_create_hub_compile_cache_proxy,
build_cache_config,
)
from .utils.require_utils import requires_torch_neuronx
from .utils.version_utils import get_neuronxcc_version


Expand Down Expand Up @@ -265,6 +266,7 @@ def is_lcm(unet_config):
return any(pattern in unet_name_or_path for pattern in patterns)

@staticmethod
@requires_torch_neuronx
def load_model(
data_parallel_mode: Optional[str],
text_encoder_path: Union[str, Path],
Expand Down Expand Up @@ -426,6 +428,7 @@ def _save_pretrained(
self.feature_extractor.save_pretrained(save_directory.joinpath("feature_extractor"))

@classmethod
@requires_torch_neuronx
def _from_pretrained(
cls,
model_id: Union[str, Path],
Expand Down Expand Up @@ -558,11 +561,13 @@ def _from_pretrained(
)

@classmethod
@requires_torch_neuronx
def _from_transformers(cls, *args, **kwargs):
# Deprecate it when optimum uses `_export` as from_pretrained_method in a stable release.
return cls._export(*args, **kwargs)

@classmethod
@requires_torch_neuronx
def _export(
cls,
model_id: Union[str, Path],
Expand All @@ -573,6 +578,7 @@ def _export(
force_download: bool = True,
cache_dir: Optional[str] = None,
compiler_workdir: Optional[str] = None,
disable_neuron_cache: Optional[bool] = False,
inline_weights_to_neff: bool = False,
optlevel: str = "2",
subfolder: str = "",
Expand Down Expand Up @@ -616,6 +622,8 @@ def _export(
standard cache should not be used.
compiler_workdir (`Optional[str]`, defaults to `None`):
Path to a directory in which the neuron compiler will store all intermediary files during the compilation(neff, weight, hlo graph...).
disable_neuron_cache (`bool`, defaults to `False`):
Whether to disable automatic caching of compiled models. If set to True, will not load neuron cache nor cache the compiled artifacts.
inline_weights_to_neff (`bool`, defaults to `False`):
Whether to inline the weights to the neff graph. If set to False, weights will be seperated from the neff.
optlevel (`str`, defaults to `"2"`):
Expand Down Expand Up @@ -685,7 +693,7 @@ def _export(
pipe = replace_stable_diffusion_submodels(pipe, submodels)

# Check if the cache exists
if not inline_weights_to_neff:
if not inline_weights_to_neff and not disable_neuron_cache:
# 1. Fetch all model configs
models_for_export = get_submodels_for_export_stable_diffusion(
pipeline=pipe,
Expand Down Expand Up @@ -757,6 +765,7 @@ def _export(
task=task,
dynamic_batch_size=dynamic_batch_size,
cache_dir=cache_dir,
disable_neuron_cache=disable_neuron_cache,
compiler_workdir=compiler_workdir,
inline_weights_to_neff=inline_weights_to_neff,
optlevel=optlevel,
Expand Down
1 change: 0 additions & 1 deletion optimum/neuron/utils/hub_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def create_compile_cache():
"_commit_hash",
"sample_size",
"projection_dim",
"task",
"_use_default_values",
]
NEURON_CONFIG_WHITE_LIST = ["input_names", "output_names", "model_type"]
Expand Down
53 changes: 52 additions & 1 deletion tests/cache/test_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from transformers import AutoTokenizer
from transformers.testing_utils import ENDPOINT_STAGING

from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSequenceClassification, NeuronStableDiffusionPipeline
from optimum.neuron import (
NeuronModelForCausalLM,
NeuronModelForSequenceClassification,
NeuronStableDiffusionPipeline,
NeuronStableDiffusionXLPipeline,
)
from optimum.neuron.utils import get_hub_cached_entries, synchronize_hub_cache
from optimum.neuron.utils.cache_utils import (
CACHE_REPO_FILENAME,
Expand Down Expand Up @@ -122,6 +127,21 @@ def export_stable_diffusion_model(model_id):
)


def export_stable_diffusion_xl_model(model_id):
batch_size = 1
height = 64
width = 64
num_images_per_prompt = 4
return NeuronStableDiffusionXLPipeline.from_pretrained(
model_id,
export=True,
batch_size=batch_size,
height=height,
width=width,
num_images_per_prompt=num_images_per_prompt,
)


def check_decoder_generation(model):
batch_size = model.config.neuron["batch_size"]
input_ids = torch.ones((batch_size, 20), dtype=torch.int64)
Expand Down Expand Up @@ -272,6 +292,37 @@ def test_stable_diffusion_cache(cache_repos):
unset_custom_cache_repo_name_in_hf_home()


@is_inferentia_test
@requires_neuronx
def test_stable_diffusion_xl_cache(cache_repos):
cache_path, cache_repo_id = cache_repos
model_id = "echarlaix/tiny-random-stable-diffusion-xl"
# Export the model a first time to populate the local cache
model = export_stable_diffusion_xl_model(model_id)
check_stable_diffusion_inference(model)
# check registry
check_aot_cache_entry(cache_path)
# Synchronize the hub cache with the local cache
synchronize_hub_cache(cache_repo_id=cache_repo_id)
assert_local_and_hub_cache_sync(cache_path, cache_repo_id)
# Verify we are able to fetch the cached entry for the model
model_entries = get_hub_cached_entries(model_id, "inference", cache_repo_id=cache_repo_id)
assert len(model_entries) == 1
# Clear the local cache
for root, dirs, files in os.walk(cache_path):
for f in files:
os.unlink(os.path.join(root, f))
for d in dirs:
shutil.rmtree(os.path.join(root, d))
assert local_cache_size(cache_path) == 0
# Export the model again: the compilation artifacts should be fetched from the Hub
model = export_stable_diffusion_xl_model(model_id)
check_stable_diffusion_inference(model)
# Verify the local cache directory has not been populated
assert len(get_local_cached_files(cache_path, ".neuron")) == 0
unset_custom_cache_repo_name_in_hf_home()


@is_inferentia_test
@requires_neuronx
@pytest.mark.parametrize(
Expand Down
11 changes: 8 additions & 3 deletions tests/inference/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,14 @@ def test_save_compiler_intermediary_files(self):
save_path = f"{tempdir}/neff"
neff_path = os.path.join(save_path, "graph.neff")
_ = NeuronModelForSequenceClassification.from_pretrained(
self.MODEL_ID, export=True, compiler_workdir=save_path, **self.STATIC_INPUTS_SHAPES
self.MODEL_ID,
export=True,
compiler_workdir=save_path,
disable_neuron_cache=True,
**self.STATIC_INPUTS_SHAPES,
)
self.assertTrue(os.path.isdir(save_path))
os.listdir(save_path)
self.assertTrue(os.path.exists(neff_path))

@requires_neuronx
Expand Down Expand Up @@ -656,7 +661,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self):
"hf-internal-testing/tiny-random-t5", from_transformers=True, **self.STATIC_INPUTS_SHAPES
)

self.assertIn("is not supported yet", str(context.exception))
self.assertIn("doesn't support", str(context.exception))

@parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True)
@requires_neuronx
Expand Down Expand Up @@ -862,7 +867,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self):
"hf-internal-testing/tiny-random-t5", from_transformers=True, **self.STATIC_INPUTS_SHAPES
)

self.assertIn("is not supported yet", str(context.exception))
self.assertIn("doesn't support", str(context.exception))

@parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True)
@requires_neuronx
Expand Down

0 comments on commit 59fe430

Please sign in to comment.