diff --git a/docs/assets/guides/models/02-sdxl-image.jpeg b/docs/assets/guides/models/02-sdxl-image.jpeg new file mode 100644 index 000000000..11a17f7c3 Binary files /dev/null and b/docs/assets/guides/models/02-sdxl-image.jpeg differ diff --git a/docs/source/guides/export_model.mdx b/docs/source/guides/export_model.mdx index 5936b90cf..fb1d98a7a 100644 --- a/docs/source/guides/export_model.mdx +++ b/docs/source/guides/export_model.mdx @@ -239,7 +239,39 @@ optimum-cli export neuron --model stabilityai/stable-diffusion-2-1-base \ --batch_size 1 \ --height 512 `# height in pixels of generated image, eg. 512, 768` \ --width 512 `# width in pixels of generated image, eg. 512, 768` \ - --num_image_per_prompt 4 `# number of images to generate per prompt, defaults to 1` \ + --num_images_per_prompt 4 `# number of images to generate per prompt, defaults to 1` \ + --auto_cast matmul `# cast only matrix multiplication operations` \ + --auto_cast_type bf16 `# cast operations from FP32 to BF16` \ + sd_neuron/ +``` + +## Exporting Stable Diffusion XL to Neuron + +Similar to Stable Diffusion, you will be able to use Optimum CLI to compile components in the SDXL pipeline for inference on neuron devices. + +We support the export of following components in the pipeline to boost the speed: + +* Text encoder +* Second text encoder +* U-Net (a three times larger UNet than the one in Stable Diffusion pipeline) +* VAE encoder +* VAE decoder + + + +"Stable Diffusion XL works especially well with images between 768 and 1024." + + + +Exporting a SDXL checkpoint can be done using the CLI: + +```bash +optimum-cli export neuron --model stabilityai/stable-diffusion-xl-base-1.0 \ + --task stable-diffusion-xl \ + --batch_size 1 \ + --height 1024 `# height in pixels of generated image, eg. 768, 1024` \ + --width 1024 `# width in pixels of generated image, eg. 768, 1024` \ + --num_images_per_prompt 4 `# number of images to generate per prompt, defaults to 1` \ --auto_cast matmul `# cast only matrix multiplication operations` \ --auto_cast_type bf16 `# cast operations from FP32 to BF16` \ sd_neuron/ diff --git a/docs/source/guides/models.mdx b/docs/source/guides/models.mdx index f33f33316..8d3f9662a 100644 --- a/docs/source/guides/models.mdx +++ b/docs/source/guides/models.mdx @@ -67,7 +67,7 @@ And the next time when you want to run inference, just load your compiled model As you see, there is no need to pass the neuron arguments used during the export as they are saved in a `config.json` file, and will be restored automatically by `NeuronModelForXXX` class. -## Export and inference of Discriminative NLP models +## Discriminative NLP models As explained in the previous section, you will need only few modifications to your Transformers code to export and run NLP models: @@ -133,7 +133,7 @@ No worries, `NeuronModelForXXX` class will pad your inputs to an eligible shape. -## Export and inference of Generative NLP models +## Generative NLP models As explained before, you will need only a few modifications to your Transformers code to export and run NLP models: @@ -196,7 +196,7 @@ with torch.inference_mode(): print(outputs) ``` -## Inference of Stable Diffusion Models +## Stable Diffusion Optimum extends 🤗`Diffusers` to support inference on Neuron. To get started, make sure you have installed Diffusers: @@ -247,7 +247,40 @@ Now generate an image with a prompt on neuron: search ami + +## Stable Diffusion XL + +Similar to Stable Diffusion, you will be able to use `NeuronStableDiffusionXLPipeline` API to export and run inference on Neuron devices with SDXL models. + +```python +>>> from optimum.neuron import NeuronStableDiffusionXLPipeline + +>>> model_id = "stabilityai/stable-diffusion-xl-base-1.0" +>>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} +>>> input_shapes = {"batch_size": 1, "height": 1024, "width": 1024} + +>>> stable_diffusion_xl = NeuronStableDiffusionXLPipeline.from_pretrained(model_id, export=True, **compiler_args, **input_shapes) + +# Save locally or upload to the HuggingFace Hub +>>> save_directory = "sd_neuron_xl/" +>>> stable_diffusion_xl.save_pretrained(save_directory) +>>> stable_diffusion_xl.push_to_hub( +... save_directory, repository_id="my-neuron-repo", use_auth_token=True +... ) +``` + +Now generate an image with a prompt on neuron: + +```python +>>> prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" +>>> image = stable_diffusion_xl(prompt).images[0] +``` + +sdxl generated image diff --git a/optimum/exporters/neuron/__init__.py b/optimum/exporters/neuron/__init__.py index a23c8e230..c6ea726f0 100644 --- a/optimum/exporters/neuron/__init__.py +++ b/optimum/exporters/neuron/__init__.py @@ -13,7 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .__main__ import main_export, normalize_input_shapes, normalize_stable_diffusion_input_shapes +from .__main__ import ( + infer_stable_diffusion_shapes_from_diffusers, + main_export, + normalize_input_shapes, + normalize_stable_diffusion_input_shapes, +) from .base import NeuronConfig from .convert import export, export_models, validate_model_outputs, validate_models_outputs from .utils import ( diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index 82b522313..f4c4d340e 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -116,21 +116,20 @@ def normalize_stable_diffusion_input_shapes( ) -> Dict[str, Dict[str, int]]: args = vars(args) if isinstance(args, argparse.Namespace) else args mandatory_axes = set(getattr(inspect.getfullargspec(build_stable_diffusion_components_mandatory_shapes), "args")) - # Remove `sequence_length` as diffusers will pad it to the max and remove number of channels . + # Remove `sequence_length` as diffusers will pad it to the max and remove number of channels. mandatory_axes = mandatory_axes - { "sequence_length", "unet_num_channels", "vae_encoder_num_channels", "vae_decoder_num_channels", + "num_images_per_prompt", # default to 1 } if not mandatory_axes.issubset(set(args.keys())): raise AttributeError( f"Shape of {mandatory_axes} are mandatory for neuron compilation, while {mandatory_axes.difference(args.keys())} are not given." ) mandatory_shapes = {name: args[name] for name in mandatory_axes} - if "num_images_per_prompt" in args and args["num_images_per_prompt"] > 1: - batch_size = args["num_images_per_prompt"] * args["batch_size"] - mandatory_shapes["batch_size"] = batch_size + mandatory_shapes["num_images_per_prompt"] = args.get("num_images_per_prompt", 1) input_shapes = build_stable_diffusion_components_mandatory_shapes(**mandatory_shapes) return input_shapes @@ -184,18 +183,19 @@ def main_export( task = TasksManager.map_from_synonym(task) - model = TasksManager.get_model_from_task( - task=task, - model_name_or_path=model_name_or_path, - subfolder=subfolder, - revision=revision, - cache_dir=cache_dir, - use_auth_token=use_auth_token, - local_files_only=local_files_only, - force_download=force_download, - trust_remote_code=trust_remote_code, - framework="pt", - ) + model_kwargs = { + "task": task, + "model_name_or_path": model_name_or_path, + "subfolder": subfolder, + "revision": revision, + "cache_dir": cache_dir, + "use_auth_token": use_auth_token, + "local_files_only": local_files_only, + "force_download": force_download, + "trust_remote_code": trust_remote_code, + "framework": "pt", + } + model = TasksManager.get_model_from_task(**model_kwargs) is_stable_diffusion = "stable-diffusion" in task if not is_stable_diffusion: @@ -216,12 +216,6 @@ def main_export( "Stable diffusion export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead." ) input_shapes = infer_stable_diffusion_shapes_from_diffusers(input_shapes, model) - models_and_neuron_configs = get_stable_diffusion_models_for_export( - pipeline=model, - task=task, - dynamic_batch_size=dynamic_batch_size, - **input_shapes, - ) # Saving the model config and preprocessor as this is needed sometimes. model.scheduler.save_pretrained(output.joinpath("scheduler")) @@ -232,6 +226,12 @@ def main_export( model.feature_extractor.save_pretrained(output.joinpath("feature_extractor")) model.save_config(output) + models_and_neuron_configs = get_stable_diffusion_models_for_export( + pipeline=model, + task=task, + dynamic_batch_size=dynamic_batch_size, + **input_shapes, + ) output_model_names = { DIFFUSION_MODEL_TEXT_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_TEXT_ENCODER_NAME, NEURON_FILE_NAME), DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME), @@ -242,6 +242,7 @@ def main_export( output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join( DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME ) + del model _, neuron_outputs = export_models( models_and_neuron_configs=models_and_neuron_configs, @@ -250,8 +251,6 @@ def main_export( compiler_kwargs=compiler_kwargs, ) - del model - # Validate compiled model if do_validation is True: if is_stable_diffusion: diff --git a/optimum/exporters/neuron/base.py b/optimum/exporters/neuron/base.py index 035c791be..d181e7c38 100644 --- a/optimum/exporters/neuron/base.py +++ b/optimum/exporters/neuron/base.py @@ -276,12 +276,28 @@ def generate_dummy_inputs( else: return dummy_inputs + @classmethod + def flatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: + """ + Flatten nested structure in dummy inputs, e.g `addition_embed_type` of unet model. + """ + flatten = {} + for name, value in inputs.items(): + if isinstance(value, dict): + for sub_name, sub_value in value.items(): + flatten[sub_name] = sub_value + else: + flatten[name] = value + return flatten + def check_model_inputs_order( self, model: "PreTrainedModel", - dummy_inputs: Dict[str, torch.Tensor], + dummy_inputs: Optional[Dict[str, torch.Tensor]] = None, forward_with_tuple: bool = False, eligible_outputs: Optional[List[Union[str, int]]] = None, + custom_model_wrapper: Optional[torch.nn.Module] = None, + custom_wrapper_kwargs: Optional[Dict] = None, ): """ Checks if inputs order of the model's forward pass correspond to the generated dummy inputs to ensure the dummy inputs tuple used for @@ -320,7 +336,14 @@ def forward(self, *input): return outputs - return ModelWrapper(model, list(dummy_inputs.keys())) + if custom_model_wrapper: + return ( + custom_model_wrapper(model) + if custom_wrapper_kwargs is None + else custom_model_wrapper(model, **custom_wrapper_kwargs) + ) + else: + return ModelWrapper(model, list(dummy_inputs.keys())) class NeuronDecoderConfig(ExportConfig): diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index 59aa6e1fc..a9deefa9b 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Neuron compiled model check and export functions.""" - import copy +import os import time from collections import OrderedDict from pathlib import Path @@ -22,7 +22,6 @@ import numpy as np import torch -from packaging import version from transformers import PretrainedConfig from ...exporters.error_utils import OutputMatchError, ShapeError @@ -178,7 +177,7 @@ def validate_model_outputs( neuron_inputs = ref_inputs else: ref_outputs = reference_model(**ref_inputs) - neuron_inputs = tuple(ref_inputs.values()) + neuron_inputs = tuple(config.flatten_inputs(ref_inputs).values()) # Neuron outputs neuron_model = torch.jit.load(neuron_model_path) @@ -292,6 +291,7 @@ def export_models( ) failed_models = [] + total_compilation_time = 0 for i, model_name in enumerate(models_and_neuron_configs.keys()): logger.info(f"***** Compiling {model_name} *****") submodel, sub_neuron_config = models_and_neuron_configs[model_name] @@ -311,6 +311,7 @@ def export_models( **compiler_kwargs, ) compilation_time = time.time() - start_time + total_compilation_time += compilation_time logger.info(f"[Compilation Time] {np.round(compilation_time, 2)} seconds.") outputs.append((neuron_inputs, neuron_outputs)) # Add neuron specific configs to model components' original config @@ -347,6 +348,7 @@ def export_models( f"An error occured when trying to trace {model_name} with the error message: {e}.\n" f"The export is failed and {model_name} neuron model won't be stored." ) + logger.info(f"[Total compilation Time] {np.round(total_compilation_time, 2)} seconds.") # remove models failed to export for i, model_name in failed_models: @@ -421,6 +423,7 @@ def export_neuronx( input_shapes[axis] = getattr(config, axis) dummy_inputs = config.generate_dummy_inputs(**input_shapes) + dummy_inputs = config.flatten_inputs(dummy_inputs) dummy_inputs_tuple = tuple(dummy_inputs.values()) checked_model = config.check_model_inputs_order(model, dummy_inputs) @@ -435,6 +438,10 @@ def export_neuronx( else: compiler_args = ["--auto-cast", "none"] + # WARNING: Enabled experimental parallel compilation + compiler_args.extend(["--enable-experimental-O1"]) + compiler_args.extend(["--num-parallel-jobs", str(os.cpu_count())]) + # diffusers specific compiler_args = add_stable_diffusion_compiler_args(config, compiler_args) @@ -447,6 +454,9 @@ def export_neuronx( improve_stable_diffusion_loading(config, neuron_model) torch.jit.save(neuron_model, output) + del model + del checked_model + del dummy_inputs del neuron_model return config.inputs, config.outputs @@ -454,24 +464,26 @@ def export_neuronx( def add_stable_diffusion_compiler_args(config, compiler_args): if hasattr(config._config, "_name_or_path"): - sd_components = ["text_encoder", "unet", "vae", "vae_encoder", "vae_decoder"] + sd_components = ["text_encoder", "vae", "vae_encoder", "vae_decoder"] if any(component in config._config._name_or_path.lower() for component in sd_components): compiler_args.extend(["--enable-fast-loading-neuron-binaries"]) # unet if "unet" in config._config._name_or_path.lower(): + # SDXL unet doesn't support fast loading neuron binaries + if "stable-diffusion-xl" not in config._config._name_or_path.lower(): + compiler_args.extend(["--enable-fast-loading-neuron-binaries"]) compiler_args.extend(["--model-type=unet-inference"]) return compiler_args def improve_stable_diffusion_loading(config, neuron_model): - if version.parse(neuronx.__version__) >= version.parse("1.13.1.1.9.0"): - if hasattr(config._config, "_name_or_path"): - sd_components = ["text_encoder", "unet", "vae", "vae_encoder", "vae_decoder"] - if any(component in config._config._name_or_path.lower() for component in sd_components): - neuronx.async_load(neuron_model) - # unet - if "unet" in config._config._name_or_path.lower(): - neuronx.lazy_load(neuron_model) + if hasattr(config._config, "_name_or_path"): + sd_components = ["text_encoder", "unet", "vae", "vae_encoder", "vae_decoder"] + if any(component in config._config._name_or_path.lower() for component in sd_components): + neuronx.async_load(neuron_model) + # unet + if "unet" in config._config._name_or_path.lower(): + neuronx.lazy_load(neuron_model) def export_neuron( @@ -537,6 +549,9 @@ def export_neuron( fallback=not disable_fallback, ) torch.jit.save(neuron_model, output) + del model + del checked_model + del dummy_inputs del neuron_model return config.inputs, config.outputs diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index a42572410..3412902f8 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -277,6 +277,28 @@ def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): else: return dummy_inputs + class ModelWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, sample, timestep, encoder_hidden_states, text_embeds=None, time_ids=None): + out_tuple = self.model( + sample, + timestep.float().expand((sample.shape[0],)), + encoder_hidden_states, + added_cond_kwargs={"text_embeds": text_embeds, "time_ids": time_ids}, + return_dict=False, + ) + + return out_tuple + + def check_model_inputs_order(self, model, dummy_inputs): + return super().check_model_inputs_order( + model=model, + custom_model_wrapper=self.ModelWrapper, + ) + @register_in_tasks_manager("vae-encoder", *["semantic-segmentation"]) class VaeEncoderNeuronConfig(VisionNeuronConfig): diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index 2d0d21660..b7dbb3be2 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -19,8 +19,6 @@ from collections import OrderedDict from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union -import torch -from packaging import version from transformers import PretrainedConfig from ...neuron.utils import ( @@ -92,22 +90,23 @@ def build_stable_diffusion_components_mandatory_shapes( vae_decoder_num_channels: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, + num_images_per_prompt: Optional[int] = 1, ): text_encoder_input_shapes = {"batch_size": batch_size, "sequence_length": sequence_length} vae_encoder_input_shapes = { - "batch_size": batch_size, + "batch_size": batch_size * num_images_per_prompt, "num_channels": vae_encoder_num_channels, "height": height, "width": width, } vae_decoder_input_shapes = { - "batch_size": batch_size, + "batch_size": batch_size * num_images_per_prompt, "num_channels": vae_decoder_num_channels, "height": height, "width": width, } unet_input_shapes = { - "batch_size": batch_size, + "batch_size": batch_size * num_images_per_prompt, "sequence_length": sequence_length, "num_channels": unet_num_channels, "height": height, @@ -285,21 +284,18 @@ def _get_submodels_for_export_stable_diffusion( # VAE Encoder vae_encoder = copy.deepcopy(pipeline.vae) - if not version.parse(torch.__version__) >= version.parse("2.1.0"): - vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder) vae_encoder.forward = lambda sample: {"latent_sample": vae_encoder.encode(x=sample)["latent_dist"].sample()} models_for_export.append((DIFFUSION_MODEL_VAE_ENCODER_NAME, vae_encoder)) # VAE Decoder vae_decoder = copy.deepcopy(pipeline.vae) - if not version.parse(torch.__version__) >= version.parse("2.1.0"): - vae_decoder = override_diffusers_2_0_attn_processors(vae_decoder) vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample) models_for_export.append((DIFFUSION_MODEL_VAE_DECODER_NAME, vae_decoder)) return OrderedDict(models_for_export) +# Using xformers or torch_2_0 can avoid overflow on float16, do not apply this unless compilation error. def override_diffusers_2_0_attn_processors(model): for _, submodule in model.named_modules(): if isinstance(submodule, Attention): diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index f19556011..b76bb9e39 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -34,8 +34,7 @@ ], "modeling_diffusion": [ "NeuronStableDiffusionPipeline", - "NeuronStableDiffusionImg2ImgPipeline", - "NeuronStableDiffusionInpaintPipeline", + "NeuronStableDiffusionXLPipeline", ], "modeling_decoder": ["NeuronDecoderModel"], "accelerate": [ @@ -61,9 +60,8 @@ from .modeling_base import NeuronBaseModel from .modeling_decoder import NeuronDecoderModel from .modeling_diffusion import ( - NeuronStableDiffusionImg2ImgPipeline, - NeuronStableDiffusionInpaintPipeline, NeuronStableDiffusionPipeline, + NeuronStableDiffusionXLPipeline, ) from .pipelines import pipeline from .trainers import NeuronTrainer, Seq2SeqNeuronTrainer diff --git a/optimum/neuron/modeling_base.py b/optimum/neuron/modeling_base.py index 461ad858c..7e6e3da3f 100644 --- a/optimum/neuron/modeling_base.py +++ b/optimum/neuron/modeling_base.py @@ -99,10 +99,11 @@ def load_model(path: Union[str, Path]) -> torch.jit._script.ScriptModule: path (`Union[str, Path]`): Path of the compiled model. """ - if not isinstance(path, str): - path = str(path) + if not isinstance(path, Path): + path = Path(path) - return torch.jit.load(path) + if path.is_file(): + return torch.jit.load(path) def _save_pretrained(self, save_directory: Union[str, Path]): """ diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 23eb0b55e..dfbed7286 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -33,7 +33,9 @@ from ..utils import is_diffusers_available from .modeling_base import NeuronBaseModel from .pipelines.diffusers.pipeline_stable_diffusion import StableDiffusionPipelineMixin +from .pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin from .utils import ( + DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, DIFFUSION_MODEL_TEXT_ENCODER_NAME, DIFFUSION_MODEL_UNET_NAME, DIFFUSION_MODEL_VAE_DECODER_NAME, @@ -48,10 +50,16 @@ if is_diffusers_available(): - from diffusers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, StableDiffusionPipeline + from diffusers import ( + DDIMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, + ) from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME - from diffusers.utils import CONFIG_NAME + from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available if TYPE_CHECKING: @@ -71,11 +79,13 @@ def __init__( self, text_encoder: torch.jit._script.ScriptModule, unet: torch.jit._script.ScriptModule, - vae_encoder: torch.jit._script.ScriptModule, vae_decoder: torch.jit._script.ScriptModule, config: Dict[str, Any], tokenizer: CLIPTokenizer, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + vae_encoder: Optional[torch.jit._script.ScriptModule] = None, + text_encoder_2: Optional[torch.jit._script.ScriptModule] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, feature_extractor: Optional[CLIPFeatureExtractor] = None, device_ids: Optional[List[int]] = None, configs: Optional[Dict[str, "PretrainedConfig"]] = None, @@ -89,8 +99,6 @@ def __init__( The Neuron TorchScript module associated to the text encoder. unet (`torch.jit._script.ScriptModule`): The Neuron TorchScript module associated to the U-NET. - vae_encoder (`torch.jit._script.ScriptModule`): - The Neuron TorchScript module associated to the VAE encoder. vae_decoder (`torch.jit._script.ScriptModule`): The Neuron TorchScript module associated to the VAE decoder. config (`Dict[str, Any]`): @@ -101,6 +109,13 @@ def __init__( [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler (`Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]`): A scheduler to be used in combination with the U-NET component to denoise the encoded image latents. + vae_encoder (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`): + The Neuron TorchScript module associated to the VAE encoder. + text_encoder_2 (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`): + The Neuron TorchScript module associated to the second frozen text encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. + tokenizer_2 (`Optional[CLIPTokenizer]`, defaults to `None`): + Second tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). feature_extractor (`Optional[CLIPFeatureExtractor]`, defaults to `None`): A model extracting features from generated images to be used as inputs for the `safety_checker` device_ids (Optional[List[int]], defaults to `None`): @@ -129,14 +144,28 @@ def __init__( self.configs[DIFFUSION_MODEL_TEXT_ENCODER_NAME], self.neuron_configs[DIFFUSION_MODEL_TEXT_ENCODER_NAME], ) + self.text_encoder_2 = ( + NeuronModelTextEncoder( + text_encoder_2, + self, + self.configs[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME], + self.neuron_configs[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME], + ) + if text_encoder_2 is not None + else None + ) self.unet = NeuronModelUnet( unet, self, self.configs[DIFFUSION_MODEL_UNET_NAME], self.neuron_configs[DIFFUSION_MODEL_UNET_NAME] ) - self.vae_encoder = NeuronModelVaeEncoder( - vae_encoder, - self, - self.configs[DIFFUSION_MODEL_VAE_ENCODER_NAME], - self.neuron_configs[DIFFUSION_MODEL_VAE_ENCODER_NAME], + self.vae_encoder = ( + NeuronModelVaeEncoder( + vae_encoder, + self, + self.configs[DIFFUSION_MODEL_VAE_ENCODER_NAME], + self.neuron_configs[DIFFUSION_MODEL_VAE_ENCODER_NAME], + ) + if vae_encoder is not None + else None ) self.vae_decoder = NeuronModelVaeDecoder( vae_decoder, @@ -146,15 +175,20 @@ def __init__( ) self.tokenizer = tokenizer + self.tokenizer_2 = tokenizer_2 self.scheduler = scheduler self.feature_extractor = feature_extractor self.safety_checker = None sub_models = { DIFFUSION_MODEL_TEXT_ENCODER_NAME: self.text_encoder, DIFFUSION_MODEL_UNET_NAME: self.unet, - DIFFUSION_MODEL_VAE_ENCODER_NAME: self.vae_encoder, DIFFUSION_MODEL_VAE_DECODER_NAME: self.vae_decoder, } + if self.text_encoder_2 is not None: + sub_models[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = self.text_encoder_2 + if self.vae_encoder is not None: + sub_models[DIFFUSION_MODEL_VAE_ENCODER_NAME] = self.vae_encoder + for name in sub_models.keys(): self._internal_dict[name] = ("optimum", sub_models[name].__class__.__name__) self._internal_dict.pop("vae", None) @@ -163,18 +197,69 @@ def __init__( self.model_and_config_save_paths = model_and_config_save_paths if model_and_config_save_paths else None if hasattr(self.vae_decoder.config, "block_out_channels"): - self.vae_scale_factor = 2 ** ( - len(self.vae_decoder.config.block_out_channels) - 1 - ) # not working for tiny test models, need to remove `block_out_channels` in `config.json`. + self.vae_scale_factor = 2 ** (len(self.vae_decoder.config.block_out_channels) - 1) else: self.vae_scale_factor = 8 + self.num_images_per_prompt = ( + self.neuron_configs["unet"].batch_size // self.neuron_configs["text_encoder"].batch_size + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + @staticmethod + def load_model( + text_encoder_path: Union[str, Path], + unet_path: Union[str, Path], + vae_decoder_path: Union[str, Path], + vae_encoder_path: Optional[Union[str, Path]] = None, + text_encoder_2_path: Optional[Union[str, Path]] = None, + device_ids: Optional[List[int]] = None, + dynamic_batch_size: bool = False, + ): + """ + Loads Stable Diffusion TorchScript modules compiled by neuron(x)-cc compiler. It will be first loaded onto CPU and then moved to + one or multiple [NeuronCore](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/neuroncores-arch.html). + + Args: + text_encoder_path (`Union[str, Path]`): + Path of the compiled text encoder. + unet_path (`Union[str, Path]`): + Path of the compiled U-NET. + vae_decoder_path (`Union[str, Path]`): + Path of the compiled VAE decoder. + vae_encoder_path (`Optional[Union[str, Path]]`, defaults to `None`): + Path of the compiled VAE encoder. It is optional, only used for tasks taking images as input. + text_encoder_2_path (`Optional[Union[str, Path]]`, defaults to `None`): + Path of the compiled second frozen text encoder. SDXL only. + device_ids (`Optional[List[int]]`, defaults to `None`): + The ID of neuron cores to load a model, in the case of stable diffusion, it is only used for loading unet, and by default unet will be loaded onto both neuron cores of a device. + dynamic_batch_size (`bool`, defaults to `False`): + Whether enable dynamic batch size for neuron compiled model. If `True`, the input batch size can be a multiple of the batch size during the compilation. + """ + if device_ids is None: + device_ids = [0, 1] + + text_encoder = NeuronBaseModel.load_model(text_encoder_path) + if len(device_ids) > 1: + unet = torch_neuronx.DataParallel( + torch.jit.load(unet_path), + device_ids, + set_dynamic_batching=dynamic_batch_size, + ) + else: + unet = NeuronBaseModel.load_model(unet_path) + vae_decoder = NeuronBaseModel.load_model(vae_decoder_path) + vae_encoder = NeuronBaseModel.load_model(vae_encoder_path) + text_encoder_2 = NeuronBaseModel.load_model(text_encoder_2_path) + + return text_encoder, unet, vae_decoder, vae_encoder, text_encoder_2 + def _save_pretrained( self, save_directory: Union[str, Path], text_encoder_file_name: str = NEURON_FILE_NAME, + text_encoder_2_file_name: str = NEURON_FILE_NAME, unet_file_name: str = NEURON_FILE_NAME, vae_encoder_file_name: str = NEURON_FILE_NAME, vae_decoder_file_name: str = NEURON_FILE_NAME, @@ -183,6 +268,12 @@ def _save_pretrained( Saves the model to the serialized format optimized for Neuron devices. """ save_directory = Path(save_directory) + if not self.model_and_config_save_paths.get(DIFFUSION_MODEL_VAE_ENCODER_NAME)[0].is_file(): + self.model_and_config_save_paths.pop(DIFFUSION_MODEL_VAE_ENCODER_NAME) + + if not self.model_and_config_save_paths.get(DIFFUSION_MODEL_TEXT_ENCODER_2_NAME)[0].is_file(): + self.model_and_config_save_paths.pop(DIFFUSION_MODEL_TEXT_ENCODER_2_NAME) + if self.model_and_config_save_paths is None: logger.warning( "`model_save_paths` is None which means that no path of Neuron model is defined. Nothing will be saved." @@ -192,10 +283,19 @@ def _save_pretrained( logger.info(f"Saving the {tuple(self.model_and_config_save_paths.keys())}...") dst_paths = { - "text_encoder": save_directory / DIFFUSION_MODEL_TEXT_ENCODER_NAME / text_encoder_file_name, - "unet": save_directory / DIFFUSION_MODEL_UNET_NAME / unet_file_name, - "vae_encoder": save_directory / DIFFUSION_MODEL_VAE_ENCODER_NAME / vae_encoder_file_name, - "vae_decoder": save_directory / DIFFUSION_MODEL_VAE_DECODER_NAME / vae_decoder_file_name, + DIFFUSION_MODEL_TEXT_ENCODER_NAME: save_directory + / DIFFUSION_MODEL_TEXT_ENCODER_NAME + / text_encoder_file_name, + DIFFUSION_MODEL_TEXT_ENCODER_2_NAME: save_directory + / DIFFUSION_MODEL_TEXT_ENCODER_2_NAME + / text_encoder_2_file_name, + DIFFUSION_MODEL_UNET_NAME: save_directory / DIFFUSION_MODEL_UNET_NAME / unet_file_name, + DIFFUSION_MODEL_VAE_ENCODER_NAME: save_directory + / DIFFUSION_MODEL_VAE_ENCODER_NAME + / vae_encoder_file_name, + DIFFUSION_MODEL_VAE_DECODER_NAME: save_directory + / DIFFUSION_MODEL_VAE_DECODER_NAME + / vae_decoder_file_name, } model_src_to_dst_path = { self.model_and_config_save_paths[model_name][0]: dst_paths[model_name] @@ -212,9 +312,12 @@ def _save_pretrained( for src_path, dst_path in zip(src_paths, dst_paths): dst_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copyfile(src_path, dst_path) + if src_path.is_file(): + shutil.copyfile(src_path, dst_path) self.tokenizer.save_pretrained(save_directory.joinpath("tokenizer")) + if self.tokenizer_2 is not None: + self.tokenizer_2.save_pretrained(save_directory.joinpath("tokenizer_2")) self.scheduler.save_pretrained(save_directory.joinpath("scheduler")) if self.feature_extractor is not None: self.feature_extractor.save_pretrained(save_directory.joinpath("feature_extractor")) @@ -228,6 +331,7 @@ def _from_pretrained( revision: Optional[str] = None, cache_dir: Optional[str] = None, text_encoder_file_name: Optional[str] = NEURON_FILE_NAME, + text_encoder_2_file_name: Optional[str] = NEURON_FILE_NAME, unet_file_name: Optional[str] = NEURON_FILE_NAME, vae_encoder_file_name: Optional[str] = NEURON_FILE_NAME, vae_decoder_file_name: Optional[str] = NEURON_FILE_NAME, @@ -237,17 +341,16 @@ def _from_pretrained( **kwargs, # To share kwargs only available for `_from_transformers` ): model_id = str(model_id) - sub_models_to_load, _, _ = cls.extract_init_dict(config) - sub_models_names = set(sub_models_to_load.keys()).intersection({"feature_extractor", "tokenizer", "scheduler"}) - sub_models = {} + patterns = set(config.keys()) + sub_models_to_load = patterns.intersection({"feature_extractor", "tokenizer", "tokenizer_2", "scheduler"}) if not os.path.isdir(model_id): - patterns = set(config.keys()) patterns.update({DIFFUSION_MODEL_VAE_ENCODER_NAME, DIFFUSION_MODEL_VAE_DECODER_NAME}) allow_patterns = {os.path.join(k, "*") for k in patterns if not k.startswith("_")} allow_patterns.update( { text_encoder_file_name, + text_encoder_2_file_name, unet_file_name, vae_encoder_file_name, vae_decoder_file_name, @@ -268,8 +371,9 @@ def _from_pretrained( ) new_model_save_dir = Path(model_id) - for name in sub_models_names: - library_name, library_classes = sub_models_to_load[name] + sub_models = {} + for name in sub_models_to_load: + library_name, library_classes = config[name] if library_classes is not None: library = importlib.import_module(library_name) class_obj = getattr(library, library_classes) @@ -285,6 +389,10 @@ def _from_pretrained( new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_NAME / text_encoder_file_name, new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_NAME / cls.sub_component_config_name, ), + "text_encoder_2": ( + new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_NAME / text_encoder_2_file_name, + new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_NAME / cls.sub_component_config_name, + ), "unet": ( new_model_save_dir / DIFFUSION_MODEL_UNET_NAME / unet_file_name, new_model_save_dir / DIFFUSION_MODEL_UNET_NAME / cls.sub_component_config_name, @@ -300,26 +408,22 @@ def _from_pretrained( } # Re-build pretrained configs and neuron configs - configs = { - name: DiffusersPretrainedConfig.from_json_file(model_config[1]) - for name, model_config in model_and_config_save_paths.items() - } - neuron_configs = {name: cls._neuron_config_init(model_config) for name, model_config in configs.items()} - - text_encoder = cls.load_model(model_and_config_save_paths["text_encoder"][0]) - if device_ids is None: - device_ids = [0, 1] - if len(device_ids) > 1: - # Load the compiled UNet onto multiple neuron cores - unet = torch_neuronx.DataParallel( - torch.jit.load(model_and_config_save_paths["unet"][0]), - device_ids, - set_dynamic_batching=neuron_configs[DIFFUSION_MODEL_UNET_NAME].dynamic_batch_size, - ) - else: - unet = cls.load_model(model_and_config_save_paths["unet"][0]) - vae_encoder = cls.load_model(model_and_config_save_paths["vae_encoder"][0]) - vae_decoder = cls.load_model(model_and_config_save_paths["vae_decoder"][0]) + configs, neuron_configs = {}, {} + for name, file_paths in model_and_config_save_paths.items(): + if file_paths[1].is_file(): + model_config = DiffusersPretrainedConfig.from_json_file(file_paths[1]) + configs[name] = model_config + neuron_configs[name] = cls._neuron_config_init(model_config) + + text_encoder, unet, vae_decoder, vae_encoder, text_encoder_2 = cls.load_model( + text_encoder_path=model_and_config_save_paths["text_encoder"][0], + unet_path=model_and_config_save_paths["unet"][0], + vae_decoder_path=model_and_config_save_paths["vae_decoder"][0], + vae_encoder_path=model_and_config_save_paths["vae_encoder"][0], + text_encoder_2_path=model_and_config_save_paths["text_encoder_2"][0], + device_ids=device_ids, + dynamic_batch_size=neuron_configs[DIFFUSION_MODEL_UNET_NAME].dynamic_batch_size, + ) if model_save_dir is None: model_save_dir = new_model_save_dir @@ -327,11 +431,13 @@ def _from_pretrained( return cls( text_encoder=text_encoder, unet=unet, - vae_encoder=vae_encoder, vae_decoder=vae_decoder, config=config, tokenizer=sub_models["tokenizer"], scheduler=sub_models["scheduler"], + vae_encoder=vae_encoder, + text_encoder_2=text_encoder_2, + tokenizer_2=sub_models.pop("tokenizer_2", None), feature_extractor=sub_models.pop("feature_extractor", None), device_ids=device_ids, configs=configs, @@ -468,16 +574,25 @@ def __init__( if hasattr(self.model, "device"): self.device = self.model.device - def forward(self, sample: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor): + def forward( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, Any]] = None, + ): timestep = timestep.float().expand((sample.shape[0],)) inputs = { "sample": sample, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states, } - outputs = self.model(*tuple(inputs.values())) + if added_cond_kwargs is not None: + inputs["text_embeds"] = added_cond_kwargs.pop("text_embeds", None) + inputs["time_ids"] = added_cond_kwargs.pop("time_ids", None) - return tuple(output for output in outputs.values()) + outputs = self.model(*tuple(inputs.values())) + return outputs class NeuronModelVaeEncoder(_NeuronDiffusionModelPart): @@ -515,5 +630,63 @@ def forward(self, latent_sample: torch.Tensor): class NeuronStableDiffusionPipeline(NeuronStableDiffusionPipelineBase, StableDiffusionPipelineMixin): - def __call__(self, *args, **kwargs): - return StableDiffusionPipelineMixin.__call__(self, *args, **kwargs) + __call__ = StableDiffusionPipelineMixin.__call__ + + +class NeuronStableDiffusionXLPipelineBase(NeuronStableDiffusionPipelineBase): + # `TasksManager` registered img2ime pipeline for `stable-diffusion-xl`: https://github.com/huggingface/optimum/blob/v1.12.0/optimum/exporters/tasks.py#L174 + auto_model_class = StableDiffusionXLImg2ImgPipeline + + def __init__( + self, + text_encoder: torch.jit._script.ScriptModule, + unet: torch.jit._script.ScriptModule, + vae_decoder: torch.jit._script.ScriptModule, + config: Dict[str, Any], + tokenizer: CLIPTokenizer, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + vae_encoder: Optional[torch.jit._script.ScriptModule] = None, + text_encoder_2: Optional[torch.jit._script.ScriptModule] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + feature_extractor: Optional[CLIPFeatureExtractor] = None, + device_ids: Optional[List[int]] = None, + configs: Optional[Dict[str, "PretrainedConfig"]] = None, + neuron_configs: Optional[Dict[str, "NeuronConfig"]] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + model_and_config_save_paths: Optional[Dict[str, Tuple[str, Path]]] = None, + add_watermarker: Optional[bool] = None, + ): + super().__init__( + text_encoder=text_encoder, + unet=unet, + vae_decoder=vae_decoder, + config=config, + tokenizer=tokenizer, + scheduler=scheduler, + vae_encoder=vae_encoder, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + feature_extractor=feature_extractor, + device_ids=device_ids, + configs=configs, + neuron_configs=neuron_configs, + model_save_dir=model_save_dir, + model_and_config_save_paths=model_and_config_save_paths, + ) + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + if not is_invisible_watermark_available(): + raise ImportError( + "`add_watermarker` requires invisible-watermark to be installed, which can be installed with `pip install invisible-watermark`." + ) + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + +class NeuronStableDiffusionXLPipeline(NeuronStableDiffusionXLPipelineBase, StableDiffusionXLPipelineMixin): + __call__ = StableDiffusionXLPipelineMixin.__call__ diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py index bb3eec38a..d971f9192 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py @@ -30,7 +30,7 @@ class StableDiffusionPipelineMixin(StableDiffusionPipeline): # Adapted from https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L302 - def _encode_prompt( + def encode_prompt( self, prompt, num_images_per_prompt, @@ -173,7 +173,7 @@ def __call__( num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, + num_images_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, @@ -189,6 +189,12 @@ def __call__( # 0. Height and width to unet (static shapes) height = self.unet.config.neuron["static_height"] * self.vae_scale_factor width = self.unet.config.neuron["static_width"] * self.vae_scale_factor + if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: + logger.warning( + f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " + f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} per prompt." + ) + num_images_per_prompt = self.num_images_per_prompt # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -216,7 +222,7 @@ def __call__( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self._encode_prompt( + prompt_embeds = self.encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py new file mode 100644 index 000000000..9c0bc5739 --- /dev/null +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py @@ -0,0 +1,524 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from diffusers import StableDiffusionXLPipeline +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from diffusers.utils import randn_tensor + + +logger = logging.getLogger(__name__) + + +class StableDiffusionXLPipelineMixin(StableDiffusionXLPipeline): + # Adapted from https://github.com/huggingface/diffusers/blob/v0.20.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L219 + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`Optional[str]`, defaults to `None`): + prompt to be encoded + prompt_2 (`Optional[str]`, defaults to `None`): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + num_images_per_prompt (`int`, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`Optional[str]`, defaults to `None`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`Optional[str]`, defaults to `None`): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`Optional[float]`, defaults to `None`): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(input_ids=text_input_ids) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds[-1][-2] # hidden_states + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and getattr( + self.config, "force_zeros_for_empty_prompt", False + ) + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = "" if isinstance(prompt, str) else [""] * batch_size + else: + negative_prompt = negative_prompt + # negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt, negative_prompt_2] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds = text_encoder(input_ids=uncond_input.input_ids) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[-1][-2] # hidden_states + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Adapted from https://github.com/huggingface/diffusers/blob/v0.20.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L502 + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype) + elif latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Adapted from https://github.com/huggingface/diffusers/blob/v0.20.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L557 + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet (static shapes) + height = self.unet.config.neuron["static_height"] * self.vae_scale_factor + width = self.unet.config.neuron["static_width"] * self.vae_scale_factor + if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: + logger.warning( + f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " + f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} per prompt." + ) + num_images_per_prompt = self.num_images_per_prompt + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 and (self.dynamic_batch_size or len(self.device_ids) == 2) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = (original_size + crops_coords_top_left + target_size,) + add_time_ids = torch.tensor(add_time_ids, dtype=prompt_embeds.dtype) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # import pdb + # pdb.set_trace() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py index 3f9f1d406..607eae1af 100644 --- a/tests/inference/inference_utils.py +++ b/tests/inference/inference_utils.py @@ -42,6 +42,7 @@ "roberta": "hf-internal-testing/tiny-random-RobertaModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", + "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "xlm": "hf-internal-testing/tiny-random-XLMModel", "xlm-roberta": "hf-internal-testing/tiny-xlm-roberta", } diff --git a/tests/inference/test_stable_diffusion_pipeline.py b/tests/inference/test_stable_diffusion_pipeline.py index 2b30d9924..c4d1fb3b5 100644 --- a/tests/inference/test_stable_diffusion_pipeline.py +++ b/tests/inference/test_stable_diffusion_pipeline.py @@ -19,7 +19,7 @@ import PIL from parameterized import parameterized -from optimum.neuron import NeuronStableDiffusionPipeline +from optimum.neuron import NeuronStableDiffusionPipeline, NeuronStableDiffusionXLPipeline from optimum.neuron.modeling_diffusion import ( NeuronModelTextEncoder, NeuronModelUnet, @@ -66,13 +66,8 @@ def test_export_and_inference_non_dyn(self, model_arch): self.assertIsInstance(neuron_pipeline.vae_encoder, NeuronModelVaeEncoder) self.assertIsInstance(neuron_pipeline.vae_decoder, NeuronModelVaeDecoder) - prompt = "sailing ship in storm by Leonardo da Vinci" - with self.assertRaises(Exception) as context: - image = neuron_pipeline(prompt).images[0] - self.assertIn("pipeline were compiled with", str(context.exception)) - - prompts = ["sailing ship in storm by Leonardo da Vinci"] * num_images_per_prompt - image = neuron_pipeline(prompts).images[0] + prompts = ["sailing ship in storm by Leonardo da Vinci"] + image = neuron_pipeline(prompts, num_images_per_prompt=num_images_per_prompt).images[0] self.assertIsInstance(image, PIL.Image.Image) @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) @@ -89,3 +84,73 @@ def test_export_and_inference_dyn(self, model_arch): prompts = ["sailing ship in storm by Leonardo da Vinci"] * 2 image = neuron_pipeline(prompts, num_images_per_prompt=2).images[0] self.assertIsInstance(image, PIL.Image.Image) + + +@is_inferentia_test +@requires_neuronx +@require_diffusers +class NeuronStableDiffusionXLPipelineIntegrationTest(unittest.TestCase): + NEURON_MODEL_CLASS = NeuronStableDiffusionXLPipeline + STATIC_INPUTS_SHAPES = {"batch_size": 1, "height": 64, "width": 64} + COMPILER_ARGS = {"auto_cast": "all", "auto_cast_type": "bf16"} + SUPPORTED_ARCHITECTURES = [ + "stable-diffusion-xl", + ] + ATOL_FOR_VALIDATION = 1e-3 + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_export_and_inference_non_dyn(self, model_arch): + num_images_per_prompt = 4 + input_shapes = copy.deepcopy(self.STATIC_INPUTS_SHAPES) + input_shapes.update({"num_images_per_prompt": num_images_per_prompt}) + neuron_pipeline = self.NEURON_MODEL_CLASS.from_pretrained( + MODEL_NAMES[model_arch], + export=True, + dynamic_batch_size=False, + **input_shapes, + **self.COMPILER_ARGS, + device_ids=[0, 1], + ) + self.assertIsInstance(neuron_pipeline.text_encoder, NeuronModelTextEncoder) + self.assertIsInstance(neuron_pipeline.text_encoder_2, NeuronModelTextEncoder) + self.assertIsInstance(neuron_pipeline.unet, NeuronModelUnet) + self.assertIsInstance(neuron_pipeline.vae_encoder, NeuronModelVaeEncoder) + self.assertIsInstance(neuron_pipeline.vae_decoder, NeuronModelVaeDecoder) + + prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + prompt_2 = "Van Gogh painting" + negative_prompt_1 = "low quality, low resolution" + negative_prompt_2 = "low quality, low resolution" + + image = neuron_pipeline( + prompt=prompt, + prompt_2=prompt_2, + negative_prompt=negative_prompt_1, + negative_prompt_2=negative_prompt_2, + num_images_per_prompt=num_images_per_prompt, + ).images[0] + self.assertIsInstance(image, PIL.Image.Image) + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_export_and_inference_dyn(self, model_arch): + neuron_pipeline = self.NEURON_MODEL_CLASS.from_pretrained( + MODEL_NAMES[model_arch], + export=True, + dynamic_batch_size=True, + **self.STATIC_INPUTS_SHAPES, + **self.COMPILER_ARGS, + device_ids=[0, 1], + ) + + prompt = ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"] * 2 + prompt_2 = ["Van Gogh painting"] * 2 + negative_prompt_1 = ["low quality, low resolution"] * 2 + negative_prompt_2 = ["low quality, low resolution"] * 2 + image = neuron_pipeline( + prompt=prompt, + prompt_2=prompt_2, + negative_prompt=negative_prompt_1, + negative_prompt_2=negative_prompt_2, + num_images_per_prompt=2, + ).images[0] + self.assertIsInstance(image, PIL.Image.Image)