diff --git a/docs/source/guides/models.mdx b/docs/source/guides/models.mdx index 8d3f9662a..a4b8ebfb5 100644 --- a/docs/source/guides/models.mdx +++ b/docs/source/guides/models.mdx @@ -211,7 +211,11 @@ You can also accelerate the inference of stable diffusion on neuronx devices (in * VAE encoder * VAE decoder -The export can be done either with the CLI or with `NeuronStableDiffusionPipeline` API. Here is an example of exporting stable diffusion components with `NeuronStableDiffusionPipeline`: +### Text-to-Image + +`NeuronStableDiffusionPipeline` class allows you to generate images from a text prompt on neuron devices similar to the experience with `diffusers`. + +Like for other tasks, you need to compile models before being able to perform inference. The export can be done either via the CLI or via `NeuronStableDiffusionPipeline` API. Here is an example of exporting stable diffusion components with `NeuronStableDiffusionPipeline`: @@ -247,9 +251,75 @@ Now generate an image with a prompt on neuron: stable diffusion generated image +### Image-to-Image + +With the `NeuronStableDiffusionImg2ImgPipeline` class, you can generate a new image conditioned on a text prompt and an initial image. + +```python +import requests +from PIL import Image +from io import BytesIO +from optimum.neuron import NeuronStableDiffusionImg2ImgPipeline + +model_id = "nitrosocke/Ghibli-Diffusion" +input_shapes = {"batch_size": 1, "height": 512, "width": 512} +pipeline = NeuronStableDiffusionImg2ImgPipeline.from_pretrained(model_id, export=True, **input_shapes, device_ids=[0, 1]) +pipeline.save_pretrained("sd_img2img/") + +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) + +prompt = "ghibli style, a fantasy landscape with snowcapped mountains, trees, lake with detailed reflection. sunlight and cloud in the sky, warm colors, 8K" + +image = pipeline(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0] +image.save("fantasy_landscape.png") +``` +`image` | `prompt` | output | +:-------------------------:|:-------------------------:|:-------------------------:|-------------------------:| +landscape photo | ***ghibli style, a fantasy landscape with snowcapped mountains, trees, lake with detailed reflection. warm colors, 8K*** | drawing | + +### Inpaint + +With the `NeuronStableDiffusionInpaintPipeline` class, you can edit specific parts of an image by providing a mask and a text prompt. + +```python +import requests +from PIL import Image +from io import BytesIO +from optimum.neuron import NeuronStableDiffusionInpaintPipeline + +model_id = "runwayml/stable-diffusion-inpainting" +input_shapes = {"batch_size": 1, "height": 512, "width": 512} +pipeline = NeuronStableDiffusionInpaintPipeline.from_pretrained(model_id, export=True, **input_shapes, device_ids=[0, 1]) +pipeline.save_pretrained("sd_inpaint/") + +def download_image(url): + response = requests.get(url) + return Image.open(BytesIO(response.content)).convert("RGB") + +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + +init_image = download_image(img_url).resize((512, 512)) +mask_image = download_image(mask_url).resize((512, 512)) + +prompt = "Face of a yellow cat, high resolution, sitting on a park bench" +image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0] +image.save("cat_on_bench.png") +``` + +`image` | `mask_image` | `prompt` | output | +:-------------------------:|:-------------------------:|:-------------------------:|-------------------------:| +drawing | drawing | ***Face of a yellow cat, high resolution, sitting on a park bench*** | drawing | + ## Stable Diffusion XL Similar to Stable Diffusion, you will be able to use `NeuronStableDiffusionXLPipeline` API to export and run inference on Neuron devices with SDXL models. @@ -280,6 +350,8 @@ Now generate an image with a prompt on neuron: sdxl generated image diff --git a/docs/source/package_reference/export.mdx b/docs/source/package_reference/export.mdx index e35548f00..cb8c0064d 100644 --- a/docs/source/package_reference/export.mdx +++ b/docs/source/package_reference/export.mdx @@ -70,7 +70,8 @@ Since many architectures share similar properties for their Neuron configuration | RoFormer | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | | XLM | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | | XLM-RoBERTa | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | -| Stable Diffusion | text-to-image | +| Stable Diffusion | text-to-image, image-to-image, inpaint | +| Stable Diffusion XL | text-to-image | diff --git a/docs/source/package_reference/modeling.mdx b/docs/source/package_reference/modeling.mdx index 19aa2643c..55a6b613f 100644 --- a/docs/source/package_reference/modeling.mdx +++ b/docs/source/package_reference/modeling.mdx @@ -71,4 +71,20 @@ The following Neuron model classes are available for natural language processing ### NeuronStableDiffusionPipeline -[[autodoc]] modeling_diffusion.NeuronStableDiffusionPipeline \ No newline at end of file +[[autodoc]] modeling_diffusion.NeuronStableDiffusionPipeline + - __call__ + +### NeuronStableDiffusionImg2ImgPipeline + +[[autodoc]] modeling_diffusion.NeuronStableDiffusionImg2ImgPipeline + - __call__ + +### NeuronStableDiffusionInpaintPipeline + +[[autodoc]] modeling_diffusion.NeuronStableDiffusionInpaintPipeline + - __call__ + +### NeuronStableDiffusionXLPipeline + +[[autodoc]] modeling_diffusion.NeuronStableDiffusionXLPipeline + - __call__ \ No newline at end of file diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index f88bae011..04934884a 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -143,18 +143,25 @@ def infer_stable_diffusion_shapes_from_diffusers( vae_encoder_num_channels = model.vae.config.in_channels vae_decoder_num_channels = model.vae.config.latent_channels vae_scale_factor = 2 ** (len(model.vae.config.block_out_channels) - 1) or 8 - height = input_shapes["unet_input_shapes"]["height"] // vae_scale_factor - width = input_shapes["unet_input_shapes"]["width"] // vae_scale_factor + height = input_shapes["unet_input_shapes"]["height"] + scaled_height = height // vae_scale_factor + width = input_shapes["unet_input_shapes"]["width"] + scaled_width = width // vae_scale_factor input_shapes["text_encoder_input_shapes"].update({"sequence_length": sequence_length}) input_shapes["unet_input_shapes"].update( - {"sequence_length": sequence_length, "num_channels": unet_num_channels, "height": height, "width": width} + { + "sequence_length": sequence_length, + "num_channels": unet_num_channels, + "height": scaled_height, + "width": scaled_width, + } ) input_shapes["vae_encoder_input_shapes"].update( {"num_channels": vae_encoder_num_channels, "height": height, "width": width} ) input_shapes["vae_decoder_input_shapes"].update( - {"num_channels": vae_decoder_num_channels, "height": height, "width": width} + {"num_channels": vae_decoder_num_channels, "height": scaled_height, "width": scaled_width} ) return input_shapes diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index fa0aa7b77..babc09f5c 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -256,6 +256,7 @@ def outputs(self) -> List[str]: def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): # For neuron, we use static shape for compiling the unet. Unlike `optimum`, we use the given `height` and `width` instead of the `sample_size`. + # TODO: Modify optimum.utils.DummyVisionInputGenerator to enable unequal height and width (it prioritize `image_size` to custom h/w now) if self.height == self.width: self._normalized_config.image_size = self.height else: @@ -302,7 +303,7 @@ def check_model_inputs_order(self, model, dummy_inputs): @register_in_tasks_manager("vae-encoder", *["semantic-segmentation"]) class VaeEncoderNeuronConfig(VisionNeuronConfig): - ATOL_FOR_VALIDATION = 1e-2 + ATOL_FOR_VALIDATION = 1e-3 MODEL_TYPE = "vae-encoder" NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( @@ -319,6 +320,22 @@ def inputs(self) -> List[str]: def outputs(self) -> List[str]: return ["latent_sample"] + def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): + # For neuron, we use static shape for compiling the unet. Unlike `optimum`, we use the given `height` and `width` instead of the `sample_size`. + # TODO: Modify optimum.utils.DummyVisionInputGenerator to enable unequal height and width (it prioritize `image_size` to custom h/w now) + if self.height == self.width: + self._normalized_config.image_size = self.height + else: + raise ValueError( + "You need to input the same value for `self.height({self.height})` and `self.width({self.width})`." + ) + dummy_inputs = super().generate_dummy_inputs(**kwargs) + + if return_tuple is True: + return tuple(dummy_inputs.values()) + else: + return dummy_inputs + @register_in_tasks_manager("vae-decoder", *["semantic-segmentation"]) class VaeDecoderNeuronConfig(VisionNeuronConfig): diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index b76bb9e39..0797ed6fc 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -34,6 +34,8 @@ ], "modeling_diffusion": [ "NeuronStableDiffusionPipeline", + "NeuronStableDiffusionImg2ImgPipeline", + "NeuronStableDiffusionInpaintPipeline", "NeuronStableDiffusionXLPipeline", ], "modeling_decoder": ["NeuronDecoderModel"], @@ -60,6 +62,8 @@ from .modeling_base import NeuronBaseModel from .modeling_decoder import NeuronDecoderModel from .modeling_diffusion import ( + NeuronStableDiffusionImg2ImgPipeline, + NeuronStableDiffusionInpaintPipeline, NeuronStableDiffusionPipeline, NeuronStableDiffusionXLPipeline, ) diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index b1ae8e8b5..b8b58cfd2 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -59,8 +59,12 @@ from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available - from .pipelines.diffusers.pipeline_stable_diffusion import StableDiffusionPipelineMixin - from .pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin + from .pipelines import ( + NeuronStableDiffusionImg2ImgPipelineMixin, + NeuronStableDiffusionInpaintPipelineMixin, + NeuronStableDiffusionPipelineMixin, + NeuronStableDiffusionXLPipelineMixin, + ) if TYPE_CHECKING: @@ -158,16 +162,16 @@ def __init__( self.unet = NeuronModelUnet( unet, self, self.configs[DIFFUSION_MODEL_UNET_NAME], self.neuron_configs[DIFFUSION_MODEL_UNET_NAME] ) - self.vae_encoder = ( - NeuronModelVaeEncoder( + if vae_encoder is not None: + self.vae_encoder = NeuronModelVaeEncoder( vae_encoder, self, self.configs[DIFFUSION_MODEL_VAE_ENCODER_NAME], self.neuron_configs[DIFFUSION_MODEL_VAE_ENCODER_NAME], ) - if vae_encoder is not None - else None - ) + else: + self.vae_encoder = None + self.vae_decoder = NeuronModelVaeDecoder( vae_decoder, self, @@ -623,15 +627,36 @@ def __init__( ): super().__init__(model, parent_model, config, neuron_config, DIFFUSION_MODEL_VAE_DECODER_NAME) - def forward(self, latent_sample: torch.Tensor): + def forward( + self, + latent_sample: torch.Tensor, + image: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + ): inputs = (latent_sample,) + if image is not None: + inputs += (image,) + if mask is not None: + inputs += (mask,) outputs = self.model(*inputs) return tuple(output for output in outputs.values()) -class NeuronStableDiffusionPipeline(NeuronStableDiffusionPipelineBase, StableDiffusionPipelineMixin): - __call__ = StableDiffusionPipelineMixin.__call__ +class NeuronStableDiffusionPipeline(NeuronStableDiffusionPipelineBase, NeuronStableDiffusionPipelineMixin): + __call__ = NeuronStableDiffusionPipelineMixin.__call__ + + +class NeuronStableDiffusionImg2ImgPipeline( + NeuronStableDiffusionPipelineBase, NeuronStableDiffusionImg2ImgPipelineMixin +): + __call__ = NeuronStableDiffusionImg2ImgPipelineMixin.__call__ + + +class NeuronStableDiffusionInpaintPipeline( + NeuronStableDiffusionPipelineBase, NeuronStableDiffusionInpaintPipelineMixin +): + __call__ = NeuronStableDiffusionInpaintPipelineMixin.__call__ class NeuronStableDiffusionXLPipelineBase(NeuronStableDiffusionPipelineBase): @@ -689,5 +714,5 @@ def __init__( self.watermark = None -class NeuronStableDiffusionXLPipeline(NeuronStableDiffusionXLPipelineBase, StableDiffusionXLPipelineMixin): - __call__ = StableDiffusionXLPipelineMixin.__call__ +class NeuronStableDiffusionXLPipeline(NeuronStableDiffusionXLPipelineBase, NeuronStableDiffusionXLPipelineMixin): + __call__ = NeuronStableDiffusionXLPipelineMixin.__call__ diff --git a/optimum/neuron/pipelines/__init__.py b/optimum/neuron/pipelines/__init__.py index fa325ae82..52eb5ef1e 100644 --- a/optimum/neuron/pipelines/__init__.py +++ b/optimum/neuron/pipelines/__init__.py @@ -20,9 +20,21 @@ _import_structure = { "transformers": ["pipeline"], + "diffusers": [ + "NeuronStableDiffusionPipelineMixin", + "NeuronStableDiffusionImg2ImgPipelineMixin", + "NeuronStableDiffusionInpaintPipelineMixin", + "NeuronStableDiffusionXLPipelineMixin", + ], } if TYPE_CHECKING: + from .diffusers import ( + NeuronStableDiffusionImg2ImgPipelineMixin, + NeuronStableDiffusionInpaintPipelineMixin, + NeuronStableDiffusionPipelineMixin, + NeuronStableDiffusionXLPipelineMixin, + ) from .transformers import ( pipeline, ) diff --git a/optimum/neuron/pipelines/diffusers/__init__.py b/optimum/neuron/pipelines/diffusers/__init__.py new file mode 100644 index 000000000..0d20f7fb3 --- /dev/null +++ b/optimum/neuron/pipelines/diffusers/__init__.py @@ -0,0 +1,19 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .pipeline_stable_diffusion import NeuronStableDiffusionPipelineMixin +from .pipeline_stable_diffusion_img2img import NeuronStableDiffusionImg2ImgPipelineMixin +from .pipeline_stable_diffusion_inpaint import NeuronStableDiffusionInpaintPipelineMixin +from .pipeline_stable_diffusion_xl import NeuronStableDiffusionXLPipelineMixin diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py index 99a053c2f..ba7d5ee18 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py @@ -19,123 +19,17 @@ import torch from diffusers import StableDiffusionPipeline -from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg from diffusers.utils.torch_utils import randn_tensor +from .pipeline_utils import StableDiffusionPipelineMixin -logger = logging.getLogger(__name__) - - -class StableDiffusionPipelineMixin(StableDiffusionPipeline): - # Adapted from https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L302 - def encode_prompt( - self, - prompt, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - ): - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - # [Modified] Input and its dtype constraints - prompt_embeds = self.text_encoder(input_ids=text_input_ids) - prompt_embeds = prompt_embeds[0] - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = self.text_encoder(uncond_input.input_ids) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) +logger = logging.getLogger(__name__) - return prompt_embeds +class NeuronStableDiffusionPipelineMixin(StableDiffusionPipelineMixin, StableDiffusionPipeline): # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -154,18 +48,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - def check_num_images_per_prompt(self, prompt_batch_size: int, neuron_batch_size: int, num_images_per_prompt: int): - if self.dynamic_batch_size: - return prompt_batch_size, num_images_per_prompt - if neuron_batch_size != prompt_batch_size * num_images_per_prompt: - raise ValueError( - f"Models in the pipeline were compiled with `batch_size` {neuron_batch_size} which does not equal the number of" - f" prompt({prompt_batch_size}) multiplied by `num_images_per_prompt`({num_images_per_prompt}). You need to enable" - " `dynamic_batch_size` or precisely configure `num_images_per_prompt` during the compilation." - ) - else: - return prompt_batch_size, num_images_per_prompt - # Adapted from https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L566 def __call__( self, @@ -186,32 +68,107 @@ def __call__( cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. If it is different from the batch size used for the compiltaion, + it will be overriden by the static batch size of neuron (except for dynamic batching). + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`diffusers.schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`Optional[str]`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Optional[Callable]`, defaults to `None`): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, defaults to `None`): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + + Examples: + + ```py + >>> from optimum.neuron import NeuronStableDiffusionPipeline + + >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} + >>> input_shapes = {"batch_size": 1, "height": 512, "width": 512} + + >>> stable_diffusion = NeuronStableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", export=True, **compiler_args, **input_shapes + ... ) + >>> stable_diffusion.save_pretrained("sd_neuron/") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = stable_diffusion(prompt).images[0] + ``` + + Returns: + [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ # 0. Height and width to unet (static shapes) height = self.unet.config.neuron["static_height"] * self.vae_scale_factor width = self.unet.config.neuron["static_width"] * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: logger.warning( f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} per prompt." ) num_images_per_prompt = self.num_images_per_prompt - - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 + batch_size = 1 elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) + batch_size = len(prompt) else: - prompt_batch_size = prompt_embeds.shape[0] + batch_size = prompt_embeds.shape[0] neuron_batch_size = self.unet.config.neuron["static_batch_size"] - batch_size, num_images_per_prompt = self.check_num_images_per_prompt( - prompt_batch_size, neuron_batch_size, num_images_per_prompt - ) + self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -222,7 +179,7 @@ def __call__( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - prompt_embeds = self.encode_prompt( + prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, @@ -231,6 +188,11 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -260,7 +222,7 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - # [Modified] Remove not traced inputs + # [modified for neuron] Remove not traced inputs: cross_attention_kwargs, return_dict noise_pred = self.unet( latent_model_input, t, @@ -308,17 +270,3 @@ def __call__( return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - def run_safety_checker(self, image, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt") - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_img2img.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_img2img.py new file mode 100644 index 000000000..f1fbe92a8 --- /dev/null +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_img2img.py @@ -0,0 +1,309 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Override some diffusers API for NeuroStableDiffusionImg2ImgPipeline""" + +import logging +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from diffusers import StableDiffusionImg2ImgPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils import deprecate +from diffusers.utils.torch_utils import randn_tensor + +from .pipeline_utils import StableDiffusionPipelineMixin + + +logger = logging.getLogger(__name__) + + +class NeuronStableDiffusionImg2ImgPipelineMixin(StableDiffusionPipelineMixin, StableDiffusionImg2ImgPipeline): + # Adapted from diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + else: + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=image)[0] + scaling_factor = self.vae_encoder.config.scaling_factor or 0.18215 + init_latents = scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # Adapted from diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.__call__ + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`Optional[Union[str, List[str]`, defaults to `None`): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. If it is different from the batch size used for the compiltaion, + it will be overriden by the static batch size of neuron (except for dynamic batching). + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`diffusers.schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`Optional[str]`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Optional[Callable]`, defaults to `None`): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, defaults to `None`): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> from io import BytesIO + + >>> from optimum.neuron import NeuronStableDiffusionImg2ImgPipeline + + >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} + >>> input_shapes = {"batch_size": 1, "height": 512, "width": 512} + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + >>> response = requests.get(url) + + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((512, 512)) + + >>> pipeline = NeuronStableDiffusionImg2ImgPipeline.from_pretrained( + ... "nitrosocke/Ghibli-Diffusion", export=True, **input_shapes, device_ids=[0, 1] + ... ) + >>> pipeline.save_pretrained("sd_img2img/") + + >>> prompt = "ghibli style, a fantasy landscape with snowcapped mountains, trees, lake with detailed reflection." + >>> image = pipeline(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0] + ``` + + Returns: + [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 1. Check inputs. Raise error if not correct + if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: + logger.warning( + f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " + f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} image per prompt." + ) + num_images_per_prompt = self.num_images_per_prompt + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + neuron_batch_size = self.unet.config.neuron["static_batch_size"] + self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 and (self.dynamic_batch_size or len(self.device_ids) == 2) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. Preprocess image + height = self.vae_encoder.config.neuron["static_height"] + width = self.vae_encoder.config.neuron["static_width"] + image = self.image_processor.preprocess(image, height=height, width=width) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device=None) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, generator + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + # [modified for neuron] Remove not traced inputs: cross_attention_kwargs, return_dict + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] + image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py new file mode 100644 index 000000000..445ec9e3b --- /dev/null +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py @@ -0,0 +1,399 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Override some diffusers API for NeuroStableDiffusionInpaintPipeline""" + +import logging +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from diffusers import StableDiffusionInpaintPipeline +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from .pipeline_utils import StableDiffusionPipelineMixin + + +logger = logging.getLogger(__name__) + + +class NeuronStableDiffusionInpaintPipelineMixin(StableDiffusionPipelineMixin, StableDiffusionInpaintPipeline): + # Adapted from https://github.com/huggingface/diffusers/blob/v0.21.2/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py#L629 + def _encode_vae_image( + self, image: torch.Tensor, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None + ): + image_latents = self.vae_encoder(sample=image)[0] + image_latents = self.vae_encoder.config.scaling_factor * image_latents + + return image_latents + + # Adapted from https://github.com/huggingface/diffusers/blob/v0.21.2/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py#L699 + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: torch.FloatTensor = None, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + clip_skip: int = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to + be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch + tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the + expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the + expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but + if passing latents directly it is not encoded again. + mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + strength (`float`, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`Optional[Union[str, List[str]`, defaults to `None`): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. If it is different from the batch size used for the compiltaion, + it will be overriden by the static batch size of neuron (except for dynamic batching). + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`diffusers.schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`Optional[str]`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Optional[Callable]`, defaults to `None`): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, defaults to `None`): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, defaults to `None`): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + + Examples: + + ```py + >>> import PIL + >>> import requests + >>> from io import BytesIO + + >>> from optimum.neuron import NeuronStableDiffusionInpaintPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + + >>> pipeline = NeuronStableDiffusionInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", export=True, **input_shapes, device_ids=[0, 1]) + ... ) + >>> pipeline.save_pretrained("sd_inpaint/") + + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0] + ``` + + Returns: + [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 0. Height and width to unet (static shapes) + height = self.unet.config.neuron["static_height"] * self.vae_scale_factor + width = self.unet.config.neuron["static_width"] * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + if self.num_images_per_prompt != num_images_per_prompt and not self.dynamic_batch_size: + logger.warning( + f"Overriding `num_images_per_prompt({num_images_per_prompt})` to {self.num_images_per_prompt} used for the compilation. Please recompile the models with your " + f"custom `num_images_per_prompt` or turn on `dynamic_batch_size`, if you wish generating {num_images_per_prompt} per prompt." + ) + num_images_per_prompt = self.num_images_per_prompt + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + neuron_batch_size = self.unet.config.neuron["static_batch_size"] + self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 and (self.dynamic_batch_size or len(self.device_ids) == 2) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=None + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + # 6. Prepare latent variables + num_channels_latents = self.vae_encoder.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + None, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + None, + generator, + do_classifier_free_guidance, + ) + + # 8. Check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + # [modified for neuron] Remove not traced inputs: cross_attention_kwargs, return_dict + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + condition_kwargs = {} + if "AsymmetricAutoencoderKL" in self.vae_decoder.config._class_name: + init_image = init_image.to(dtype=masked_image_latents.dtype) + init_image_condition = init_image.clone() + # [modified for neuron] Remove generator which is not an input for the compilation + init_image = self._encode_vae_image(init_image) + mask_condition = mask_condition.to(dtype=masked_image_latents.dtype) + condition_kwargs = {"image": init_image_condition, "mask": mask_condition} + image = self.vae_decoder( + latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215), **condition_kwargs + )[0] + image, has_nsfw_concept = self.run_safety_checker(image, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py index 0d438276c..7620a56f5 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py @@ -17,201 +17,17 @@ import torch from diffusers import StableDiffusionXLPipeline -from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from diffusers.utils.torch_utils import randn_tensor +from .pipeline_utils import StableDiffusionXLPipelineMixin -logger = logging.getLogger(__name__) - - -class StableDiffusionXLPipelineMixin(StableDiffusionXLPipeline): - # Adapted from https://github.com/huggingface/diffusers/blob/v0.20.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L219 - def encode_prompt( - self, - prompt: str, - prompt_2: Optional[str] = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - negative_prompt_2: Optional[str] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`Optional[str]`, defaults to `None`): - prompt to be encoded - prompt_2 (`Optional[str]`, defaults to `None`): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in both text-encoders - num_images_per_prompt (`int`, defaults to 1): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`, defaults to `True`): - whether to use classifier free guidance or not - negative_prompt (`Optional[str]`, defaults to `None`): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - negative_prompt_2 (`Optional[str]`, defaults to `None`): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - lora_scale (`Optional[float]`, defaults to `None`): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - # textual inversion: procecss multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, tokenizer) - - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - prompt_embeds = text_encoder(input_ids=text_input_ids) - - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds[-1][-2] # hidden_states - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = negative_prompt is None and getattr( - self.config, "force_zeros_for_empty_prompt", False - ) - if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - if negative_prompt is None: - negative_prompt = "" if isinstance(prompt, str) else [""] * batch_size - else: - negative_prompt = negative_prompt - # negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): - if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - negative_prompt_embeds = text_encoder(input_ids=uncond_input.input_ids) - # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds[-1][-2] # hidden_states - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - prompt_embeds = prompt_embeds - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) +logger = logging.getLogger(__name__) - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds +class NeuronStableDiffusionXLPipelineMixin(StableDiffusionXLPipelineMixin, StableDiffusionXLPipeline): # Adapted from https://github.com/huggingface/diffusers/blob/v0.20.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L502 def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -233,7 +49,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype # Adapted from https://github.com/huggingface/diffusers/blob/v0.20.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L557 def __call__( self, - prompt: Union[str, List[str]] = None, + prompt: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None, num_inference_steps: int = 50, denoising_end: Optional[float] = None, @@ -262,102 +78,118 @@ def __call__( Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - prompt_2 (`str` or `List[str]`, *optional*): + prompt_2 (`Optional[Union[str, List[str]]]`, defaults to `None`): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders - num_inference_steps (`int`, *optional*, defaults to 50): + num_inference_steps (`int`, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - denoising_end (`float`, *optional*): + denoising_end (`Optional[float]`, defaults to `None`): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will still retain a substantial amount of noise as determined by the discrete timesteps selected by the scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) - guidance_scale (`float`, *optional*, defaults to 5.0): + guidance_scale (`float`, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): + negative_prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): + negative_prompt_2 (`Optional[Union[str, List[str]]]`, defaults to `None`): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. If it is different from the batch size used for the compiltaion, + it will be overriden by the static batch size of neuron (except for dynamic batching). + eta (`float`, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): + latents (`Optional[torch.FloatTensor]`, defaults to `None`): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): + prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): + negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + negative_pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): + output_type (`Optional[str]`, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`diffusers.pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): + callback (`Optional[Callable]`, defaults to `None`): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): + callback_steps (`int`, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - cross_attention_kwargs (`dict`, *optional*): + cross_attention_kwargs (`dict`, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - guidance_rescale (`float`, *optional*, defaults to 0.7): + guidance_rescale (`float`, *optional*, defaults to 0.0): Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. - original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + original_size (`Optional[Tuple[int, int]]`, defaults to (1024, 1024)): If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + crops_coords_top_left (`Tuple[int]`, defaults to (0, 0)): `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + target_size (`Tuple[int]`,defaults to (1024, 1024)): For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Examples: + ```py + >>> from optimum.neuron import NeuronStableDiffusionXLPipeline + + >>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} + >>> input_shapes = {"batch_size": 1, "height": 1024, "width": 1024} + + >>> stable_diffusion_xl = NeuronStableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", export=True, **compiler_args, **input_shapes) + ... ) + >>> stable_diffusion_xl.save_pretrained("sd_neuron_xl/") + + >>> prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + >>> image = stable_diffusion_xl(prompt).images[0] + ``` + Returns: - [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + [`diffusers.pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`diffusers.pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ # 0. Default height and width to unet (static shapes) @@ -395,6 +227,8 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] + neuron_batch_size = self.unet.config.neuron["static_batch_size"] + self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` diff --git a/optimum/neuron/pipelines/diffusers/pipeline_utils.py b/optimum/neuron/pipelines/diffusers/pipeline_utils.py new file mode 100644 index 000000000..e37183221 --- /dev/null +++ b/optimum/neuron/pipelines/diffusers/pipeline_utils.py @@ -0,0 +1,361 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import List, Optional, Union + +import torch +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin + + +logger = logging.getLogger(__name__) + + +class DiffusionBasePipelineMixin: + def check_num_images_per_prompt(self, prompt_batch_size: int, neuron_batch_size: int, num_images_per_prompt: int): + if not self.dynamic_batch_size and neuron_batch_size != prompt_batch_size * num_images_per_prompt: + raise ValueError( + f"Models in the pipeline were compiled with `batch_size` {neuron_batch_size} which does not equal the number of" + f" prompt({prompt_batch_size}) multiplied by `num_images_per_prompt`({num_images_per_prompt}). You need to enable" + " `dynamic_batch_size` or precisely configure `num_images_per_prompt` during the compilation." + ) + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt") + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + +class StableDiffusionPipelineMixin(DiffusionBasePipelineMixin): + # Adapted from https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L302 + def encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, list]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`Union[str, List[str]]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`Optional[Union[str, list]]`, defaults to `None`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`Optional[float]`, defaults to `None`): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + # [Modified] Input and its dtype constraints + prompt_embeds = self.text_encoder(input_ids=text_input_ids) + prompt_embeds = prompt_embeds[0] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + +class StableDiffusionXLPipelineMixin(DiffusionBasePipelineMixin): + # Adapted from https://github.com/huggingface/diffusers/blob/v0.20.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L219 + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str`): + prompt to be encoded + prompt_2 (`Optional[str]`, defaults to `None`): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + num_images_per_prompt (`int`, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`Optional[str]`, defaults to `None`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`Optional[str]`, defaults to `None`): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`Optional[torch.FloatTensor]`, defaults to `None`): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`Optional[float]`, defaults to `None`): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(input_ids=text_input_ids) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds[-1][-2] # hidden_states + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and getattr( + self.config, "force_zeros_for_empty_prompt", False + ) + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = "" if isinstance(prompt, str) else [""] * batch_size + else: + negative_prompt = negative_prompt + # negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt, negative_prompt_2] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds = text_encoder(input_ids=uncond_input.input_ids) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[-1][-2] # hidden_states + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py index 607eae1af..c1e63ff10 100644 --- a/tests/inference/inference_utils.py +++ b/tests/inference/inference_utils.py @@ -17,9 +17,12 @@ import shutil import tempfile import unittest +from io import BytesIO from typing import Dict +import requests from huggingface_hub import HfFolder +from PIL import Image from transformers import set_seed @@ -123,3 +126,8 @@ def _setup(self, model_args: Dict): def tearDownClass(cls): for _, dir_path in cls.neuron_model_dirs.items(): shutil.rmtree(dir_path) + + +def download_image(url): + response = requests.get(url) + return Image.open(BytesIO(response.content)).convert("RGB") diff --git a/tests/inference/test_stable_diffusion_pipeline.py b/tests/inference/test_stable_diffusion_pipeline.py index c4d1fb3b5..3f56f6bd1 100644 --- a/tests/inference/test_stable_diffusion_pipeline.py +++ b/tests/inference/test_stable_diffusion_pipeline.py @@ -19,7 +19,12 @@ import PIL from parameterized import parameterized -from optimum.neuron import NeuronStableDiffusionPipeline, NeuronStableDiffusionXLPipeline +from optimum.neuron import ( + NeuronStableDiffusionImg2ImgPipeline, + NeuronStableDiffusionInpaintPipeline, + NeuronStableDiffusionPipeline, + NeuronStableDiffusionXLPipeline, +) from optimum.neuron.modeling_diffusion import ( NeuronModelTextEncoder, NeuronModelUnet, @@ -30,7 +35,7 @@ from optimum.utils import logging from optimum.utils.testing_utils import require_diffusers -from .inference_utils import MODEL_NAMES +from .inference_utils import MODEL_NAMES, download_image logger = logging.get_logger() @@ -85,6 +90,42 @@ def test_export_and_inference_dyn(self, model_arch): image = neuron_pipeline(prompts, num_images_per_prompt=2).images[0] self.assertIsInstance(image, PIL.Image.Image) + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_img2img_export_and_inference(self, model_arch): + neuron_pipeline = NeuronStableDiffusionImg2ImgPipeline.from_pretrained( + MODEL_NAMES[model_arch], + export=True, + dynamic_batch_size=False, + **self.STATIC_INPUTS_SHAPES, + **self.COMPILER_ARGS, + device_ids=[0, 1], + ) + + url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + init_image = download_image(url) + prompt = "ghibli style, a fantasy landscape with mountain, trees and lake, reflection" + image = neuron_pipeline(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0] + self.assertIsInstance(image, PIL.Image.Image) + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_inpaint_export_and_inference(self, model_arch): + neuron_pipeline = NeuronStableDiffusionInpaintPipeline.from_pretrained( + MODEL_NAMES[model_arch], + export=True, + dynamic_batch_size=False, + **self.STATIC_INPUTS_SHAPES, + **self.COMPILER_ARGS, + device_ids=[0, 1], + ) + + img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + init_image = download_image(img_url).resize((512, 512)) + mask_image = download_image(mask_url).resize((512, 512)) + prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + image = neuron_pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0] + self.assertIsInstance(image, PIL.Image.Image) + @is_inferentia_test @requires_neuronx