diff --git a/.github/workflows/build_main_documentation.yml b/.github/workflows/build_main_documentation.yml index b6a752698f..e82043be98 100644 --- a/.github/workflows/build_main_documentation.yml +++ b/.github/workflows/build_main_documentation.yml @@ -44,6 +44,11 @@ jobs: repository: 'huggingface/optimum-intel' path: optimum-intel + - uses: actions/checkout@v2 + with: + repository: 'huggingface/optimum-furiosa' + path: optimum-furiosa + - name: Set environment variables run: | cd optimum @@ -76,6 +81,7 @@ jobs: - name: Make Habana documentation run: | + sudo docker system prune -a -f cd optimum-habana make doc BUILD_DIR=habana-doc-build VERSION=${{ env.VERSION }} sudo mv habana-doc-build ../optimum @@ -83,11 +89,33 @@ jobs: - name: Make Intel documentation run: | + sudo docker system prune -a -f cd optimum-intel make doc BUILD_DIR=intel-doc-build VERSION=${{ env.VERSION }} sudo mv intel-doc-build ../optimum cd .. + - name: Make Furiosa documentation + run: | + cd optimum-furiosa + pip install . + sudo apt update + sudo apt install -y ca-certificates apt-transport-https gnupg + sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-key 5F03AFA423A751913F249259814F888B20B09A7E + sudo tee -a /etc/apt/auth.conf.d/furiosa.conf > /dev/null < /dev/null < Dict[str, Dict[int, str]]: + common_outputs = super().outputs + + if self._behavior is ConfigBehavior.ENCODER: + for i in range(self._config.encoder_layers): + common_outputs[f"encoder_attentions.{i}"] = {0: "batch_size"} + elif self._behavior is ConfigBehavior.DECODER: + for i in range(self._config.decoder_layers): + common_outputs[f"decoder_attentions.{i}"] = { + 0: "batch_size", + 2: "decoder_sequence_length", + 3: "past_decoder_sequence_length + 1" + } + for i in range(self._config.decoder_layers): + common_outputs[f"cross_attentions.{i}"] = { + 0: "batch_size", + 2: "decoder_sequence_length", + 3: "encoder_sequence_length_out" + } + + return common_outputs + + @property + def torch_to_onnx_output_map(self): + if self._behavior is ConfigBehavior.ENCODER: + # The encoder export uses WhisperEncoder that returns the key "attentions" + return {"attentions": "encoder_attentions"} + else: + return {} + +model_id = "openai/whisper-tiny.en" +config = AutoConfig.from_pretrained(model_id) + +custom_whisper_onnx_config = CustomWhisperOnnxConfig( + config=config, + task="automatic-speech-recognition", +) + +encoder_config = custom_whisper_onnx_config.with_behavior("encoder") +decoder_config = custom_whisper_onnx_config.with_behavior("decoder", use_past=False) +decoder_with_past_config = custom_whisper_onnx_config.with_behavior("decoder", use_past=True) + +custom_onnx_configs={ + "encoder_model": encoder_config, + "decoder_model": decoder_config, + "decoder_with_past_model": decoder_with_past_config, +} + +main_export( + model_id, + output="custom_whisper_onnx", + no_post_process=True, + model_kwargs={"output_attentions": True}, + custom_onnx_configs=custom_onnx_configs +) +``` + +### Customize the export of Transformers models with custom modeling + +Optimum supports the export of Transformers models with custom modeling that use [`trust_remote_code=True`](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoModel.from_pretrained.trust_remote_code), not officially supported in the Transormers library but usable with its functionality as [pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) and [generation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.generate). + +Examples of such models are [THUDM/chatglm2-6b](https://huggingface.co/THUDM/chatglm2-6b) and [mosaicml/mpt-30b](https://huggingface.co/mosaicml/mpt-30b). + +To export custom models, a dictionary `custom_onnx_configs` needs to be passed to [`~optimum.exporters.onnx.main_export`], with the ONNX config definition for all the subparts of the model to export (for example, encoder and decoder subparts). The example below allows to export `mosaicml/mpt-7b` model: + +```python +from optimum.exporters.onnx import main_export + +from transformers import AutoConfig + +from optimum.exporters.onnx.config import TextDecoderOnnxConfig +from optimum.utils import NormalizedTextConfig, DummyPastKeyValuesGenerator +from typing import Dict + + +class MPTDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + """ + MPT swaps the two last dimensions for the key cache compared to usual transformers + decoder models, thus the redefinition here. + """ + def generate(self, input_name: str, framework: str = "pt"): + past_key_shape = ( + self.batch_size, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads, + self.sequence_length, + ) + past_value_shape = ( + self.batch_size, + self.num_attention_heads, + self.sequence_length, + self.hidden_size // self.num_attention_heads, + ) + return [ + ( + self.random_float_tensor(past_key_shape, framework=framework), + self.random_float_tensor(past_value_shape, framework=framework), + ) + for _ in range(self.num_layers) + ] + +class CustomMPTOnnxConfig(TextDecoderOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (MPTDummyPastKeyValuesGenerator,) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DUMMY_PKV_GENERATOR_CLASS = MPTDummyPastKeyValuesGenerator + + DEFAULT_ONNX_OPSET = 14 # aten::tril operator requires opset>=14 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + hidden_size="d_model", + num_layers="n_layers", + num_attention_heads="n_heads" + ) + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + """ + Adapted from https://github.com/huggingface/optimum/blob/v1.9.0/optimum/exporters/onnx/base.py#L625 + """ + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 3: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name} + + +model_id = "/home/fxmarty/hf_internship/optimum/tiny-mpt-random-remote-code" +config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + +onnx_config = CustomMPTOnnxConfig( + config=config, + task="text-generation", + use_past_in_inputs=False, + use_present_in_outputs=True, +) +onnx_config_with_past = CustomMPTOnnxConfig(config, task="text-generation", use_past=True) + +custom_onnx_configs = { + "decoder_model": onnx_config, + "decoder_with_past_model": onnx_config_with_past, +} + +main_export( + model_id, + output="mpt_onnx", + task="text-generation-with-past", + trust_remote_code=True, + custom_onnx_configs=custom_onnx_configs, + no_post_process=True, +) +``` + +Moreover, the advanced argument `fn_get_submodels` to `main_export` allows to customize how the submodels are extracted in case the model needs to be exported in several submodels. Examples of such functions can be [consulted here](link to utils.py relevant code once merged). diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 0d4c9b1697..04f2af3aa9 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -22,25 +22,25 @@ As such, Optimum enables developers to efficiently use any of these platforms wi
Optimum Habana
+ >
Habana

Maximize training throughput and efficiency with Habana's Gaudi processor

Optimum Intel
+ >
Intel

Optimize your model to speedup inference with OpenVINO and Neural Compressor

-
Optimum Neuron
-

Coming soon!

+
AWS Trainium/Inferentia
+

Accelerate your training and inference workflows with AWS Trainium and AWS Inferentia

+ +
FuriosaAI
+

Fast and efficient inference on FuriosaAI WARBOY

ONNX Runtime

Apply quantization and graph optimization to accelerate Transformers models training and inference with ONNX Runtime

-
Exporters
-

Export your PyTorch or TensorFlow model to different formats such as ONNX and TFLite

-
BetterTransformer

A one-liner integration to use PyTorch's BetterTransformer with Transformers models

diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 895ead2566..1ed7fa609a 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -25,7 +25,6 @@ If you'd like to use the accelerator-specific features of πŸ€— Optimum, you can | [ONNX runtime](https://onnxruntime.ai/docs/) | `python -m pip install optimum[onnxruntime]` | | [Intel Neural Compressor (INC)](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) | `python -m pip install optimum[neural-compressor]`| | [Intel OpenVINO](https://docs.openvino.ai/latest/index.html) | `python -m pip install optimum[openvino,nncf]` | -| [Graphcore IPU](https://www.graphcore.ai/products/ipu) | `python -m pip install optimum[graphcore]` | | [Habana Gaudi Processor (HPU)](https://habana.ai/training/) | `python -m pip install optimum[habana]` | diff --git a/docs/source/onnxruntime/package_reference/modeling_ort.mdx b/docs/source/onnxruntime/package_reference/modeling_ort.mdx index 3a7c6e409f..ebbfa1736e 100644 --- a/docs/source/onnxruntime/package_reference/modeling_ort.mdx +++ b/docs/source/onnxruntime/package_reference/modeling_ort.mdx @@ -113,3 +113,20 @@ The following ORT classes are available for the following custom tasks. #### ORTStableDiffusionPipeline [[autodoc]] onnxruntime.ORTStableDiffusionPipeline + +#### ORTStableDiffusionImg2ImgPipeline + +[[autodoc]] onnxruntime.ORTStableDiffusionImg2ImgPipeline + +#### ORTStableDiffusionInpaintPipeline + +[[autodoc]] onnxruntime.ORTStableDiffusionInpaintPipeline + + +#### ORTStableDiffusionXLPipeline + +[[autodoc]] onnxruntime.ORTStableDiffusionXLPipeline + +#### ORTStableDiffusionXLImg2ImgPipeline + +[[autodoc]] onnxruntime.ORTStableDiffusionXLImg2ImgPipeline \ No newline at end of file diff --git a/docs/source/onnxruntime/usage_guides/models.mdx b/docs/source/onnxruntime/usage_guides/models.mdx index c06c06c7ac..634c88fc0b 100644 --- a/docs/source/onnxruntime/usage_guides/models.mdx +++ b/docs/source/onnxruntime/usage_guides/models.mdx @@ -64,7 +64,7 @@ It is also possible, just as with regular [`~transformers.PreTrainedModel`]s, to ... ) ``` -## Export and inference of sequence-to-sequence models +## Sequence-to-sequence models Sequence-to-sequence (Seq2Seq) models can also be used when running inference with ONNX Runtime. When Seq2Seq models are exported to the ONNX format, they are decomposed into three parts that are later combined during inference: @@ -92,40 +92,139 @@ Here is an example of how you can load a T5 model to the ONNX format and run inf >>> # [{'translation_text': "Il n'est jamais sorti sans un livre sous son bras, et il est souvent revenu avec deux."}] ``` -## Export and inference of Stable Diffusion models +## Stable Diffusion Stable Diffusion models can also be used when running inference with ONNX Runtime. When Stable Diffusion models -are exported to the ONNX format, they are split into three components that are later combined during inference: +are exported to the ONNX format, they are split into four components that are later combined during inference: - The text encoder - The U-NET +- The VAE encoder - The VAE decoder Make sure you have πŸ€— Diffusers installed. To install `diffusers`: -``` +```bash pip install diffusers ``` +### Text-to-Image + Here is an example of how you can load an ONNX Stable Diffusion model and run inference using ONNX Runtime: ```python from optimum.onnxruntime import ORTStableDiffusionPipeline model_id = "runwayml/stable-diffusion-v1-5" -stable_diffusion = ORTStableDiffusionPipeline.from_pretrained(model_id, revision="onnx") +pipeline = ORTStableDiffusionPipeline.from_pretrained(model_id, revision="onnx") prompt = "sailing ship in storm by Leonardo da Vinci" -image = stable_diffusion(prompt).images[0] +image = pipeline(prompt).images[0] ``` To load your PyTorch model and convert it to ONNX on-the-fly, you can set `export=True`. ```python -stable_diffusion = ORTStableDiffusionPipeline.from_pretrained(model_id, export=True) +pipeline = ORTStableDiffusionPipeline.from_pretrained(model_id, export=True) # Don't forget to save the ONNX model save_directory = "a_local_path" -stable_diffusion.save_pretrained(save_directory) +pipeline.save_pretrained(save_directory) ``` ![img](https://huggingface.co/datasets/optimum/documentation-images/resolve/main/onnxruntime/stable_diffusion_v1_5_ort_sail_boat.png) + + +### Image-to-Image + +```python +import requests +import torch +from PIL import Image +from io import BytesIO +from optimum.onnxruntime import ORTStableDiffusionImg2ImgPipeline + +model_id = "runwayml/stable-diffusion-v1-5" +pipeline = ORTStableDiffusionImg2ImgPipeline.from_pretrained(model_id, revision="onnx") + +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((768, 512)) + +prompt = "A fantasy landscape, trending on artstation" + +image = pipeline(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0] +image.save("fantasy_landscape.png") +``` + +### Inpaint + +```python +import PIL +import requests +import torch +from io import BytesIO +from optimum.onnxruntime import ORTStableDiffusionInpaintPipeline + +model_id = "runwayml/stable-diffusion-inpainting" +pipeline = ORTStableDiffusionInpaintPipeline.from_pretrained(model_id, revision="onnx") + +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") + +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + +init_image = download_image(img_url).resize((512, 512)) +mask_image = download_image(mask_url).resize((512, 512)) + +prompt = "Face of a yellow cat, high resolution, sitting on a park bench" +image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0] +``` + + +## Stable Diffusion XL + +Before using `ORTStableDiffusionXLPipeline` make sure to have `diffusers` and `invisible_watermark` installed. You can install the libraries as follows: + +```bash +pip install diffusers +pip install invisible-watermark>=2.0 +``` + +### Text-to-Image + +Here is an example of how you can load a PyTorch SD XL model, convert it to ONNX on-the-fly and run inference using ONNX Runtime: + +```python +from optimum.onnxruntime import ORTStableDiffusionXLPipeline + +model_id = "stabilityai/stable-diffusion-xl-base-0.9" +pipeline = ORTStableDiffusionXLPipeline.from_pretrained(model_id, export=True) +prompt = "sailing ship in storm by Leonardo da Vinci" +image = pipeline(prompt).images[0] + +# Don't forget to save the ONNX model +save_directory = "a_local_path" +pipeline.save_pretrained(save_directory) + +``` + +### Image-to-Image + +The image can be refined by making use of a model like [stabilityai/stable-diffusion-xl-refiner-0.9](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9). In this case, you only have to output the latents from the base model. + + +```python +from optimum.onnxruntime import ORTStableDiffusionXLImg2ImgPipeline + +use_refiner = True +model_id = "stabilityai/stable-diffusion-xl-refiner-0.9" +refiner = ORTStableDiffusionXLImg2ImgPipeline.from_pretrained(model_id, export=True) + +image = pipeline(prompt=prompt, output_type="latent" if use_refiner else "pil").images[0] +image = refiner(prompt=prompt, image=image[None, :]).images[0] +image.save("sailing_ship.png") +``` diff --git a/docs/source/onnxruntime/usage_guides/quantization.mdx b/docs/source/onnxruntime/usage_guides/quantization.mdx index 27ba00184e..8ffe16f3d6 100644 --- a/docs/source/onnxruntime/usage_guides/quantization.mdx +++ b/docs/source/onnxruntime/usage_guides/quantization.mdx @@ -22,7 +22,7 @@ while the latter effectively handles quantization. -You can read the [conceptual guide on quantization](/concept_guides/quantization) to learn about quantization. It +You can read the [conceptual guide on quantization](../../concept_guides/quantization) to learn about quantization. It explains the main concepts that you will be using when performing quantization with the [`~optimum.onnxruntime.ORTQuantizer`]. @@ -63,7 +63,7 @@ Quantizing an ONNX model can be done as follows: optimum-cli onnxruntime quantize --onnx_model onnx_model_location/ --avx512 -o quantized_model/ ``` -This quantize all the ONNX files in `onnx_model_location` with the AVX-512 instructions. +This quantize all the ONNX files in `onnx_model_location` with the AVX-512 instructions. ## Creating an `ORTQuantizer` diff --git a/docs/source/onnxruntime/usage_guides/trainer.mdx b/docs/source/onnxruntime/usage_guides/trainer.mdx index 6b466b7257..50c6b4d77a 100644 --- a/docs/source/onnxruntime/usage_guides/trainer.mdx +++ b/docs/source/onnxruntime/usage_guides/trainer.mdx @@ -236,6 +236,47 @@ in the Optimum repository. +## ORTModule+StableDiffusion + +Optimum supports accelerating Hugging Face Diffusers with ONNX Runtime in [this example](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training/stable-diffusion/text-to-image). +The core changes required to enable ONNX Runtime Training are summarized below: + +```diff +import torch +from diffusers import AutoencoderKL, UNet2DConditionModel +from transformers import CLIPTextModel + ++from onnxruntime.training.ortmodule import ORTModule ++from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer + +unet = UNet2DConditionModel.from_pretrained( + "CompVis/stable-diffusion-v1-4", + subfolder="unet", + ... +) +text_encoder = CLIPTextModel.from_pretrained( + "CompVis/stable-diffusion-v1-4", + subfolder="text_encoder", + ... +) +vae = AutoencoderKL.from_pretrained( + "CompVis/stable-diffusion-v1-4", + subfolder="vae", + ... +) + +optimizer = torch.optim.AdamW( + unet.parameters(), + ... +) + ++vae = ORTModule(vae) ++text_encoder = ORTModule(text_encoder) ++unet = ORTModule(unet) + ++optimizer = ORT_FP16_Optimizer(optimizer) +``` + ## Other Resources * Blog posts diff --git a/docs/source/quicktour.mdx b/docs/source/quicktour.mdx index 507aee155e..11b0bb1c18 100644 --- a/docs/source/quicktour.mdx +++ b/docs/source/quicktour.mdx @@ -16,6 +16,30 @@ This quick tour is intended for developers who are ready to dive into the code a ## Accelerated inference +#### OpenVINO + +To load a model and run inference with OpenVINO Runtime, you can just replace your `AutoModelForXxx` class with the corresponding `OVModelForXxx` class. +If you want to load a PyTorch checkpoint, set `export=True` to convert your model to the OpenVINO IR (Intermediate Representation). + +```diff +- from transformers import AutoModelForSequenceClassification ++ from optimum.intel.openvino import OVModelForSequenceClassification + from transformers import AutoTokenizer, pipeline + + # Download a tokenizer and model from the Hub and convert to OpenVINO format + tokenizer = AutoTokenizer.from_pretrained(model_id) + model_id = "distilbert-base-uncased-finetuned-sst-2-english" +- model = AutoModelForSequenceClassification.from_pretrained(model_id) ++ model = OVModelForSequenceClassification.from_pretrained(model_id, export=True) + + # Run inference! + classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) + results = classifier("He's a dreadful magician.") +``` + +You can find more examples in the [documentation](https://huggingface.co/docs/optimum/intel/inference) and in the [examples](https://github.com/huggingface/optimum-intel/tree/main/examples/openvino). + + #### ONNX Runtime To accelerate inference with ONNX Runtime, πŸ€— Optimum uses _configuration objects_ to define parameters for graph optimization and quantization. These objects are then used to instantiate dedicated _optimizers_ and _quantizers_. @@ -67,30 +91,6 @@ In this example, we've quantized a model from the Hugging Face Hub, in the same You can find more examples in the [documentation](https://huggingface.co/docs/optimum/onnxruntime/quickstart) and in the [examples](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime). -#### Intel - -To load a model and run inference with OpenVINO Runtime, you can just replace your `AutoModelForXxx` class with the corresponding `OVModelForXxx` class. -If you want to load a PyTorch checkpoint, set `export=True` to convert your model to the OpenVINO IR (Intermediate Representation). - -```diff -- from transformers import AutoModelForSequenceClassification -+ from optimum.intel.openvino import OVModelForSequenceClassification - from transformers import AutoTokenizer, pipeline - - # Download a tokenizer and model from the Hub and convert to OpenVINO format - tokenizer = AutoTokenizer.from_pretrained(model_id) - model_id = "distilbert-base-uncased-finetuned-sst-2-english" -- model = AutoModelForSequenceClassification.from_pretrained(model_id) -+ model = OVModelForSequenceClassification.from_pretrained(model_id, export=True) - - # Run inference! - classifier = pipeline("text-classification", model=model, tokenizer=tokenizer) - results = classifier("He's a dreadful magician.") -``` - -You can find more examples in the [documentation](https://huggingface.co/docs/optimum/intel/inference) and in the [examples](https://github.com/huggingface/optimum-intel/tree/main/examples/openvino). - - ## Accelerated training #### Habana @@ -130,45 +130,6 @@ To train transformers on Habana's Gaudi processors, πŸ€— Optimum provides a `Gau You can find more examples in the [documentation](https://huggingface.co/docs/optimum/habana/quickstart) and in the [examples](https://github.com/huggingface/optimum-habana/tree/main/examples). -#### Graphcore - -To train transformers on Graphcore's IPUs, πŸ€— Optimum provides a `IPUTrainer` that is very similar to the πŸ€— Transformers [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: - -```diff -- from transformers import Trainer, TrainingArguments -+ from optimum.graphcore import IPUConfig, IPUTrainer, IPUTrainingArguments - - # Download a pretrained model from the Hub - model = AutoModelForXxx.from_pretrained("bert-base-uncased") - - # Define the training arguments -- training_args = TrainingArguments( -+ training_args = IPUTrainingArguments( - output_dir="path/to/save/folder/", -+ ipu_config_name="Graphcore/bert-base-ipu", # Any IPUConfig on the Hub or stored locally - ... - ) - - # Define the configuration to compile and put the model on the IPU -+ ipu_config = IPUConfig.from_pretrained(training_args.ipu_config_name) - - # Initialize the trainer -- trainer = Trainer( -+ trainer = IPUTrainer( - model=model, -+ ipu_config=ipu_config - args=training_args, - train_dataset=train_dataset - ... - ) - - # Use Graphcore IPU for training! - trainer.train() -``` - -You can find more examples in the [documentation](https://huggingface.co/docs/optimum/graphcore/quickstart) and in the [examples](https://github.com/huggingface/optimum-graphcore/tree/main/examples). - - #### ONNX Runtime To train transformers with ONNX Runtime's acceleration features, πŸ€— Optimum provides a `ORTTrainer` that is very similar to the πŸ€— Transformers [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer). Here is a simple example: diff --git a/examples/onnxruntime/training/language-modeling/run_clm.py b/examples/onnxruntime/training/language-modeling/run_clm.py index 219d163010..2807d3f721 100644 --- a/examples/onnxruntime/training/language-modeling/run_clm.py +++ b/examples/onnxruntime/training/language-modeling/run_clm.py @@ -435,10 +435,12 @@ def tokenize_function(examples): if data_args.block_size is None: block_size = tokenizer.model_max_length - logger.warning( - f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "You can change that default value by passing --block_size xxx." - ) + if block_size > 1024: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx." + ) + block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( diff --git a/examples/onnxruntime/training/stable-diffusion/text-to-image/README.md b/examples/onnxruntime/training/stable-diffusion/text-to-image/README.md new file mode 100644 index 0000000000..283fc5ff11 --- /dev/null +++ b/examples/onnxruntime/training/stable-diffusion/text-to-image/README.md @@ -0,0 +1,134 @@ +# Stable Diffusion Text-to-Image Fine-Tuning + +This example shows how to leverage ONNX Runtime Training to fine-tune stable diffusion model on your own dataset. + +Our team has tested finetuning `CompVis/stable-diffusion-v1-4` model on the `lambdalabs/pokemon-blip-captions` dataset and achieved the following speedup: +![image](https://github.com/microsoft/onnxruntime-training-examples/assets/31260940/00f199b1-3a84-4369-924d-fd6c613bd3b4) + +___Note___: + +___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___ + + +## Running locally with PyTorch +### Installing the dependencies + +___Note___: This example requires PyTorch nightly and [ONNX Runtime](https://github.com/Microsoft/onnxruntime) nightly +```bash +pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu118 +pip install onnxruntime-training --pre -f https://download.onnxruntime.ai/onnxruntime_nightly_cu118.html +python -m onnxruntime.training.ortmodule.torch_cpp_extensions.install +``` +Or get your environment ready via Docker: [examples/onnxruntime/training/docker/Dockerfile-ort-nightly-cu118](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/training/docker/Dockerfile-ort-nightly-cu118) + +Then, cd in the example folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [πŸ€—Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +### Pokemon example + +You have to be a registered user in πŸ€— Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). + +Run the following command to authenticate your token + +```bash +huggingface-cli login +``` + +If you have already cloned the repo, then you won't need to go through these steps. + +
+ +#### Hardware +Cited performance metrics used an NIVIDIA V100, 8-GPU cluster. + +**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export dataset_name="lambdalabs/pokemon-blip-captions" + +accelerate launch --mixed_precision="fp16" train_text_to_image.py --ort \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$dataset_name \ + --use_ema \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" +``` + + + +To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). +If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export TRAIN_DIR="path_to_your_dataset" + +accelerate launch --mixed_precision="fp16" train_text_to_image.py --ort \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$TRAIN_DIR \ + --use_ema \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" +``` + + +Once the training is finished the model will be saved in the `output_dir` specified in the command. In this example it's `sd-pokemon-model`. To load the fine-tuned model for inference just pass that path to `ORTStableDiffusionPipeline` + + +```python +from optimum.onnxruntime import ORTStableDiffusionPipeline + +model_path = "path_to_saved_model" +pipe = ORTStableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16) +pipe.to("cuda") + +image = pipe(prompt="yoda").images[0] +image.save("yoda-pokemon.png") +``` + +#### Training with multiple GPUs + +`accelerate` allows for seamless multi-GPU training. Follow the instructions [here](https://huggingface.co/docs/accelerate/basic_tutorials/launch) +for running distributed training with `accelerate`. Here is an example command: + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export dataset_name="lambdalabs/pokemon-blip-captions" + +accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py --ort \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$dataset_name \ + --use_ema \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" +``` \ No newline at end of file diff --git a/examples/onnxruntime/training/stable-diffusion/text-to-image/requirements.txt b/examples/onnxruntime/training/stable-diffusion/text-to-image/requirements.txt new file mode 100644 index 0000000000..58e576d8b6 --- /dev/null +++ b/examples/onnxruntime/training/stable-diffusion/text-to-image/requirements.txt @@ -0,0 +1,7 @@ +accelerate>=0.16.0 +transformers>=4.25.1 +datasets +git+https://github.com/huggingface/diffusers +ftfy +tensorboard +Jinja2 diff --git a/examples/onnxruntime/training/stable-diffusion/text-to-image/train_text_to_image.py b/examples/onnxruntime/training/stable-diffusion/text-to-image/train_text_to_image.py new file mode 100644 index 0000000000..dda30cd1b3 --- /dev/null +++ b/examples/onnxruntime/training/stable-diffusion/text-to-image/train_text_to_image.py @@ -0,0 +1,1009 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import logging +import math +import os +import random +from pathlib import Path + +import accelerate +import datasets +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available +from huggingface_hub import create_repo, upload_folder +from packaging import version +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +from transformers.utils import ContextManagers + + +if is_wandb_available(): + import wandb + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.17.1") + +logger = get_logger(__name__, log_level="INFO") + +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): + logger.info("Running validation... ") + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=accelerator.unwrap_model(vae), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet), + safety_checker=None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + elif tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") + for i, image in enumerate(images) + ] + } + ) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + torch.cuda.empty_cache() + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that πŸ€— Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more docs" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + parser.add_argument( + "--ort", + action="store_true", + default=False, + help=("Leverages ONNX Runtime Training to accelerate fine-tuning"), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + args = parse_args() + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration( + total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir + ) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + ) + + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + # Create EMA for the unet. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.ort: + from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer + + optimizer = ORT_FP16_Optimizer(optimizer) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + examples["input_ids"] = tokenize_captions(examples) + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + input_ids = torch.stack([example["input_ids"] for example in examples]) + return {"pixel_values": pixel_values, "input_ids": input_ids} + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + if args.use_ema: + ema_unet.to(accelerator.device) + + if args.ort: + from onnxruntime.training.ortmodule import ORTModule + + vae = ORTModule(vae) + text_encoder = ORTModule(text_encoder) + unet = ORTModule(unet) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompts") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + if args.input_perturbation: + new_noise = noise + args.input_perturbation * torch.randn_like(noise) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.input_perturbation: + noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Predict the noise residual and compute loss + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + if args.ort: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + checkpoint_dir = dirs[-1] if len(dirs) > 0 else None + + # reload PyTorch model with weights trained via ORT Training + # this is required since ORTModule object cannot be passed into StableDiffusionPipeline + root_dir = Path(__file__).resolve().parent + checkpoint_dir = root_dir / args.output_dir / checkpoint_dir + unet = UNet2DConditionModel.from_pretrained(str(checkpoint_dir), subfolder="unet") + + # reload pre-trained text_encoder and vae + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + else: + unet = accelerator.unwrap_model(unet) + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=text_encoder, + vae=vae, + unet=unet, + revision=args.revision, + torch_dtype=torch.float16, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 4340f906e6..574636bd25 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -74,7 +74,8 @@ def gpt2_wrapped_scaled_dot_product( # torch.Tensor.expand does no memory copy causal_mask = causal_mask.expand(batch_size, -1, -1, -1) - attention_mask = causal_mask + attention_mask + if attention_mask is not None: + attention_mask = causal_mask + attention_mask sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False @@ -127,7 +128,8 @@ def gpt_neo_wrapped_scaled_dot_product( # torch.Tensor.expand does no memory copy causal_mask = causal_mask.expand(batch_size, -1, -1, -1) - attention_mask = causal_mask + attention_mask + if attention_mask is not None: + attention_mask = causal_mask + attention_mask sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False @@ -587,9 +589,6 @@ def llama_forward( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) - # This line is necessary for numerical equivalence, although I'm not sure it is useful in any way. - attention_mask = torch.max(attention_mask, torch.tensor(torch.finfo(attention_mask.dtype).min)) - attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) diff --git a/optimum/commands/__init__.py b/optimum/commands/__init__.py index 004d1a7263..540ea4dd86 100644 --- a/optimum/commands/__init__.py +++ b/optimum/commands/__init__.py @@ -15,5 +15,5 @@ from .base import BaseOptimumCLICommand, CommandInfo, RootOptimumCLICommand from .env import EnvironmentCommand from .export import ExportCommand, ONNXExportCommand, TFLiteExportCommand -from .onnxruntime import ONNXRuntimeCommand, ONNXRuntimmeOptimizeCommand, ONNXRuntimmeQuantizeCommand +from .onnxruntime import ONNXRuntimeCommand, ONNXRuntimeOptimizeCommand, ONNXRuntimeQuantizeCommand from .optimum_cli import register_optimum_cli_subcommand diff --git a/optimum/commands/export/onnx.py b/optimum/commands/export/onnx.py index f607d7c6c7..25922fa6d4 100644 --- a/optimum/commands/export/onnx.py +++ b/optimum/commands/export/onnx.py @@ -232,5 +232,6 @@ def run(self): trust_remote_code=self.args.trust_remote_code, pad_token_id=self.args.pad_token_id, for_ort=self.args.for_ort, + use_subprocess=True, **input_shapes, ) diff --git a/optimum/commands/onnxruntime/__init__.py b/optimum/commands/onnxruntime/__init__.py index b1112987b4..1b9c24c3b2 100644 --- a/optimum/commands/onnxruntime/__init__.py +++ b/optimum/commands/onnxruntime/__init__.py @@ -14,5 +14,5 @@ # limitations under the License. from .base import ONNXRuntimeCommand -from .optimize import ONNXRuntimmeOptimizeCommand -from .quantize import ONNXRuntimmeQuantizeCommand +from .optimize import ONNXRuntimeOptimizeCommand +from .quantize import ONNXRuntimeQuantizeCommand diff --git a/optimum/commands/onnxruntime/base.py b/optimum/commands/onnxruntime/base.py index 24c64300cc..53e3245ea4 100644 --- a/optimum/commands/onnxruntime/base.py +++ b/optimum/commands/onnxruntime/base.py @@ -15,8 +15,8 @@ """optimum.onnxruntime command-line interface base classes.""" from .. import BaseOptimumCLICommand, CommandInfo -from .optimize import ONNXRuntimmeOptimizeCommand -from .quantize import ONNXRuntimmeQuantizeCommand +from .optimize import ONNXRuntimeOptimizeCommand +from .quantize import ONNXRuntimeQuantizeCommand class ONNXRuntimeCommand(BaseOptimumCLICommand): @@ -28,11 +28,11 @@ class ONNXRuntimeCommand(BaseOptimumCLICommand): CommandInfo( name="optimize", help="Optimize ONNX models.", - subcommand_class=ONNXRuntimmeOptimizeCommand, + subcommand_class=ONNXRuntimeOptimizeCommand, ), CommandInfo( name="quantize", help="Dynammic quantization for ONNX models.", - subcommand_class=ONNXRuntimmeQuantizeCommand, + subcommand_class=ONNXRuntimeQuantizeCommand, ), ) diff --git a/optimum/commands/onnxruntime/optimize.py b/optimum/commands/onnxruntime/optimize.py index 68dac23198..5890e0a07c 100644 --- a/optimum/commands/onnxruntime/optimize.py +++ b/optimum/commands/onnxruntime/optimize.py @@ -69,7 +69,7 @@ def parse_args_onnxruntime_optimize(parser: "ArgumentParser"): ) -class ONNXRuntimmeOptimizeCommand(BaseOptimumCLICommand): +class ONNXRuntimeOptimizeCommand(BaseOptimumCLICommand): @staticmethod def parse_args(parser: "ArgumentParser"): return parse_args_onnxruntime_optimize(parser) diff --git a/optimum/commands/onnxruntime/quantize.py b/optimum/commands/onnxruntime/quantize.py index b82059d4b2..0ce7e6c3dc 100644 --- a/optimum/commands/onnxruntime/quantize.py +++ b/optimum/commands/onnxruntime/quantize.py @@ -63,7 +63,7 @@ def parse_args_onnxruntime_quantize(parser: "ArgumentParser"): ) -class ONNXRuntimmeQuantizeCommand(BaseOptimumCLICommand): +class ONNXRuntimeQuantizeCommand(BaseOptimumCLICommand): @staticmethod def parse_args(parser: "ArgumentParser"): return parse_args_onnxruntime_quantize(parser) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 69986a091c..6cefc7c571 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -15,6 +15,7 @@ """Entry point to the optimum.exporters.onnx command line.""" import argparse +import os from pathlib import Path from requests.exceptions import ConnectionError as RequestsConnectionError @@ -22,7 +23,7 @@ from transformers.utils import is_torch_available from ...commands.export.onnx import parse_args_onnx -from ...utils import DEFAULT_DUMMY_SHAPES, logging +from ...utils import DEFAULT_DUMMY_SHAPES, ONNX_WEIGHTS_NAME, logging from ...utils.save_utils import maybe_save_preprocessors from ..error_utils import AtolError, OutputMatchError, ShapeError from ..tasks import TasksManager @@ -30,6 +31,9 @@ from .constants import UNPICKABLE_ARCHS from .convert import export_models, validate_models_outputs from .utils import ( + _get_submodels_for_export_decoder, + _get_submodels_for_export_encoder_decoder, + _get_submodels_for_export_stable_diffusion, get_decoder_models_for_export, get_encoder_decoder_models_for_export, get_stable_diffusion_models_for_export, @@ -39,13 +43,92 @@ if is_torch_available(): import torch -from typing import Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union +if TYPE_CHECKING: + from transformers import PreTrainedModel, TFPreTrainedModel + + from .base import OnnxConfig + logger = logging.get_logger() logger.setLevel(logging.INFO) +def _get_submodels_and_onnx_configs( + model: Union["PreTrainedModel", "TFPreTrainedModel"], + task: str, + monolith: bool, + custom_onnx_configs: Dict, + custom_architecture: bool, + fn_get_submodels: Optional[Callable] = None, +): + is_stable_diffusion = "stable-diffusion" in task + if not custom_architecture: + if is_stable_diffusion: + onnx_config = None + models_and_onnx_configs = get_stable_diffusion_models_for_export(model) + else: + onnx_config_constructor = TasksManager.get_exporter_config_constructor( + model=model, exporter="onnx", task=task + ) + onnx_config = onnx_config_constructor(model.config) + + if ( + model.config.is_encoder_decoder + and task.startswith(TasksManager._ENCODER_DECODER_TASKS) + and not monolith + ): + models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config) + elif task.startswith("text-generation") and not monolith: + models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config) + else: + models_and_onnx_configs = {"model": (model, onnx_config)} + + # When specifying custom ONNX configs for supported transformers architectures, we do + # not force to specify a custom ONNX config for each submodel. + for key, custom_onnx_config in custom_onnx_configs.items(): + models_and_onnx_configs[key] = (models_and_onnx_configs[key][0], custom_onnx_config) + else: + onnx_config = None + submodels_for_export = None + models_and_onnx_configs = {} + + if fn_get_submodels is not None: + submodels_for_export = fn_get_submodels(model) + else: + if is_stable_diffusion: + submodels_for_export = _get_submodels_for_export_stable_diffusion(model) + elif ( + model.config.is_encoder_decoder + and task.startswith(TasksManager._ENCODER_DECODER_TASKS) + and not monolith + ): + submodels_for_export = _get_submodels_for_export_encoder_decoder( + model, use_past=task.endswith("-with-past") + ) + elif task.startswith("text-generation") and not monolith: + submodels_for_export = _get_submodels_for_export_decoder(model, use_past=task.endswith("-with-past")) + else: + submodels_for_export = {"model": model} + + if submodels_for_export.keys() != custom_onnx_configs.keys(): + logger.error(f"ONNX custom configs for: {', '.join(custom_onnx_configs.keys())}") + logger.error(f"Submodels to export: {', '.join(submodels_for_export.keys())}") + raise ValueError( + "Trying to export a custom model, but could not find as many custom ONNX configs as the number of submodels to export. Please specifiy the fn_get_submodels argument, that should return a dictionary of submodules with as many items as the provided custom_onnx_configs dictionary." + ) + + for key, custom_onnx_config in custom_onnx_configs.items(): + models_and_onnx_configs[key] = (submodels_for_export[key], custom_onnx_config) + + # Default to the first ONNX config for stable-diffusion and custom architecture case. + if onnx_config is None: + onnx_config = next(iter(models_and_onnx_configs.values()))[1] + + return onnx_config, models_and_onnx_configs + + def main_export( model_name_or_path: str, output: Union[str, Path], @@ -68,6 +151,10 @@ def main_export( use_auth_token: Optional[Union[bool, str]] = None, for_ort: bool = False, do_validation: bool = True, + model_kwargs: Optional[Dict[str, Any]] = None, + custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, + fn_get_submodels: Optional[Callable] = None, + use_subprocess: bool = False, **kwargs_shapes, ): """ @@ -127,6 +214,21 @@ def main_export( use_auth_token (`Optional[str]`, defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated when running `transformers-cli login` (stored in `~/.huggingface`). + model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): + Experimental usage: keyword arguments to pass to the model during + the export. This argument should be used along the `custom_onnx_configs` argument + in case, for example, the model inputs/outputs are changed (for example, if + `model_kwargs={"output_attentions": True}` is passed). + custom_onnx_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`): + Experimental usage: override the default ONNX config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model). + fn_get_submodels (`Optional[Callable]`, defaults to `None`): + Experimental usage: Override the default submodels that are used at the export. This is + especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success. + use_subprocess (`bool`): + Do the ONNX exported model validation in subprocesses. This is especially useful when + exporting on CUDA device, where ORT does not release memory at inference session + destruction. When set to `True`, the `main_export` call should be guarded in + `if __name__ == "__main__":` block. **kwargs_shapes (`Dict`): Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. @@ -137,6 +239,20 @@ def main_export( >>> main_export("gpt2", output="gpt2_onnx/") ``` """ + if optimize == "O4" and device != "cuda": + raise ValueError( + "Requested O4 optimization, but this optimization requires to do the export on GPU." + " Please pass the argument `--device cuda`." + ) + + if (framework == "tf" and fp16 is True) or not is_torch_available(): + raise ValueError("The --fp16 option is supported only for PyTorch.") + + if fp16 is True and device == "cpu": + raise ValueError( + "The --fp16 option is supported only when exporting on GPU. Please pass the option `--device cuda`." + ) + output = Path(output) if not output.exists(): output.mkdir(parents=True) @@ -152,14 +268,6 @@ def main_export( framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework) - if (framework == "tf" and fp16 is True) or not is_torch_available(): - raise ValueError("The --fp16 option is supported only for PyTorch.") - - if fp16 is True and device == "cpu": - raise ValueError( - "The --fp16 option is supported only when exporting on GPU. Please pass the option `--device cuda`." - ) - # get the shapes to be used to generate dummy inputs input_shapes = {} for input_name in DEFAULT_DUMMY_SHAPES.keys(): @@ -196,8 +304,36 @@ def main_export( device=device, ) - if task != "stable-diffusion" and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type( - model.config.model_type.replace("_", "-"), "onnx" + custom_architecture = False + is_stable_diffusion = "stable-diffusion" in task + model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-") + + if not is_stable_diffusion: + if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE: + raise ValueError( + f"{model_type} is not supported yet. Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. " + f"If you want to support {model_type} please propose a PR or open up an issue." + ) + if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task( + task, exporter="onnx" + ): + custom_architecture = True + + # TODO: support onnx_config.py in the model repo + if custom_architecture and custom_onnx_configs is None: + raise ValueError( + f"Trying to export a {model.config.model_type.replace('-', '_')} model, that is a custom or unsupported architecture for the task {task}, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. For the task {task}, the Optimum ONNX exporter supports natively the architectures: {TasksManager.get_supported_model_type_for_task(task, exporter='onnx')}." + ) + + if custom_architecture and original_task == "auto": + raise ValueError( + f'Automatic task detection is not supported with custom architectures. Please specify the `task` argument. Suggestion: task="{task}" (or task="{task}-with-past" if the model is decoder-based and supports KV cache)' + ) + + if ( + not custom_architecture + and not is_stable_diffusion + and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx") ): if original_task == "auto": # Make -with-past the default if --task was not explicitely specified task = task + "-with-past" @@ -223,10 +359,16 @@ def main_export( possible_synonyms = "" logger.info(f"Automatic task detection to {task}{possible_synonyms}.") - if task != "stable-diffusion": - onnx_config_constructor = TasksManager.get_exporter_config_constructor(model=model, exporter="onnx", task=task) - onnx_config = onnx_config_constructor(model.config) + onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs( + model=model, + task=task, + monolith=monolith, + custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, + custom_architecture=custom_architecture, + fn_get_submodels=fn_get_submodels, + ) + if not is_stable_diffusion: needs_pad_token_id = ( isinstance(onnx_config, OnnxConfigWithPast) and getattr(model.config, "pad_token_id", None) is None @@ -250,8 +392,8 @@ def main_export( if opset < onnx_config.DEFAULT_ONNX_OPSET: raise ValueError( - f"Opset {opset} is not sufficient to export {model.config.model_type}. " - f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required." + f"Opset {opset} is not sufficient to export {model_type}. " + f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required." ) if atol is None: atol = onnx_config.ATOL_FOR_VALIDATION @@ -265,21 +407,6 @@ def main_export( generation_config.save_pretrained(output) maybe_save_preprocessors(model_name_or_path, output) - if task == "stable-diffusion": - onnx_files_subpaths = [ - "text_encoder/model.onnx", - "unet/model.onnx", - "vae_encoder/model.onnx", - "vae_decoder/model.onnx", - ] - models_and_onnx_configs = get_stable_diffusion_models_for_export(model) - # Saving the additional components needed to perform inference. - model.tokenizer.save_pretrained(output.joinpath("tokenizer")) - model.scheduler.save_pretrained(output.joinpath("scheduler")) - if model.feature_extractor is not None: - model.feature_extractor.save_pretrained(output.joinpath("feature_extractor")) - model.save_config(output) - else: if model.config.is_encoder_decoder and task.startswith("text-generation"): raise ValueError( f"model.config.is_encoder_decoder is True and task is `{task}`, which are incompatible. If the task was auto-inferred, please fill a bug report" @@ -288,25 +415,33 @@ def main_export( ) onnx_files_subpaths = None - if ( - model.config.is_encoder_decoder - and task.startswith( - ( - "text2text-generation", - "automatic-speech-recognition", - "image-to-text", - "feature-extraction-with-past", - "visual-question-answering", - "document-question-answering", - ) - ) - and not monolith - ): - models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config) - elif task.startswith("text-generation") and not monolith: - models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config) - else: - models_and_onnx_configs = {"model": (model, onnx_config)} + else: + # save the subcomponent configuration + for model_name in models_and_onnx_configs: + subcomponent = models_and_onnx_configs[model_name][0] + if hasattr(subcomponent, "save_config"): + subcomponent.save_config(output / model_name) + elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"): + subcomponent.config.save_pretrained(output / model_name) + + onnx_files_subpaths = [os.path.join(name_dir, ONNX_WEIGHTS_NAME) for name_dir in models_and_onnx_configs] + + # Saving the additional components needed to perform inference. + model.scheduler.save_pretrained(output.joinpath("scheduler")) + + feature_extractor = getattr(model, "feature_extractor", None) + if feature_extractor is not None: + feature_extractor.save_pretrained(output.joinpath("feature_extractor")) + + tokenizer = getattr(model, "tokenizer", None) + if tokenizer is not None: + tokenizer.save_pretrained(output.joinpath("tokenizer")) + + tokenizer_2 = getattr(model, "tokenizer_2", None) + if tokenizer_2 is not None: + tokenizer_2.save_pretrained(output.joinpath("tokenizer_2")) + + model.save_config(output) _, onnx_outputs = export_models( models_and_onnx_configs=models_and_onnx_configs, @@ -316,14 +451,9 @@ def main_export( input_shapes=input_shapes, device=device, dtype="fp16" if fp16 is True else None, + model_kwargs=model_kwargs, ) - if optimize == "O4" and device != "cuda": - raise ValueError( - "Requested O4 optimization, but this optimization requires to do the export on GPU." - " Please pass the argument `--device cuda`." - ) - if optimize is not None: from ...onnxruntime import AutoOptimizationConfig, ORTOptimizer @@ -338,7 +468,7 @@ def main_export( # Optionally post process the obtained ONNX file(s), for example to merge the decoder / decoder with past if any # TODO: treating stable diffusion separately is quite ugly - if not no_post_process and task != "stable-diffusion": + if not no_post_process and not is_stable_diffusion: try: logger.info("Post-processing the exported models...") models_and_onnx_configs, onnx_files_subpaths = onnx_config.post_process_exported_models( @@ -349,8 +479,7 @@ def main_export( f"The post-processing of the ONNX export failed. The export can still be performed by passing the option --no-post-process. Detailed error: {e}" ) - use_subprocess = True - if task == "stable-diffusion": + if is_stable_diffusion: use_subprocess = ( False # TODO: fix Can't pickle local object 'get_stable_diffusion_models_for_export..' ) @@ -371,6 +500,7 @@ def main_export( device=device, dtype=torch_dtype, use_subprocess=use_subprocess, + model_kwargs=model_kwargs, ) logger.info(f"The ONNX export succeeded and the exported model was saved at: {output.as_posix()}") except ShapeError as e: diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 6b2b8bdb2a..03bdd20f2e 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -133,14 +133,7 @@ class OnnxConfig(ExportConfig, ABC): "feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}), "fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "image-classification": OrderedDict({"logits": {0: "batch_size"}}), - # TODO: Is this the same thing as semantic-segmentation? - "image-segmentation": OrderedDict( - { - "logits": {0: "batch_size", 1: "num_queries"}, - "pred_boxes": {0: "batch_size", 1: "num_queries"}, - "pred_masks": {0: "batch_size", 1: "num_queries"}, - } - ), + "image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}), "image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "mask-generation": OrderedDict({"logits": {0: "batch_size"}}), "masked-im": OrderedDict( @@ -305,8 +298,10 @@ def fix_dynamic_axes( del onnx_model gc.collect() - def patch_model_for_export(self, model: Union["PreTrainedModel", "TFPreTrainedModel"]) -> ModelPatcher: - return ModelPatcher(self, model) + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> ModelPatcher: + return ModelPatcher(self, model, model_kwargs=model_kwargs) @property def values_override(self) -> Optional[Dict[str, Any]]: @@ -435,7 +430,10 @@ def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> `Dict[str, Any]`: Outputs with flattened structure and key mapping this new structure. """ - return {f"{name}.{idx}": item for idx, item in enumerate(itertools.chain.from_iterable(field))} + if isinstance(field[0], (list, tuple)): + return {f"{name}.{idx}": item for idx, item in enumerate(itertools.chain.from_iterable(field))} + else: + return {f"{name}.{idx}": item for idx, item in enumerate(field)} def generate_dummy_inputs_for_validation( self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None @@ -807,8 +805,10 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): flattened_output[f"{name}.{idx}.encoder.key"] = t[2] flattened_output[f"{name}.{idx}.encoder.value"] = t[3] - def patch_model_for_export(self, model: Union["PreTrainedModel", "TFPreTrainedModel"]) -> ModelPatcher: - return Seq2SeqModelPatcher(self, model) + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> ModelPatcher: + return Seq2SeqModelPatcher(self, model, model_kwargs=model_kwargs) def post_process_exported_models( self, diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index a0be629cdf..cad2fdcb0f 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -14,13 +14,15 @@ # limitations under the License. """ONNX model check and export functions.""" +import copy +import gc import multiprocessing as mp import os import traceback from inspect import signature from itertools import chain from pathlib import Path -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import onnx @@ -98,6 +100,7 @@ def validate_models_outputs( device: str = "cpu", dtype: Optional["torch.dtype"] = None, use_subprocess: Optional[bool] = True, + model_kwargs: Optional[Dict[str, Any]] = None, ): """ Validates the export of several models, by checking that the outputs from both the reference and the exported model match. @@ -123,7 +126,9 @@ def validate_models_outputs( Data type of the inputs to perform validation on. Validation on float16 is supported only for PyTorch. use_subprocess (`Optional[bool]`, defaults to `True`): Launch validation of each exported model in a subprocess. - + model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): + Experimental usage: keyword arguments to pass to the model during + the export and validation. Raises: ValueError: If the outputs shapes or values do not match between the reference and the exported model. """ @@ -163,6 +168,7 @@ def validate_models_outputs( device=device, dtype=dtype, use_subprocess=use_subprocess, + model_kwargs=model_kwargs, ) except Exception as e: exceptions.append(e) @@ -183,6 +189,7 @@ def validate_model_outputs( device: str = "cpu", dtype: Optional["torch.dtype"] = None, use_subprocess: Optional[bool] = True, + model_kwargs: Optional[Dict[str, Any]] = None, ): """ Validates the export by checking that the outputs from both the reference and the exported model match. @@ -203,11 +210,24 @@ def validate_model_outputs( The device on which the ONNX model will be validated. Either `cpu` or `cuda`. Validation on a CUDA device is supported only for PyTorch. use_subprocess (`Optional[bool]`, defaults to `True`): Launch validation of each exported model in a subprocess. + model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): + Experimental usage: keyword arguments to pass to the model during + the export and validation. Raises: ValueError: If the outputs shapes or values do not match between the reference and the exported model. """ if use_subprocess: io_process = ValidationProcess( + config, reference_model, onnx_model, onnx_named_outputs, atol, input_shapes, device, dtype, model_kwargs + ) + io_process.start() + io_process.join() + + if io_process.exception: + error, traceback = io_process.exception + raise error + else: + _run_validation( config, reference_model, onnx_model, @@ -216,15 +236,8 @@ def validate_model_outputs( input_shapes, device, dtype, + model_kwargs=model_kwargs, ) - io_process.start() - io_process.join() - - if io_process.exception: - error, traceback = io_process.exception - raise error - else: - _run_validation(config, reference_model, onnx_model, onnx_named_outputs, atol, input_shapes, device, dtype) def _run_validation( @@ -236,9 +249,12 @@ def _run_validation( input_shapes: Optional[Dict] = None, device: str = "cpu", dtype: Optional["torch.dtype"] = None, + model_kwargs: Optional[Dict[str, Any]] = None, ): from onnxruntime import GraphOptimizationLevel, SessionOptions + model_kwargs = model_kwargs if model_kwargs is not None else {} + logger.info(f"Validating ONNX model {onnx_model.as_posix()}...") if atol is None: @@ -307,11 +323,14 @@ def _run_validation( value=reference_model_inputs[key], dtype=dtype, start_dtype=torch.float32 ) + # Some models may modify in place the inputs, hence the copy. + copy_reference_model_inputs = copy.deepcopy(reference_model_inputs) + if is_torch_available() and isinstance(reference_model, nn.Module): with torch.inference_mode(): - ref_outputs = reference_model(**reference_model_inputs) + ref_outputs = reference_model(**copy_reference_model_inputs, **model_kwargs) else: - ref_outputs = reference_model(**reference_model_inputs) + ref_outputs = reference_model(**copy_reference_model_inputs, **model_kwargs) ref_outputs_dict = {} # We flatten potential collection of outputs (i.e. past_keys) to a flat structure @@ -321,7 +340,8 @@ def _run_validation( if name == "past_key_values": name = "present" if isinstance(value, (list, tuple)): - value = config.flatten_output_collection_property(name, value) + onnx_output_name = config.torch_to_onnx_output_map.get(name, name) + value = config.flatten_output_collection_property(onnx_output_name, value) ref_outputs_dict.update(value) else: ref_outputs_dict[name] = value @@ -349,6 +369,8 @@ def _run_validation( if isinstance(value, (list, tuple)): value = config.flatten_output_collection_property(name, value) onnx_inputs.update({tensor_name: pt_tensor.cpu().numpy() for tensor_name, pt_tensor in value.items()}) + elif isinstance(value, dict): + onnx_inputs.update({tensor_name: pt_tensor.cpu().numpy() for tensor_name, pt_tensor in value.items()}) else: onnx_inputs[name] = value.cpu().numpy() @@ -365,7 +387,7 @@ def _run_validation( raise OutputMatchError( "ONNX model output names do not match reference model output names.\n" f"Reference model output names: {ref_outputs_set}\n" - f"ONNX model output names: {onnx_outputs_set}" + f"ONNX model output names: {onnx_outputs_set}\n" f"Difference: {onnx_outputs_set.difference(ref_outputs_set)}" ) else: @@ -427,6 +449,7 @@ def __init__( input_shapes: Optional[Dict] = None, device: str = "cpu", dtype: Optional["torch.dtype"] = None, + model_kwargs: Optional[Dict[str, Any]] = None, ): super().__init__() self._pconn, self._cconn = mp.Pipe() @@ -439,6 +462,7 @@ def __init__( self.input_shapes = input_shapes self.device = device self.dtype = dtype + self.model_kwargs = model_kwargs def run(self): try: @@ -451,6 +475,7 @@ def run(self): input_shapes=self.input_shapes, device=self.device, dtype=self.dtype, + model_kwargs=self.model_kwargs, ) except Exception as e: tb = traceback.format_exc() @@ -472,6 +497,7 @@ def export_pytorch( device: str = "cpu", dtype: Optional["torch.dtype"] = None, input_shapes: Optional[Dict] = None, + model_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[List[str], List[str]]: """ Exports a PyTorch model to an ONNX Intermediate Representation. @@ -492,6 +518,11 @@ def export_pytorch( Data type to remap the model inputs to. PyTorch-only. Only `torch.float16` is supported. input_shapes (`Optional[Dict]`, defaults to `None`): If specified, allows to use specific shapes for the example input provided to the ONNX exporter. + model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): + Experimental usage: keyword arguments to pass to the model during + the export. This argument should be used along the `custom_onnx_config` argument + in case, for example, the model inputs/outputs are changed (for example, if + `model_kwargs={"output_attentions": True}` is passed). Returns: `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named outputs from @@ -543,7 +574,7 @@ def remap(value): if is_torch_less_than_1_11: raise RuntimeError("The ONNX export using the PyTorch framework is only supported for v1.11+") else: - with config.patch_model_for_export(model): + with config.patch_model_for_export(model, model_kwargs=model_kwargs): # Export can work with named args but the dict containing named args has to be the last element of the args # tuple. onnx_export( @@ -569,6 +600,9 @@ def remap(value): # try free model memory del model del onnx_model + gc.collect() + if device.type == "cuda" and torch.cuda.is_available(): + torch.cuda.empty_cache() onnx_model = onnx.load( str(output), load_external_data=True @@ -673,6 +707,7 @@ def export_models( input_shapes: Optional[Dict] = None, disable_dynamic_axes_fix: Optional[bool] = False, dtype: Optional[str] = None, + model_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[List[List[str]], List[List[str]]]: """ Exports a Pytorch or TensorFlow encoder decoder model to an ONNX Intermediate Representation. @@ -698,6 +733,11 @@ def export_models( Whether to disable the default dynamic axes fixing. dtype (`Optional[str]`, defaults to `None`): Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported. + model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): + Experimental usage: keyword arguments to pass to the model during + the export. This argument should be used along the `custom_onnx_config` argument + in case, for example, the model inputs/outputs are changed (for example, if + `model_kwargs={"output_attentions": True}` is passed). Returns: `Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named outputs from the ONNX configuration. @@ -726,6 +766,7 @@ def export_models( input_shapes=input_shapes, disable_dynamic_axes_fix=disable_dynamic_axes_fix, dtype=dtype, + model_kwargs=model_kwargs, ) ) @@ -742,6 +783,7 @@ def export( input_shapes: Optional[Dict] = None, disable_dynamic_axes_fix: Optional[bool] = False, dtype: Optional[str] = None, + model_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[List[str], List[str]]: """ Exports a Pytorch or TensorFlow model to an ONNX Intermediate Representation. @@ -764,6 +806,11 @@ def export( Whether to disable the default dynamic axes fixing. dtype (`Optional[str]`, defaults to `None`): Data type to remap the model inputs to. PyTorch-only. Only `fp16` is supported. + model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): + Experimental usage: keyword arguments to pass to the model during + the export. This argument should be used along the `custom_onnx_config` argument + in case, for example, the model inputs/outputs are changed (for example, if + `model_kwargs={"output_attentions": True}` is passed). Returns: `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named outputs from @@ -814,10 +861,21 @@ def export( raise ValueError("Unsupported dtype, supported dtypes are: `torch.float16`.") export_output = export_pytorch( - model, config, opset, output, device=device, input_shapes=input_shapes, dtype=torch_dtype + model, + config, + opset, + output, + device=device, + input_shapes=input_shapes, + dtype=torch_dtype, + model_kwargs=model_kwargs, ) elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): + if model_kwargs is not None: + raise NotImplementedError( + "The argument `model_kwargs` is used only for PyTorch ONNX export, and unavailable for the Tensorflow export." + ) if device == "cuda": raise RuntimeError("`tf2onnx` does not support export on CUDA device.") if input_shapes is not None: diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 2b9b49a039..775c2a58d8 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -627,7 +627,7 @@ class ConvNextOnnxConfig(ViTOnnxConfig): class MobileViTOnnxConfig(ViTOnnxConfig): - pass + ATOL_FOR_VALIDATION = 1e-4 class RegNetOnnxConfig(ViTOnnxConfig): @@ -643,9 +643,14 @@ class DetrOnnxConfig(ViTOnnxConfig): DEFAULT_ONNX_OPSET = 12 @property - def inputs(self) -> Dict[str, Dict[int, str]]: - # TODO: is pixel mask needed? - return {**super().inputs, "pixel_mask": {0: "batch_size"}} + def outputs(self) -> Dict[str, Dict[int, str]]: + if self.task == "image-segmentation": + return { + "logits": {0: "batch_size", 1: "num_queries"}, + "pred_masks": {0: "batch_size", 1: "num_queries"}, + } + else: + return super().outputs class YolosOnnxConfig(ViTOnnxConfig): @@ -708,7 +713,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]: } -class CLIPTextOnnxConfig(TextEncoderOnnxConfig): +class CLIPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig): ATOL_FOR_VALIDATION = 1e-3 # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 @@ -716,6 +721,7 @@ class CLIPTextOnnxConfig(TextEncoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( vocab_size="vocab_size", sequence_length="max_position_embeddings", + num_layers="num_hidden_layers", allow_new=True, ) @@ -727,13 +733,33 @@ def inputs(self) -> Dict[str, Dict[int, str]]: @property def outputs(self) -> Dict[str, Dict[int, str]]: - return { + common_outputs = { + "text_embeds": {0: "batch_size", 1: "sequence_length"}, + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + } + if self._normalized_config.output_hidden_states: + for i in range(self._normalized_config.num_layers + 1): + common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"} + + return common_outputs + + +class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig): + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + common_outputs = { "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, "pooler_output": {0: "batch_size"}, } + if self._normalized_config.output_hidden_states: + for i in range(self._normalized_config.num_layers + 1): + common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"} + + return common_outputs def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) + if framework == "pt": import torch @@ -763,12 +789,19 @@ class UNetOnnxConfig(VisionOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: - return { + common_inputs = { "sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, "timestep": {0: "steps"}, "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, } + # TODO : add text_image, image and image_embeds + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": + common_inputs["text_embeds"] = {0: "batch_size"} + common_inputs["time_ids"] = {0: "batch_size"} + + return common_inputs + @property def outputs(self) -> Dict[str, Dict[int, str]]: return { @@ -784,8 +817,25 @@ def torch_to_onnx_output_map(self) -> Dict[str, str]: def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] + + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": + dummy_inputs["added_cond_kwargs"] = { + "text_embeds": dummy_inputs.pop("text_embeds"), + "time_ids": dummy_inputs.pop("time_ids"), + } + return dummy_inputs + def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]: + inputs = super().ordered_inputs(model=model) + # to fix mismatch between model forward signature and expected inputs + # a dictionnary of additional embeddings `added_cond_kwargs` is expected depending on config.addition_embed_type + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": + inputs["text_embeds"] = self.inputs["text_embeds"] + inputs["time_ids"] = self.inputs["time_ids"] + + return inputs + class VaeEncoderOnnxConfig(VisionOnnxConfig): ATOL_FOR_VALIDATION = 1e-2 @@ -1011,8 +1061,10 @@ class WavLMOnnxConfig(HubertOnnxConfig): # we need to set output_attentions=True in the model input to avoid calling # torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export # due to the op torch.nn.functional.multi_head_attention_forward used for WavLM - def patch_model_for_export(self, model: Union["PreTrainedModel", "TFPreTrainedModel"]) -> "ModelPatcher": - return WavLMModelPatcher(self, model) + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return WavLMModelPatcher(self, model, model_kwargs=model_kwargs) class ASTDummyAudioInputGenerator(DummyAudioInputGenerator): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index d13a910f49..7b8f1a238b 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -15,7 +15,7 @@ import dataclasses import functools import inspect -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union from ...utils import logging @@ -28,6 +28,20 @@ logger = logging.get_logger(__name__) +def overwride_arguments(args, kwargs, forward_signature, model_kwargs): + args = list(args) + + for argument in model_kwargs: + if argument in forward_signature.parameters: + argument_index = list(forward_signature.parameters.keys()).index(argument) + + args[argument_index] = model_kwargs[argument] + else: + kwargs[argument] = model_kwargs[argument] + + return args, kwargs + + @dataclasses.dataclass class PatchingSpec: """ @@ -50,7 +64,12 @@ class PatchingSpec: class ModelPatcher: - def __init__(self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"]): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): self._model = model patching_specs = config.PATCHING_SPECS @@ -64,6 +83,8 @@ def __init__(self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreT self.orig_forward_name = "forward" if hasattr(self._model, "forward") else "call" self.orig_forward = getattr(self._model, self.orig_forward_name) + self.model_kwargs = model_kwargs if model_kwargs is not None else {} + # TODO: remove that once we got rid of OnnxConfigWithLoss or we implemented it better. if config.__class__.__name__ == "OnnxConfigWithLoss": self.real_config = config._onnx_config @@ -75,14 +96,20 @@ def __init__(self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreT @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): + signature = inspect.signature(self.orig_forward) + args, kwargs = overwride_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + outputs = self.orig_forward(*args, **kwargs) filterd_outputs = {} - for k, v in outputs.items(): - if config.torch_to_onnx_output_map.get(k, k) in config.outputs or ( - allow_past_in_outputs and k.startswith("past_key_values") + for name, value in outputs.items(): + onnx_output_name = config.torch_to_onnx_output_map.get(name, name) + if ( + onnx_output_name in config.outputs + or (allow_past_in_outputs and name.startswith("past_key_values")) + or any(key.startswith(onnx_output_name) for key in config.outputs.keys()) ): - filterd_outputs[k] = v + filterd_outputs[name] = value return filterd_outputs self.patched_forward = patched_forward @@ -112,8 +139,13 @@ def __call__(self, *args, **kwargs): class Seq2SeqModelPatcher(ModelPatcher): - def __init__(self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"]): - super().__init__(config, model) + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) allow_past_in_outputs = ( hasattr(self.real_config, "use_present_in_outputs") and self.real_config.use_present_in_outputs @@ -126,13 +158,19 @@ def __init__(self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreT @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): + signature = inspect.signature(self.orig_forward) + args, kwargs = overwride_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + outputs = self.orig_forward(*args, **kwargs) # Filter out cross attention past key values filterd_outputs = {} for name, value in outputs.items(): - if config.torch_to_onnx_output_map.get(name, name) in config.outputs or ( - allow_past_in_outputs and name.startswith("past_key_values") + onnx_output_name = config.torch_to_onnx_output_map.get(name, name) + if ( + onnx_output_name in config.outputs + or (allow_past_in_outputs and name.startswith("past_key_values")) + or any(key.startswith(onnx_output_name) for key in config.outputs.keys()) ): if name != "past_key_values": if self.real_config._behavior == "decoder" and name == "encoder_last_hidden_state": @@ -147,14 +185,20 @@ def patched_forward(*args, **kwargs): filterd_outputs[name] = value elif self.real_config._behavior == "decoder" and self.real_config.use_past is True: filterd_outputs[name] = tuple([v[:2] for v in value]) + return filterd_outputs self.patched_forward = patched_forward class WavLMModelPatcher(ModelPatcher): - def __init__(self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"]): - super().__init__(config, model) + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) allow_past_in_outputs = ( hasattr(self.real_config, "use_present_in_outputs") and self.real_config.use_present_in_outputs @@ -162,23 +206,25 @@ def __init__(self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreT @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): - args = list(args) - - signature = inspect.signature(self.orig_forward) - output_attentions_index = list(signature.parameters.keys()).index("output_attentions") - + model_kwargs = self.model_kwargs # setting output_attentions=True in the model input to avoid calling torch.nn.functional.scaled_dot_product_attention # in https://github.com/huggingface/transformers/blob/v4.27.1/src/transformers/models/wavlm/modeling_wavlm.py#L496 # that calls https://github.com/pytorch/pytorch/blob/v2.0.0/torch/nn/functional.py#L5334 - args[output_attentions_index] = True + model_kwargs["output_attentions"] = True + signature = inspect.signature(self.orig_forward) + args, kwargs = overwride_arguments(args, kwargs, signature, model_kwargs=model_kwargs) + outputs = self.orig_forward(*args, **kwargs) filterd_outputs = {} - for k, v in outputs.items(): - if config.torch_to_onnx_output_map.get(k, k) in config.outputs or ( - allow_past_in_outputs and k.startswith("past_key_values") + for name, value in outputs.items(): + onnx_output_name = config.torch_to_onnx_output_map.get(name, name) + if ( + onnx_output_name in config.outputs + or (allow_past_in_outputs and name.startswith("past_key_values")) + or any(key.startswith(onnx_output_name) for key in config.outputs.keys()) ): - filterd_outputs[k] = v + filterd_outputs[name] = value return filterd_outputs self.patched_forward = patched_forward diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 85d7cb1d03..24a809a977 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -17,8 +17,8 @@ import copy from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union -import packaging import torch +from packaging import version from transformers.utils import is_tf_available, is_torch_available from ...utils import ( @@ -65,7 +65,7 @@ from diffusers import ModelMixin, StableDiffusionPipeline -def check_onnxruntime_requirements(minimum_version: packaging.version.Version): +def check_onnxruntime_requirements(minimum_version: version.Version): """ Checks that ONNX Runtime is installed and if version is recent enough. @@ -85,7 +85,7 @@ def check_onnxruntime_requirements(minimum_version: packaging.version.Version): " and relaunch the conversion." ) - ort_version = packaging.version.parse(onnxruntime.__version__) + ort_version = version.parse(onnxruntime.__version__) if ort_version < ORT_QUANTIZE_MINIMUM_VERSION: raise ImportError( f"We found an older version of ONNX Runtime ({onnxruntime.__version__}) " @@ -94,6 +94,89 @@ def check_onnxruntime_requirements(minimum_version: packaging.version.Version): ) +def _get_submodels_for_export_stable_diffusion( + pipeline: "StableDiffusionPipeline", +) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]: + """ + Returns the components of a Stable Diffusion model. + """ + from diffusers import StableDiffusionXLImg2ImgPipeline + + models_for_export = {} + if isinstance(pipeline, StableDiffusionXLImg2ImgPipeline): + projection_dim = pipeline.text_encoder_2.config.projection_dim + else: + projection_dim = pipeline.text_encoder.config.projection_dim + + # Text encoder + if pipeline.text_encoder is not None: + if isinstance(pipeline, StableDiffusionXLImg2ImgPipeline): + pipeline.text_encoder.config.output_hidden_states = True + models_for_export["text_encoder"] = pipeline.text_encoder + + # U-NET + # PyTorch does not support the ONNX export of torch.nn.functional.scaled_dot_product_attention + pipeline.unet.set_attn_processor(AttnProcessor()) + pipeline.unet.config.text_encoder_projection_dim = projection_dim + # The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score` + # https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571 + pipeline.unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) + models_for_export["unet"] = pipeline.unet + + # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565 + vae_encoder = copy.deepcopy(pipeline.vae) + if not version.parse(torch.__version__) >= version.parse("2.1.0"): + vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder) + vae_encoder.forward = lambda sample: {"latent_sample": vae_encoder.encode(x=sample)["latent_dist"].sample()} + models_for_export["vae_encoder"] = vae_encoder + + # VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600 + vae_decoder = copy.deepcopy(pipeline.vae) + if not version.parse(torch.__version__) >= version.parse("2.1.0"): + vae_decoder = override_diffusers_2_0_attn_processors(vae_decoder) + vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample) + models_for_export["vae_decoder"] = vae_decoder + + text_encoder_2 = getattr(pipeline, "text_encoder_2", None) + if text_encoder_2 is not None: + text_encoder_2.config.output_hidden_states = True + models_for_export["text_encoder_2"] = text_encoder_2 + + return models_for_export + + +def _get_submodels_for_export_decoder( + model: Union["PreTrainedModel", "TFPreTrainedModel"], use_past: bool +) -> Dict[str, Union["PreTrainedModel", "TFPreTrainedModel"]]: + """ + Returns the decoder part of the model. + """ + models_for_export = {} + + models_for_export[ONNX_DECODER_NAME] = model + if use_past: + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model + + return models_for_export + + +def _get_submodels_for_export_encoder_decoder( + model: Union["PreTrainedModel", "TFPreTrainedModel"], use_past: bool +) -> Dict[str, Union["PreTrainedModel", "TFPreTrainedModel"]]: + """ + Returns the encoder and decoder parts of the model. + """ + models_for_export = {} + + encoder_model = model.get_encoder() + models_for_export[ONNX_ENCODER_NAME] = encoder_model + models_for_export[ONNX_DECODER_NAME] = model + if use_past: + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model + + return models_for_export + + def get_encoder_decoder_models_for_export( model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "OnnxConfig" ) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "OnnxConfig"]]: @@ -110,18 +193,20 @@ def get_encoder_decoder_models_for_export( `Dict[str, Tuple[Union[`PreTrainedModel`, `TFPreTrainedModel`], `OnnxConfig`]: A Dict containing the model and onnx configs for the encoder and decoder parts of the model. """ - models_for_export = {} + models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=config.use_past) - encoder_model = model.get_encoder() encoder_onnx_config = config.with_behavior("encoder") - models_for_export[ONNX_ENCODER_NAME] = (encoder_model, encoder_onnx_config) + models_for_export[ONNX_ENCODER_NAME] = (models_for_export[ONNX_ENCODER_NAME], encoder_onnx_config) decoder_onnx_config = config.with_behavior("decoder", use_past=False) - models_for_export[ONNX_DECODER_NAME] = (model, decoder_onnx_config) + models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], decoder_onnx_config) if config.use_past: decoder_onnx_config_with_past = config.with_behavior("decoder", use_past=True) - models_for_export[ONNX_DECODER_WITH_PAST_NAME] = (model, decoder_onnx_config_with_past) + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( + models_for_export[ONNX_DECODER_WITH_PAST_NAME], + decoder_onnx_config_with_past, + ) return models_for_export @@ -148,16 +233,19 @@ def get_decoder_models_for_export( `Dict[str, Tuple[Union[PreTrainedModel, TFPreTrainedModel], OnnxConfig]]: A Dict containing the model and onnx configs for the encoder and decoder parts of the model. """ - models_for_export = {} + models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past) onnx_config = config.__class__( model.config, task=config.task, use_past_in_inputs=False, use_present_in_outputs=True ) - models_for_export[ONNX_DECODER_NAME] = (model, onnx_config) + models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config) if config.use_past: onnx_config_with_past = config.__class__(model.config, task=config.task, use_past=True) - models_for_export[ONNX_DECODER_WITH_PAST_NAME] = (model, onnx_config_with_past) + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( + models_for_export[ONNX_DECODER_WITH_PAST_NAME], + onnx_config_with_past, + ) return models_for_export @@ -176,30 +264,25 @@ def get_stable_diffusion_models_for_export( `Dict[str, Tuple[Union[`PreTrainedModel`, `TFPreTrainedModel`], `OnnxConfig`]: A Dict containing the model and onnx configs for the different components of the model. """ - models_for_export = {} + models_for_export = _get_submodels_for_export_stable_diffusion(pipeline) # Text encoder - text_encoder_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.text_encoder, exporter="onnx", task="feature-extraction" - ) - text_encoder_onnx_config = text_encoder_config_constructor(pipeline.text_encoder.config) - models_for_export["text_encoder"] = (pipeline.text_encoder, text_encoder_onnx_config) + if "text_encoder" in models_for_export: + text_encoder_config_constructor = TasksManager.get_exporter_config_constructor( + model=pipeline.text_encoder, exporter="onnx", task="feature-extraction" + ) + text_encoder_onnx_config = text_encoder_config_constructor(pipeline.text_encoder.config) + models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_onnx_config) # U-NET onnx_config_constructor = TasksManager.get_exporter_config_constructor( model=pipeline.unet, exporter="onnx", task="semantic-segmentation", model_type="unet" ) unet_onnx_config = onnx_config_constructor(pipeline.unet.config) - - # PyTorch does not support the ONNX export of torch.nn.functional.scaled_dot_product_attention - pipeline.unet.set_attn_processor(AttnProcessor()) - models_for_export["unet"] = (pipeline.unet, unet_onnx_config) + models_for_export["unet"] = (models_for_export["unet"], unet_onnx_config) # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565 - vae_encoder = copy.deepcopy(pipeline.vae) - if not packaging.version.parse(torch.__version__) >= packaging.version.parse("2.1.0"): - vae_encoder = override_diffusers_2_0_attn_processors(vae_encoder) - vae_encoder.forward = lambda sample: {"latent_sample": vae_encoder.encode(x=sample)["latent_dist"].sample()} + vae_encoder = models_for_export["vae_encoder"] vae_config_constructor = TasksManager.get_exporter_config_constructor( model=vae_encoder, exporter="onnx", task="semantic-segmentation", model_type="vae-encoder" ) @@ -207,16 +290,23 @@ def get_stable_diffusion_models_for_export( models_for_export["vae_encoder"] = (vae_encoder, vae_onnx_config) # VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600 - vae_decoder = copy.deepcopy(pipeline.vae) - if not packaging.version.parse(torch.__version__) >= packaging.version.parse("2.1.0"): - vae_decoder = override_diffusers_2_0_attn_processors(vae_decoder) - vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample) + vae_decoder = models_for_export["vae_decoder"] vae_config_constructor = TasksManager.get_exporter_config_constructor( model=vae_decoder, exporter="onnx", task="semantic-segmentation", model_type="vae-decoder" ) vae_onnx_config = vae_config_constructor(vae_decoder.config) models_for_export["vae_decoder"] = (vae_decoder, vae_onnx_config) + if "text_encoder_2" in models_for_export: + onnx_config_constructor = TasksManager.get_exporter_config_constructor( + model=pipeline.text_encoder_2, + exporter="onnx", + task="feature-extraction", + model_type="clip-text-with-projection", + ) + onnx_config = onnx_config_constructor(pipeline.text_encoder_2.config) + models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], onnx_config) + return models_for_export diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 2a04c1552f..3ea23755cc 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -161,7 +161,7 @@ class TasksManager: "object-detection": "AutoModelForObjectDetection", "question-answering": "AutoModelForQuestionAnswering", "image-classification": "AutoModelForImageClassification", - "image-segmentation": "AutoModelForImageSegmentation", + "image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"), "mask-generation": "AutoModel", "masked-im": "AutoModelForMaskedImageModeling", "semantic-segmentation": "AutoModelForSemanticSegmentation", @@ -171,6 +171,7 @@ class TasksManager: "audio-xvector": "AutoModelForAudioXVector", "image-to-text": "AutoModelForVision2Seq", "stable-diffusion": "StableDiffusionPipeline", + "stable-diffusion-xl": "StableDiffusionXLImg2ImgPipeline", "zero-shot-image-classification": "AutoModelForZeroShotImageClassification", "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", } @@ -232,6 +233,16 @@ class TasksManager: ("pt", "vision-encoder-decoder", "document-question-answering"): ("transformers", "VisionEncoderDecoderModel"), } + # TODO: why feature-extraction-with-past is here? + _ENCODER_DECODER_TASKS = ( + "text2text-generation", + "automatic-speech-recognition", + "image-to-text", + "feature-extraction-with-past", + "visual-question-answering", + "document-question-answering", + ) + _TASKS_TO_LIBRARY = { "conversational": "transformers", "document-question-answering": "transformers", @@ -257,6 +268,7 @@ class TasksManager: "image-to-text": "transformers", "sentence-similarity": "transformers", "stable-diffusion": "diffusers", + "stable-diffusion-xl": "diffusers", "summarization": "transformers", "visual-question-answering": "transformers", "zero-shot-classification": "transformers", @@ -380,6 +392,10 @@ class TasksManager: "feature-extraction", onnx="CLIPTextOnnxConfig", ), + "clip-text-with-projection": supported_tasks_mapping( + "feature-extraction", + onnx="CLIPTextWithProjectionOnnxConfig", + ), "codegen": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", @@ -632,6 +648,7 @@ class TasksManager: "mobilevit": supported_tasks_mapping( "feature-extraction", "image-classification", + "image-segmentation", onnx="MobileViTOnnxConfig", ), "mobilenet-v1": supported_tasks_mapping( @@ -767,12 +784,13 @@ class TasksManager: tflite="RoFormerTFLiteConfig", ), "sam": supported_tasks_mapping( - "mask-generation", + "feature-extraction", onnx="SamOnnxConfig", ), "segformer": supported_tasks_mapping( "feature-extraction", "image-classification", + "image-segmentation", "semantic-segmentation", onnx="SegformerOnnxConfig", ), @@ -926,7 +944,14 @@ class TasksManager: onnx="YolosOnnxConfig", ), } - _UNSUPPORTED_CLI_MODEL_TYPE = {"unet", "vae-encoder", "vae-decoder", "clip-text-model", "trocr"} + _UNSUPPORTED_CLI_MODEL_TYPE = { + "unet", + "vae-encoder", + "vae-decoder", + "clip-text-model", + "clip-text-with-projection", + "trocr", + } _SUPPORTED_CLI_MODEL_TYPE = set(_SUPPORTED_MODEL_TYPE.keys()) - _UNSUPPORTED_CLI_MODEL_TYPE @classmethod @@ -1001,7 +1026,7 @@ def get_supported_tasks_for_model_type( if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: raise KeyError( f"{model_type_and_model_name} is not supported yet. " - f"Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. " + f"Only {TasksManager._SUPPORTED_MODEL_TYPE} are supported. " f"If you want to support {model_type} please propose a PR or open up an issue." ) elif exporter not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]: @@ -1266,7 +1291,7 @@ def _infer_task_from_model_or_model_class( ( target_name.startswith("Auto"), target_name.startswith("TFAuto"), - target_name == "StableDiffusionPipeline", + "StableDiffusion" in target_name, ) ): if target_name == auto_cls_name: @@ -1309,8 +1334,10 @@ def _infer_task_from_model_name_or_path( model_info = huggingface_hub.model_info(model_name_or_path, revision=revision) if model_info.library_name == "diffusers": # TODO : getattr(model_info, "model_index") defining auto_model_class_name currently set to None - if "stable-diffusion" in model_info.tags: - inferred_task_name = "stable-diffusion" + for task in ("stable-diffusion-xl", "stable-diffusion"): + if task in model_info.tags: + inferred_task_name = task + break else: pipeline_tag = getattr(model_info, "pipeline_tag", None) # conversational is not a supported task per se, just an alias that may map to @@ -1471,7 +1498,11 @@ def get_model_from_task( elif device is None: device = torch.device("cpu") - if version.parse(torch.__version__) >= version.parse("2.0"): + # TODO : fix EulerDiscreteScheduler loading to enable for SD models + if ( + version.parse(torch.__version__) >= version.parse("2.0") + and TasksManager._TASKS_TO_LIBRARY[task.replace("-with-past", "")] != "diffusers" + ): with device: # Initialize directly in the requested device, to save allocation time. Especially useful for large # models to initialize on cuda device. diff --git a/optimum/modeling_base.py b/optimum/modeling_base.py index 50e8ee7152..ed2c4b466b 100644 --- a/optimum/modeling_base.py +++ b/optimum/modeling_base.py @@ -328,9 +328,13 @@ def from_pretrained( if config is None: if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME: if CONFIG_NAME in os.listdir(os.path.join(model_id, subfolder)): - config = AutoConfig.from_pretrained(os.path.join(model_id, subfolder, CONFIG_NAME)) + config = AutoConfig.from_pretrained( + os.path.join(model_id, subfolder, CONFIG_NAME), trust_remote_code=trust_remote_code + ) elif CONFIG_NAME in os.listdir(model_id): - config = AutoConfig.from_pretrained(os.path.join(model_id, CONFIG_NAME)) + config = AutoConfig.from_pretrained( + os.path.join(model_id, CONFIG_NAME), trust_remote_code=trust_remote_code + ) logger.info( f"config.json not found in the specified subfolder {subfolder}. Using the top level config.json." ) @@ -344,6 +348,7 @@ def from_pretrained( use_auth_token=use_auth_token, force_download=force_download, subfolder=subfolder, + trust_remote_code=trust_remote_code, ) elif isinstance(config, (str, os.PathLike)): config = cls._load_config( @@ -353,6 +358,7 @@ def from_pretrained( use_auth_token=use_auth_token, force_download=force_download, subfolder=subfolder, + trust_remote_code=trust_remote_code, ) if not export and trust_remote_code: diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 728b241c9b..62e32cfe71 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -67,9 +67,21 @@ if not is_diffusers_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - _import_structure[".utils.dummy_diffusers_objects"] = ["ORTStableDiffusionPipeline"] + _import_structure[".utils.dummy_diffusers_objects"] = [ + "ORTStableDiffusionPipeline", + "ORTStableDiffusionImg2ImgPipeline", + "ORTStableDiffusionInpaintPipeline", + "ORTStableDiffusionXLPipeline", + "ORTStableDiffusionXLImg2ImgPipeline", + ] else: - _import_structure["modeling_diffusion"] = ["ORTStableDiffusionPipeline"] + _import_structure["modeling_diffusion"] = [ + "ORTStableDiffusionPipeline", + "ORTStableDiffusionImg2ImgPipeline", + "ORTStableDiffusionInpaintPipeline", + "ORTStableDiffusionXLPipeline", + "ORTStableDiffusionXLImg2ImgPipeline", + ] # Direct imports for type-checking @@ -112,9 +124,21 @@ if not is_diffusers_available(): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ..utils.dummy_diffusers_objects import ORTStableDiffusionPipeline + from ..utils.dummy_diffusers_objects import ( + ORTStableDiffusionImg2ImgPipeline, + ORTStableDiffusionInpaintPipeline, + ORTStableDiffusionPipeline, + ORTStableDiffusionXLImg2ImgPipeline, + ORTStableDiffusionXLPipeline, + ) else: - from .modeling_diffusion import ORTStableDiffusionPipeline + from .modeling_diffusion import ( + ORTStableDiffusionImg2ImgPipeline, + ORTStableDiffusionInpaintPipeline, + ORTStableDiffusionPipeline, + ORTStableDiffusionXLImg2ImgPipeline, + ORTStableDiffusionXLPipeline, + ) else: import sys diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index bb20bcb8fb..59a21f944d 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -319,7 +319,7 @@ def forward( ) # TODO: fix transformers generate to have contiguous input_ids here already - # For an unknown reason, calling `contigous()` here is necessary to not have errors + # For an unknown reason, calling `contiguous()` here is necessary to not have errors # on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding. # I suspect the reason is the contiguous python list that messes something up? model_inputs = [input_ids.contiguous()] @@ -433,7 +433,13 @@ def __init__( ): super().__init__(session, parent_model) - if self.parent_model.use_merged is False and self.use_past is True: + # We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2 + # can be used but do not support KV caching for the cross-attention key/values, see: + # https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L302-L311 + # This attribute is used to avoid returning cross-attention KV-cache in this case. + self.no_cross_attention_cache = getattr(self.parent_model, "no_cross_attention_cache", False) + + if (not self.parent_model.use_merged and self.use_past) or self.no_cross_attention_cache: self.num_pkv = 2 else: # When using a merged model, we always have the same number of output whether we use past key values or not, @@ -533,12 +539,12 @@ def forward( model_inputs = [input_ids] - if "encoder_attention_mask" in self.input_names: - model_inputs.append(encoder_attention_mask) - if "encoder_hidden_states" in self.input_names: model_inputs.append(encoder_hidden_states) + if "encoder_attention_mask" in self.input_names: + model_inputs.append(encoder_attention_mask) + if past_key_values is not None: model_inputs += past_key_values @@ -688,7 +694,7 @@ def forward( # Tuple of tuple of length `n_layers`, with each tuple of length equal to: # * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention) # * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant) - if self.use_past is False or use_merged_no_cache: + if not self.use_past or use_merged_no_cache or self.no_cross_attention_cache: out_past_key_values = tuple( out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv) ) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 99e76efb0b..1ffc81d883 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -17,7 +17,7 @@ import shutil from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch from huggingface_hub import hf_hub_download @@ -36,6 +36,7 @@ from .base import ORTDecoder from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN from .modeling_ort import ORTModel +from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache from .utils import ( ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, @@ -315,6 +316,7 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: "PretrainedConfig", + init_cls: Type["ORTModelDecoder"], use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, @@ -514,7 +516,7 @@ def _from_pretrained( else: onnx_paths.append(decoder_merged_path) - return cls( + return init_cls( ort_inference_sessions[0], config, decoder_with_past_session=ort_inference_sessions[1], @@ -695,3 +697,53 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> def can_generate(self): """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" return True + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + config: "PretrainedConfig", + **kwargs, + ): + if config.model_type == "bloom": + return super()._from_pretrained(model_id, config, init_cls=ORTBloomForCausalLM, **kwargs) + return super()._from_pretrained(model_id, config, init_cls=ORTModelForCausalLM, **kwargs) + + +class ORTBloomForCausalLM(ORTModelForCausalLM): + # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + attention_mask = kwargs.get("attention_mask", None) + use_cache = kwargs.get("use_cache", None) + + # only last token for input_ids if past is not None + if past_key_values: + # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = bloom_convert_to_bloom_cache(past_key_values) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "position_ids": None, + "attention_mask": attention_mask, + } + + # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + standardized_past = bloom_convert_to_standard_cache(past, batch_size=len(beam_idx)) + + # Get a copy of `beam_idx` on all the devices where we need those indices. + device_to_beam_idx = { + past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past + } + reordered_past = tuple( + ( + layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), + layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), + ) + for layer_past in standardized_past + ) + return bloom_convert_to_bloom_cache(reordered_past) diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 3def21f7af..8a7b686f53 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -23,7 +23,13 @@ import numpy as np import torch -from diffusers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, StableDiffusionPipeline +from diffusers import ( + DDIMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + StableDiffusionPipeline, + StableDiffusionXLImg2ImgPipeline, +) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME from huggingface_hub import snapshot_download @@ -34,10 +40,16 @@ from ..exporters.onnx import main_export from ..onnx.utils import _get_external_data_paths from ..pipelines.diffusers.pipeline_stable_diffusion import StableDiffusionPipelineMixin +from ..pipelines.diffusers.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipelineMixin +from ..pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin +from ..pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin +from ..pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin from ..utils import ( + DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, DIFFUSION_MODEL_UNET_SUBFOLDER, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, + DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, ) from .modeling_ort import ORTModel from .utils import ( @@ -52,11 +64,12 @@ logger = logging.getLogger(__name__) -class ORTStableDiffusionPipeline(ORTModel, StableDiffusionPipelineMixin): +class ORTStableDiffusionPipelineBase(ORTModel): auto_model_class = StableDiffusionPipeline main_input_name = "input_ids" base_model_prefix = "onnx_model" config_name = "model_index.json" + sub_component_config_name = "config.json" def __init__( self, @@ -67,6 +80,9 @@ def __init__( tokenizer: CLIPTokenizer, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], feature_extractor: Optional[CLIPFeatureExtractor] = None, + vae_encoder_session: Optional[ort.InferenceSession] = None, + text_encoder_2_session: Optional[ort.InferenceSession] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, use_io_binding: Optional[bool] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, ): @@ -88,6 +104,8 @@ def __init__( A scheduler to be used in combination with the U-NET component to denoise the encoded image latents. feature_extractor (`Optional[CLIPFeatureExtractor]`, defaults to `None`): A model extracting features from generated images to be used as inputs for the `safety_checker` + vae_encoder_session (`Optional[ort.InferenceSession]`, defaults to `None`): + The ONNX Runtime inference session associated to the VAE encoder. use_io_binding (`Optional[bool]`, defaults to `None`): Whether to use IOBinding during inference to avoid memory copy between the host and devices. Defaults to `True` if the device is CUDA, otherwise defaults to `False`. @@ -102,28 +120,63 @@ def __init__( self._internal_dict = config self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self) self.vae_decoder_model_path = Path(vae_decoder_session._model_path) - self.text_encoder = ORTModelTextEncoder(text_encoder_session, self) - self.text_encoder_model_path = Path(text_encoder_session._model_path) self.unet = ORTModelUnet(unet_session, self) self.unet_model_path = Path(unet_session._model_path) + + if text_encoder_session is not None: + self.text_encoder_model_path = Path(text_encoder_session._model_path) + self.text_encoder = ORTModelTextEncoder(text_encoder_session, self) + else: + self.text_encoder_model_path = None + self.text_encoder = None + + if vae_encoder_session is not None: + self.vae_encoder_model_path = Path(vae_encoder_session._model_path) + self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self) + else: + self.vae_encoder_model_path = None + self.vae_encoder = None + + if text_encoder_2_session is not None: + self.text_encoder_2_model_path = Path(text_encoder_2_session._model_path) + self.text_encoder_2 = ORTModelTextEncoder(text_encoder_2_session, self) + else: + self.text_encoder_2_model_path = None + self.text_encoder_2 = None + self.tokenizer = tokenizer + self.tokenizer_2 = tokenizer_2 self.scheduler = scheduler self.feature_extractor = feature_extractor self.safety_checker = None + sub_models = { DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER: self.text_encoder, DIFFUSION_MODEL_UNET_SUBFOLDER: self.unet, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER: self.vae_decoder, + DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER: self.vae_encoder, + DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER: self.text_encoder_2, } + + # Modify config to keep the resulting model compatible with diffusers pipelines for name in sub_models.keys(): - self._internal_dict[name] = ("optimum", sub_models[name].__class__.__name__) + self._internal_dict[name] = ( + ("diffusers", "OnnxRuntimeModel") if sub_models[name] is not None else (None, None) + ) self._internal_dict.pop("vae", None) + if "block_out_channels" in self.vae_decoder.config: + self.vae_scale_factor = 2 ** (len(self.vae_decoder.config["block_out_channels"]) - 1) + else: + self.vae_scale_factor = 8 + @staticmethod def load_model( vae_decoder_path: Union[str, Path], text_encoder_path: Union[str, Path], unet_path: Union[str, Path], + vae_encoder_path: Optional[Union[str, Path]] = None, + text_encoder_2_path: Optional[Union[str, Path]] = None, provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, provider_options: Optional[Dict] = None, @@ -139,6 +192,10 @@ def load_model( The path to the text encoder ONNX model. unet_path (`Union[str, Path]`): The path to the U-NET ONNX model. + vae_encoder_path (`Union[str, Path]`, defaults to `None`): + The path to the VAE encoder ONNX model. + text_encoder_2_path (`Union[str, Path]`, defaults to `None`): + The path to the second text decoder ONNX model. provider (`str`, defaults to `"CPUExecutionProvider"`): ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/ for possible providers. @@ -148,30 +205,39 @@ def load_model( Provider option dictionary corresponding to the provider used. See available options for each provider: https://onnxruntime.ai/docs/api/c/group___global.html . Defaults to `None`. """ - vae_decoder_session = ORTModel.load_model(vae_decoder_path, provider, session_options, provider_options) - text_encoder_session = ORTModel.load_model(text_encoder_path, provider, session_options, provider_options) - unet_session = ORTModel.load_model(unet_path, provider, session_options, provider_options) + vae_decoder = ORTModel.load_model(vae_decoder_path, provider, session_options, provider_options) + unet = ORTModel.load_model(unet_path, provider, session_options, provider_options) - return vae_decoder_session, text_encoder_session, unet_session + sessions = { + "vae_encoder": vae_encoder_path, + "text_encoder": text_encoder_path, + "text_encoder_2": text_encoder_2_path, + } - def _save_pretrained( - self, - save_directory: Union[str, Path], - vae_decoder_file_name: str = ONNX_WEIGHTS_NAME, - text_encoder_file_name: str = ONNX_WEIGHTS_NAME, - unet_file_name: str = ONNX_WEIGHTS_NAME, - ): + for key, value in sessions.items(): + if value is not None and value.is_file(): + sessions[key] = ORTModel.load_model(value, provider, session_options, provider_options) + else: + sessions[key] = None + + return vae_decoder, sessions["text_encoder"], unet, sessions["vae_encoder"], sessions["text_encoder_2"] + + def _save_pretrained(self, save_directory: Union[str, Path]): save_directory = Path(save_directory) src_to_dst_path = { - self.vae_decoder_model_path: save_directory - / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER - / vae_decoder_file_name, - self.text_encoder_model_path: save_directory - / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER - / text_encoder_file_name, - self.unet_model_path: save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, + self.vae_decoder_model_path: save_directory / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / ONNX_WEIGHTS_NAME, + self.text_encoder_model_path: save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME, + self.unet_model_path: save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER / ONNX_WEIGHTS_NAME, } + sub_models_to_save = { + self.vae_encoder_model_path: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, + self.text_encoder_2_model_path: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, + } + for path, subfolder in sub_models_to_save.items(): + if path is not None: + src_to_dst_path[path] = save_directory / subfolder / ONNX_WEIGHTS_NAME + # TODO: Modify _get_external_data_paths to give dictionnary src_paths = list(src_to_dst_path.keys()) dst_paths = list(src_to_dst_path.values()) @@ -181,11 +247,18 @@ def _save_pretrained( for src_path, dst_path in zip(src_paths, dst_paths): dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copyfile(src_path, dst_path) + config_path = src_path.parent / self.sub_component_config_name + if config_path.is_file(): + shutil.copyfile(config_path, dst_path.parent / self.sub_component_config_name) + + self.scheduler.save_pretrained(save_directory / "scheduler") - self.tokenizer.save_pretrained(save_directory.joinpath("tokenizer")) - self.scheduler.save_pretrained(save_directory.joinpath("scheduler")) if self.feature_extractor is not None: - self.feature_extractor.save_pretrained(save_directory.joinpath("feature_extractor")) + self.feature_extractor.save_pretrained(save_directory / "feature_extractor") + if self.tokenizer is not None: + self.tokenizer.save_pretrained(save_directory / "tokenizer") + if self.tokenizer_2 is not None: + self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2") @classmethod def _from_pretrained( @@ -198,6 +271,8 @@ def _from_pretrained( vae_decoder_file_name: str = ONNX_WEIGHTS_NAME, text_encoder_file_name: str = ONNX_WEIGHTS_NAME, unet_file_name: str = ONNX_WEIGHTS_NAME, + vae_encoder_file_name: str = ONNX_WEIGHTS_NAME, + text_encoder_2_file_name: str = ONNX_WEIGHTS_NAME, local_files_only: bool = False, provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, @@ -210,12 +285,10 @@ def _from_pretrained( raise ValueError("The provider `'TensorrtExecutionProvider'` is not supported") model_id = str(model_id) - sub_models_to_load, _, _ = cls.extract_init_dict(config) - sub_models_names = set(sub_models_to_load.keys()).intersection({"feature_extractor", "tokenizer", "scheduler"}) - sub_models = {} + patterns = set(config.keys()) + sub_models_to_load = patterns.intersection({"feature_extractor", "tokenizer", "tokenizer_2", "scheduler"}) if not os.path.isdir(model_id): - patterns = set(config.keys()) patterns.update({"vae_encoder", "vae_decoder"}) allow_patterns = {os.path.join(k, "*") for k in patterns if not k.startswith("_")} allow_patterns.update( @@ -223,6 +296,8 @@ def _from_pretrained( vae_decoder_file_name, text_encoder_file_name, unet_file_name, + vae_encoder_file_name, + text_encoder_2_file_name, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name, @@ -239,8 +314,10 @@ def _from_pretrained( ignore_patterns=["*.msgpack", "*.safetensors", "*.bin"], ) new_model_save_dir = Path(model_id) - for name in sub_models_names: - library_name, library_classes = sub_models_to_load[name] + + sub_models = {} + for name in sub_models_to_load: + library_name, library_classes = config[name] if library_classes is not None: library = importlib.import_module(library_name) class_obj = getattr(library, library_classes) @@ -251,10 +328,14 @@ def _from_pretrained( else: sub_models[name] = load_method(new_model_save_dir) - inference_sessions = cls.load_model( + vae_decoder, text_encoder, unet, vae_encoder, text_encoder_2 = cls.load_model( vae_decoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name, text_encoder_path=new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name, unet_path=new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, + vae_encoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, + text_encoder_2_path=new_model_save_dir + / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER + / text_encoder_2_file_name, provider=provider, session_options=session_options, provider_options=provider_options, @@ -269,11 +350,16 @@ def _from_pretrained( ) return cls( - *inference_sessions, + vae_decoder_session=vae_decoder, + text_encoder_session=text_encoder, + unet_session=unet, config=config, - tokenizer=sub_models["tokenizer"], - scheduler=sub_models["scheduler"], - feature_extractor=sub_models.pop("feature_extractor", None), + tokenizer=sub_models.get("tokenizer", None), + scheduler=sub_models.get("scheduler"), + feature_extractor=sub_models.get("feature_extractor", None), + tokenizer_2=sub_models.get("tokenizer_2", None), + vae_encoder_session=vae_encoder, + text_encoder_2_session=text_encoder_2, use_io_binding=use_io_binding, model_save_dir=model_save_dir, ) @@ -346,12 +432,13 @@ def to(self, device: Union[torch.device, str, int]): self.vae_decoder.session.set_providers([provider], provider_options=[provider_options]) self.text_encoder.session.set_providers([provider], provider_options=[provider_options]) self.unet.session.set_providers([provider], provider_options=[provider_options]) + + if self.vae_encoder is not None: + self.vae_encoder.session.set_providers([provider], provider_options=[provider_options]) + self.providers = self.vae_decoder.session.get_providers() return self - def __call__(self, *args, **kwargs): - return StableDiffusionPipelineMixin.__call__(self, *args, **kwargs) - @classmethod def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs): return cls.load_config(config_name_or_path, **kwargs) @@ -367,11 +454,16 @@ class _ORTDiffusionModelPart: It has its own `onnxruntime.InferenceSession`, and can perform a forward pass. """ + CONFIG_NAME = "config.json" + def __init__(self, session: ort.InferenceSession, parent_model: ORTModel): self.session = session self.parent_model = parent_model self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} + config_path = Path(session._model_path).parent / self.CONFIG_NAME + self.config = self.parent_model._dict_from_json_file(config_path) if config_path.is_file() else {} + self.input_dtype = {inputs.name: _ORT_TO_NP_TYPE[inputs.type] for inputs in self.session.get_inputs()} @property def device(self): @@ -397,14 +489,26 @@ def forward(self, input_ids: np.ndarray): class ORTModelUnet(_ORTDiffusionModelPart): def __init__(self, session: ort.InferenceSession, parent_model: ORTModel): super().__init__(session, parent_model) - self.input_dtype = {inputs.name: _ORT_TO_NP_TYPE[inputs.type] for inputs in self.session.get_inputs()} - def forward(self, sample: np.ndarray, timestep: np.ndarray, encoder_hidden_states: np.ndarray): + def forward( + self, + sample: np.ndarray, + timestep: np.ndarray, + encoder_hidden_states: np.ndarray, + text_embeds: Optional[np.ndarray] = None, + time_ids: Optional[np.ndarray] = None, + ): onnx_inputs = { "sample": sample, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states, } + + if text_embeds is not None: + onnx_inputs["text_embeds"] = text_embeds + if time_ids is not None: + onnx_inputs["time_ids"] = time_ids + outputs = self.session.run(None, onnx_inputs) return outputs @@ -416,3 +520,76 @@ def forward(self, latent_sample: np.ndarray): } outputs = self.session.run(None, onnx_inputs) return outputs + + +class ORTModelVaeEncoder(_ORTDiffusionModelPart): + def forward(self, sample: np.ndarray): + onnx_inputs = { + "sample": sample, + } + outputs = self.session.run(None, onnx_inputs) + return outputs + + +class ORTStableDiffusionPipeline(ORTStableDiffusionPipelineBase, StableDiffusionPipelineMixin): + def __call__(self, *args, **kwargs): + return StableDiffusionPipelineMixin.__call__(self, *args, **kwargs) + + +class ORTStableDiffusionImg2ImgPipeline(ORTStableDiffusionPipelineBase, StableDiffusionImg2ImgPipelineMixin): + def __call__(self, *args, **kwargs): + return StableDiffusionImg2ImgPipelineMixin.__call__(self, *args, **kwargs) + + +class ORTStableDiffusionInpaintPipeline(ORTStableDiffusionPipelineBase, StableDiffusionInpaintPipelineMixin): + def __call__(self, *args, **kwargs): + return StableDiffusionInpaintPipelineMixin.__call__(self, *args, **kwargs) + + +class ORTStableDiffusionXLPipelineBase(ORTStableDiffusionPipelineBase): + auto_model_class = StableDiffusionXLImg2ImgPipeline + + def __init__( + self, + vae_decoder_session: ort.InferenceSession, + text_encoder_session: ort.InferenceSession, + unet_session: ort.InferenceSession, + config: Dict[str, Any], + tokenizer: CLIPTokenizer, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + feature_extractor: Optional[CLIPFeatureExtractor] = None, + vae_encoder_session: Optional[ort.InferenceSession] = None, + text_encoder_2_session: Optional[ort.InferenceSession] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + use_io_binding: Optional[bool] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + ): + super().__init__( + vae_decoder_session=vae_decoder_session, + text_encoder_session=text_encoder_session, + unet_session=unet_session, + config=config, + tokenizer=tokenizer, + scheduler=scheduler, + feature_extractor=feature_extractor, + vae_encoder_session=vae_encoder_session, + text_encoder_2_session=text_encoder_2_session, + tokenizer_2=tokenizer_2, + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + ) + + # additional invisible-watermark dependency for SD XL + from ..pipelines.diffusers.watermark import StableDiffusionXLWatermarker + + self.watermark = StableDiffusionXLWatermarker() + + +class ORTStableDiffusionXLPipeline(ORTStableDiffusionXLPipelineBase, StableDiffusionXLPipelineMixin): + def __call__(self, *args, **kwargs): + return StableDiffusionXLPipelineMixin.__call__(self, *args, **kwargs) + + +class ORTStableDiffusionXLImg2ImgPipeline(ORTStableDiffusionXLPipelineBase, StableDiffusionXLImg2ImgPipelineMixin): + def __call__(self, *args, **kwargs): + return StableDiffusionXLImg2ImgPipelineMixin.__call__(self, *args, **kwargs) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 10d6beb118..1784766c6a 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -170,7 +170,6 @@ class ORTModel(OptimizedModel): @classproperty def export_feature(cls): logger.warning(f"{cls.__name__}.export_feature is deprecated, and will be removed in optimum 2.0.") - try: feature = TasksManager.infer_task_from_model(cls.auto_model_class) except ValueError: @@ -207,7 +206,9 @@ def shared_attributes_init( ) self.providers = model.get_providers() - self._device = get_device_for_provider(self.providers[0]) + self._device = get_device_for_provider( + self.providers[0], provider_options=model.get_provider_options()[self.providers[0]] + ) # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it # would end-up removing the directory containing the underlying ONNX model. diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index cdd8d1b6cd..ee09713390 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -1244,6 +1244,11 @@ def __init__( generation_config: Optional[GenerationConfig] = None, **kwargs, ): + # There are probably other archs that do not support cross attention KV cache, but only + # this one seem popular on the Hub. + if config.decoder.model_type == "gpt2": + self.no_cross_attention_cache = True + super().__init__( encoder_session, decoder_session, diff --git a/optimum/onnxruntime/models/__init__.py b/optimum/onnxruntime/models/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/optimum/onnxruntime/models/__init__.py @@ -0,0 +1 @@ + diff --git a/optimum/onnxruntime/models/bloom.py b/optimum/onnxruntime/models/bloom.py new file mode 100644 index 0000000000..4608f153a0 --- /dev/null +++ b/optimum/onnxruntime/models/bloom.py @@ -0,0 +1,44 @@ +from typing import TYPE_CHECKING, Tuple + + +if TYPE_CHECKING: + import torch + + +def bloom_convert_to_standard_cache( + past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]], batch_size: int +) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]: + """ + Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, + num_heads, ...])) + """ + batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape + num_heads = batch_size_times_num_heads // batch_size + # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] + # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size, num_heads, head_dim, seq_length), + layer_past[1].view(batch_size, num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + +def bloom_convert_to_bloom_cache( + past_key_value: Tuple[Tuple["torch.Tensor", "torch.Tensor"]] +) -> Tuple[Tuple["torch.Tensor", "torch.Tensor"]]: + """ + Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) + """ + batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape + batch_size_times_num_heads = batch_size * num_heads + # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] + # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 02fbf161f0..f8f6acbbdb 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -173,15 +173,14 @@ def wrap_onnx_config_for_loss(onnx_config: OnnxConfig) -> OnnxConfig: return OnnxConfigWithLoss(onnx_config) -def get_device_for_provider(provider: str) -> torch.device: +def get_device_for_provider(provider: str, provider_options: Dict) -> torch.device: """ Gets the PyTorch device (CPU/CUDA) associated with an ONNX Runtime provider. """ - return ( - torch.device("cuda:0") - if provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider"] - else torch.device("cpu") - ) + if provider in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]: + return torch.device(f"cuda:{provider_options['device_id']}") + else: + return torch.device("cpu") def get_provider_for_device(device: torch.device) -> str: diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion.py index 77c8ab81e3..0f5b3c3b33 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion.py @@ -20,60 +20,78 @@ import torch from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from .pipeline_utils import DiffusionPipelineMixin +from .pipeline_utils import DiffusionPipelineMixin, rescale_noise_cfg logger = logging.getLogger(__name__) class StableDiffusionPipelineMixin(DiffusionPipelineMixin): - # Copied from https://github.com/huggingface/diffusers/blob/v0.12.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L115 + # Copied from https://github.com/huggingface/diffusers/blob/v0.17.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L114 def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: int, do_classifier_free_guidance: bool, - negative_prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, list]], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`): + prompt (`Union[str, List[str]]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`): + negative_prompt (`Optional[Union[str, list]]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + if prompt_embeds is None: + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size @@ -93,7 +111,7 @@ def _encode_prompt( else: uncond_tokens = negative_prompt - max_length = text_input_ids.shape[-1] + max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", @@ -102,6 +120,8 @@ def _encode_prompt( return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] + + if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. @@ -111,31 +131,17 @@ def _encode_prompt( return prompt_embeds - # Adapted from https://github.com/huggingface/diffusers/blob/v0.12.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L192 - def __call__( + # Copied from https://github.com/huggingface/diffusers/blob/v0.17.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L217 + def check_inputs( self, prompt: Union[str, List[str]], - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[np.random.RandomState] = None, - latents: Optional[np.ndarray] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, np.ndarray], None]] = None, - callback_steps: Optional[int] = 1, + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, ): - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -147,6 +153,155 @@ def __call__( f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = generator.randn(*shape).astype(dtype) + elif latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + return latents + + # Adapted from https://github.com/huggingface/diffusers/blob/v0.17.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L264 + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + guidance_rescale: float = 0.0, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to None): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`Optional[int]`, defaults to None): + The height in pixels of the generated image. + width (`Optional[int]`, defaults to None): + The width in pixels of the generated image. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`Optional[np.random.RandomState]`, defaults to `None`):: + A np.random.RandomState to make generation deterministic. + latents (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (Optional[Callable], defaults to `None`): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + guidance_rescale (`float`, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `Ο†` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + height = height or self.unet.config.get("sample_size", 64) * self.vae_scale_factor + width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor + + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if generator is None: generator = np.random @@ -156,21 +311,27 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 prompt_embeds = self._encode_prompt( - prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) - # get the initial random noise unless the user supplied it - latents_dtype = prompt_embeds.dtype - latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) - if latents is None: - latents = generator.randn(*latents_shape).astype(latents_dtype) - elif latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - # set timesteps self.scheduler.set_timesteps(num_inference_steps) - - latents = latents * np.float64(self.scheduler.init_noise_sigma) + timesteps = self.scheduler.timesteps + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + self.unet.config.get("in_channels", 4), + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -184,7 +345,8 @@ def __call__( # Adapted from diffusers to extend it for other runtimes than ORT timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) @@ -199,6 +361,9 @@ def __call__( if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( @@ -207,28 +372,40 @@ def __call__( latents = scheduler_output.prev_sample.numpy() # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - latents = 1 / 0.18215 * latents - # image = self.vae_decoder(latent_sample=latents)[0] - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] - ) + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) + if output_type == "latent": + image = latents + has_nsfw_concept = None + else: + latents = 1 / self.vae_decoder.config.get("scaling_factor", 0.18215) * latents + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + # TODO: add image_processor + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + image, has_nsfw_concept = self.run_safety_checker(image) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - if self.safety_checker is not None: + def run_safety_checker(self, image): + if self.safety_checker is None: + has_nsfw_concept = None + else: safety_checker_input = self.feature_extractor( self.numpy_to_pil(image), return_tensors="np" ).pixel_values.astype(image.dtype) - # Adapted from diffusers (removed) - # image, has_nsfw_concepts = self.safety_checker(clip_input=safety_checker_input, images=image) - - # safety_checker does not support batched inputs yet images, has_nsfw_concept = [], [] for i in range(image.shape[0]): image_i, has_nsfw_concept_i = self.safety_checker( @@ -237,13 +414,5 @@ def __call__( images.append(image_i) has_nsfw_concept.append(has_nsfw_concept_i[0]) image = np.concatenate(images) - else: - has_nsfw_concept = None - - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return image, has_nsfw_concept diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py new file mode 100644 index 0000000000..d2c23b2b04 --- /dev/null +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py @@ -0,0 +1,305 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils import deprecate + +from .pipeline_stable_diffusion import StableDiffusionPipelineMixin +from .pipeline_utils import preprocess + + +logger = logging.getLogger(__name__) + + +class StableDiffusionImg2ImgPipelineMixin(StableDiffusionPipelineMixin): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt: Union[str, List[str]], + strength: float, + callback_steps: int, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionImg2ImgPipeline.__call__ + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[np.ndarray, PIL.Image.Image] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to None): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`Union[np.ndarray, PIL.Image.Image]`): + `Image`, or tensor representing an image batch which will be upscaled. + strength (`float`, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`Optional[np.random.RandomState]`, defaults to `None`):: + A np.random.RandomState to make generation deterministic. + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (Optional[Callable], defaults to `None`): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + + image = preprocess(image) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + latents_dtype = prompt_embeds.dtype + image = image.astype(latents_dtype) + # encode the init image into latents and scale the latents + init_latents = self.vae_encoder(sample=image)[0] + + scaling_factor = self.vae_decoder.config.get("scaling_factor", 0.18215) + init_latents = scaling_factor * init_latents + + if isinstance(prompt, str): + prompt = [prompt] + if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = len(prompt) // init_latents.shape[0] + init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0) + elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." + ) + else: + init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps.numpy()[-init_timestep] + timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(latents_dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ) + init_latents = init_latents.numpy() + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + latents = init_latents + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].numpy() + + # Adapted from diffusers to extend it for other runtimes than ORT + timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + has_nsfw_concept = None + else: + latents = 1 / scaling_factor * latents + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + image, has_nsfw_concept = self.run_safety_checker(image) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py new file mode 100644 index 0000000000..e2a7ac7c9e --- /dev/null +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py @@ -0,0 +1,347 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils import PIL_INTERPOLATION + +from .pipeline_stable_diffusion import StableDiffusionPipelineMixin + + +logger = logging.getLogger(__name__) + + +def prepare_mask_and_masked_image(image, mask, latents_shape, vae_scale_factor): + image = np.array( + image.convert("RGB").resize((latents_shape[1] * vae_scale_factor, latents_shape[0] * vae_scale_factor)) + ) + image = image[None].transpose(0, 3, 1, 2) + image = image.astype(np.float32) / 127.5 - 1.0 + + image_mask = np.array( + mask.convert("L").resize((latents_shape[1] * vae_scale_factor, latents_shape[0] * vae_scale_factor)) + ) + masked_image = image * (image_mask < 127.5) + + mask = mask.resize((latents_shape[1], latents_shape[0]), PIL_INTERPOLATION["nearest"]) + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask, masked_image + + +class StableDiffusionInpaintPipelineMixin(StableDiffusionPipelineMixin): + # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: PIL.Image.Image, + mask_image: PIL.Image.Image, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`Union[str, List[str]]`): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be upscaled. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing a masked image batch which will be upscaled. + height (`Optional[int]`, defaults to None): + The height in pixels of the generated image. + width (`Optional[int]`, defaults to None): + The width in pixels of the generated image. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`Optional[np.random.RandomState]`, defaults to `None`):: + A np.random.RandomState to make generation deterministic. + latents (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (Optional[Callable], defaults to `None`): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + height = height or self.unet.config.get("sample_size", 64) * self.vae_scale_factor + width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor + + # check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + num_channels_latents = self.vae_decoder.config.get("latent_channels", 4) + num_channels_unet = self.unet.config.get("in_channels", 9) + latents_shape = ( + batch_size * num_images_per_prompt, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + latents_dtype = prompt_embeds.dtype + if latents is None: + latents = generator.randn(*latents_shape).astype(latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + # prepare mask and masked_image + mask, masked_image = prepare_mask_and_masked_image( + image, mask_image, latents_shape[-2:], self.vae_scale_factor + ) + mask = mask.astype(latents.dtype) + masked_image = masked_image.astype(latents.dtype) + + masked_image_latents = self.vae_encoder(sample=masked_image)[0] + + scaling_factor = self.vae_decoder.config.get("scaling_factor", 0.18215) + masked_image_latents = scaling_factor * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt + mask = mask.repeat(batch_size * num_images_per_prompt, 0) + masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 0) + + mask = np.concatenate([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + np.concatenate([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # check that sizes of mask, masked image and latents match + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: expects" + f" {num_channels_unet} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: + raise ValueError( + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {num_channels_unet}." + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # Adapted from diffusers to extend it for other runtimes than ORT + timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + # concat latents, mask, masked_image_latnets in the channel dimension + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + if num_channels_unet == 9: + latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1) + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)[ + 0 + ] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + has_nsfw_concept = None + else: + latents = 1 / scaling_factor * latents + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + image, has_nsfw_concept = self.run_safety_checker(image) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py new file mode 100644 index 0000000000..4c8c015fed --- /dev/null +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py @@ -0,0 +1,499 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput + +from .pipeline_utils import DiffusionPipelineMixin, rescale_noise_cfg + + +logger = logging.getLogger(__name__) + + +class StableDiffusionXLPipelineMixin(DiffusionPipelineMixin): + # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, list]], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`Union[str, List[str]]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # get prompt text embeddings + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.array_equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + input_ids=text_input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) + ) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds[-2] + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = np.concatenate(prompt_embeds_list, axis=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config["force_zeros_for_empty_prompt"] + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = np.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = np.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = text_encoder( + input_ids=uncond_input.input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) + ) + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[-2] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + negative_prompt_embeds_list.append(negative_prompt_embeds) + negative_prompt_embeds = np.concatenate(negative_prompt_embeds, axis=-1) + + pooled_prompt_embeds = np.repeat(pooled_prompt_embeds, num_images_per_prompt, axis=0) + negative_pooled_prompt_embeds = np.repeat(negative_pooled_prompt_embeds, num_images_per_prompt, axis=0) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs + def check_inputs( + self, + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = generator.randn(*shape).astype(dtype) + elif latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + return latents + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + extra_step_kwargs = {} + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: + extra_step_kwargs["eta"] = eta + + return extra_step_kwargs + + # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__ + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to None): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`Optional[int]`, defaults to None): + The height in pixels of the generated image. + width (`Optional[int]`, defaults to None): + The width in pixels of the generated image. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to 5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`Optional[np.random.RandomState]`, defaults to `None`):: + A np.random.RandomState to make generation deterministic. + latents (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (Optional[Callable], defaults to `None`): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + guidance_rescale (`float`, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `Ο†` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 0. Default height and width to unet + height = height or self.unet.config["sample_size"] * self.vae_scale_factor + width = width or self.unet.config["sample_size"] * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + self.unet.config.get("in_channels", 4), + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = (original_size + crops_coords_top_left + target_size,) + add_time_ids = np.array(add_time_ids, dtype=prompt_embeds.dtype) + + if do_classifier_free_guidance: + prompt_embeds = np.concatenate((negative_prompt_embeds, prompt_embeds), axis=0) + add_text_embeds = np.concatenate((negative_pooled_prompt_embeds, add_text_embeds), axis=0) + add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0) + add_time_ids = np.repeat(add_time_ids, batch_size * num_images_per_prompt, axis=0) + + # Adapted from diffusers to extend it for other runtimes than ORT + timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet( + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_embeds=add_text_embeds, + time_ids=add_time_ids, + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + else: + latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215) + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + image = self.watermark.apply_watermark(image) + + # TODO: add image_processor + image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1)) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py new file mode 100644 index 0000000000..4a2b48d38e --- /dev/null +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py @@ -0,0 +1,506 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput + +from .pipeline_utils import DiffusionPipelineMixin, preprocess, rescale_noise_cfg + + +logger = logging.getLogger(__name__) + + +class StableDiffusionXLImg2ImgPipelineMixin(DiffusionPipelineMixin): + # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, list]], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`Union[str, List[str]]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # get prompt text embeddings + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.array_equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + input_ids=text_input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) + ) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds[-2] + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = np.concatenate(prompt_embeds_list, axis=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config["force_zeros_for_empty_prompt"] + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = np.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = np.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + + negative_prompt_embeds = text_encoder( + input_ids=uncond_input.input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) + ) + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[-2] + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + negative_prompt_embeds_list.append(negative_prompt_embeds) + negative_prompt_embeds = np.concatenate(negative_prompt_embeds_list, axis=-1) + + pooled_prompt_embeds = np.repeat(pooled_prompt_embeds, num_images_per_prompt, axis=0) + negative_pooled_prompt_embeds = np.repeat(negative_pooled_prompt_embeds, num_images_per_prompt, axis=0) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt: Union[str, List[str]], + strength: float, + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :].numpy() + + return timesteps, num_inference_steps - t_start + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None): + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + else: + init_latents = self.vae_encoder(sample=image)[0] * self.vae_decoder.config.get("scaling_factor", 0.18215) + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = np.concatenate([init_latents] * additional_image_per_prompt, axis=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = np.concatenate([init_latents], axis=0) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timestep) + ) + return init_latents.numpy() + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.config.get("requires_aesthetics_score"): + add_time_ids = (original_size + crops_coords_top_left + (aesthetic_score,),) + add_neg_time_ids = (original_size + crops_coords_top_left + (negative_aesthetic_score,),) + else: + add_time_ids = (original_size + crops_coords_top_left + target_size,) + add_neg_time_ids = (original_size + crops_coords_top_left + target_size,) + + add_time_ids = np.array(add_time_ids, dtype=dtype) + add_neg_time_ids = np.array(add_neg_time_ids, dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__ + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[np.ndarray, PIL.Image.Image] = None, + strength: float = 0.3, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to None): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`Union[np.ndarray, PIL.Image.Image]`): + `Image`, or tensor representing an image batch which will be upscaled. + strength (`float`, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to 5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`Optional[np.random.RandomState]`, defaults to `None`):: + A np.random.RandomState to make generation deterministic. + latents (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (Optional[Callable], defaults to `None`): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + guidance_rescale (`float`, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `Ο†` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 1. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 2. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 3. Preprocess image + image = preprocess(image) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + latent_timestep = np.repeat(timesteps[:1], batch_size * num_images_per_prompt, axis=0) + timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) + + latents_dtype = prompt_embeds.dtype + image = image.astype(latents_dtype) + + # 5. Prepare latent variables + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, latents_dtype, generator + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = {} + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: + extra_step_kwargs["eta"] = eta + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + ) + + if do_classifier_free_guidance: + prompt_embeds = np.concatenate((negative_prompt_embeds, prompt_embeds), axis=0) + add_text_embeds = np.concatenate((negative_pooled_prompt_embeds, add_text_embeds), axis=0) + add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0) + add_time_ids = np.repeat(add_time_ids, batch_size * num_images_per_prompt, axis=0) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet( + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_embeds=add_text_embeds, + time_ids=add_time_ids, + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + else: + latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215) + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + image = self.watermark.apply_watermark(image) + + # TODO: add image_processor + image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1)) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/optimum/pipelines/diffusers/pipeline_utils.py b/optimum/pipelines/diffusers/pipeline_utils.py index 7092003875..27cc684cb3 100644 --- a/optimum/pipelines/diffusers/pipeline_utils.py +++ b/optimum/pipelines/diffusers/pipeline_utils.py @@ -13,7 +13,13 @@ # limitations under the License. +import warnings + +import numpy as np +import PIL +import torch from diffusers import ConfigMixin +from diffusers.utils import PIL_INTERPOLATION from PIL import Image from tqdm.auto import tqdm @@ -51,3 +57,46 @@ def progress_bar(self, iterable=None, total=None): return tqdm(total=total, **self._progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") + + +# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 +def preprocess(image): + warnings.warn( + ( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead" + ), + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image.cpu().numpy() + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0).cpu().numpy() + return image + + +# Adapted from https://github.com/huggingface/diffusers/blob/v0.18.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L58 +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = np.std(noise_pred_text, axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) + std_cfg = np.std(noise_cfg, axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg diff --git a/optimum/pipelines/diffusers/watermark.py b/optimum/pipelines/diffusers/watermark.py new file mode 100644 index 0000000000..e07b4829c6 --- /dev/null +++ b/optimum/pipelines/diffusers/watermark.py @@ -0,0 +1,27 @@ +import numpy as np +from imwatermark import WatermarkEncoder + + +WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] + + +# Adapted from https://github.com/huggingface/diffusers/blob/v0.18.1/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L12 +class StableDiffusionXLWatermarker: + def __init__(self): + self.watermark = WATERMARK_BITS + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def apply_watermark(self, images: np.array): + # can't encode images that are smaller than 256 + if images.shape[-1] < 256: + return images + + images = (255 * (images / 2 + 0.5)).transpose((0, 2, 3, 1)) + + images = np.array([self.encoder.encode(image, "dwtDct") for image in images]).transpose((0, 3, 1, 2)) + + np.clip(2 * (images / 255 - 0.5), -1.0, 1.0, out=images) + + return images diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 9858b2b9af..df0db3f39a 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -15,10 +15,12 @@ from .constant import ( CONFIG_NAME, + DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, DIFFUSION_MODEL_UNET_SUBFOLDER, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, + ONNX_WEIGHTS_NAME, ) from .import_utils import ( DIFFUSERS_MINIMUM_VERSION, diff --git a/optimum/utils/constant.py b/optimum/utils/constant.py index b5f64caf07..4497b5246d 100644 --- a/optimum/utils/constant.py +++ b/optimum/utils/constant.py @@ -18,3 +18,5 @@ DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER = "text_encoder" DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER = "vae_decoder" DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER = "vae_encoder" +DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER = "text_encoder_2" +ONNX_WEIGHTS_NAME = "model.onnx" diff --git a/optimum/utils/dummy_diffusers_objects.py b/optimum/utils/dummy_diffusers_objects.py index cd72b1fe10..f85a0987d4 100644 --- a/optimum/utils/dummy_diffusers_objects.py +++ b/optimum/utils/dummy_diffusers_objects.py @@ -24,3 +24,47 @@ def __init__(self, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["diffusers"]) + + +class ORTStableDiffusionImg2ImgPipeline(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTStableDiffusionInpaintPipeline(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTStableDiffusionXLPipeline(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTStableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 9d78eccd82..5e6049bd41 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -34,7 +34,7 @@ TORCH_MINIMUM_VERSION = packaging.version.parse("1.11.0") TRANSFORMERS_MINIMUM_VERSION = packaging.version.parse("4.25.0") -DIFFUSERS_MINIMUM_VERSION = packaging.version.parse("0.17.0") +DIFFUSERS_MINIMUM_VERSION = packaging.version.parse("0.18.0") # This is the minimal required version to support some ONNX Runtime features diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 30c79052e6..d062a29d7e 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -605,7 +605,11 @@ class DummyTimestepInputGenerator(DummyInputGenerator): Generates dummy time step inputs. """ - SUPPORTED_INPUT_NAMES = ("timestep",) + SUPPORTED_INPUT_NAMES = ( + "timestep", + "text_embeds", + "time_ids", + ) def __init__( self, @@ -617,7 +621,8 @@ def __init__( ): self.task = task self.vocab_size = normalized_config.vocab_size - + self.text_encoder_projection_dim = normalized_config.text_encoder_projection_dim + self.time_ids = 5 if normalized_config.requires_aesthetics_score else 6 if random_batch_size_range: low, high = random_batch_size_range self.batch_size = random.randint(low, high) @@ -626,7 +631,12 @@ def __init__( def generate(self, input_name: str, framework: str = "pt"): shape = [self.batch_size] - return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework) + + if input_name == "timestep": + return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework) + + shape.append(self.text_encoder_projection_dim if input_name == "text_embeds" else self.time_ids) + return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework) class DummyLabelsGenerator(DummyInputGenerator): diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index b9838988fe..4454c3348e 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -30,7 +30,7 @@ class NormalizedConfig: """ def __init__(self, config: Union[PretrainedConfig, Dict], allow_new: bool = False, **kwargs): - self.config = config if isinstance(config, PretrainedConfig) else PretrainedConfig.from_dict(config) + self.config = config for key, value in kwargs.items(): if allow_new or hasattr(self, key.upper()): setattr(self, key.upper(), value) @@ -140,6 +140,9 @@ def __getattr__(self, attr_name): MPTNormalizedTextConfig = NormalizedTextConfig.with_args( num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers" ) +GPTBigCodeNormalizedTextConfig = NormalizedTextConfig.with_args( + num_attention_heads="n_head", hidden_size="n_embd", num_layers="n_layer" +) WhisperLikeNormalizedTextConfig = NormalizedTextConfig.with_args( hidden_size="d_model", @@ -242,6 +245,7 @@ class NormalizedConfigManager: "xlm-roberta": NormalizedTextConfig, "yolos": NormalizedVisionConfig, "mpt": MPTNormalizedTextConfig, + "gpt_bigcode": GPTBigCodeNormalizedTextConfig, } @classmethod diff --git a/optimum/utils/save_utils.py b/optimum/utils/save_utils.py index 3d5550a2fd..b461df192b 100644 --- a/optimum/utils/save_utils.py +++ b/optimum/utils/save_utils.py @@ -24,26 +24,41 @@ logger = logging.getLogger(__name__) -def maybe_load_preprocessors(src_name_or_path: Union[str, Path], subfolder: str = "") -> List: +def maybe_load_preprocessors( + src_name_or_path: Union[str, Path], subfolder: str = "", trust_remote_code: bool = False +) -> List: preprocessors = [] try: - preprocessors.append(AutoTokenizer.from_pretrained(src_name_or_path, subfolder=subfolder)) + preprocessors.append( + AutoTokenizer.from_pretrained(src_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code) + ) except Exception: pass try: - preprocessors.append(AutoProcessor.from_pretrained(src_name_or_path, subfolder=subfolder)) + preprocessors.append( + AutoProcessor.from_pretrained(src_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code) + ) except Exception: pass try: - preprocessors.append(AutoFeatureExtractor.from_pretrained(src_name_or_path, subfolder=subfolder)) + preprocessors.append( + AutoFeatureExtractor.from_pretrained( + src_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code + ) + ) except Exception: pass return preprocessors -def maybe_save_preprocessors(src_name_or_path: Union[str, Path], dest_dir: Union[str, Path], src_subfolder: str = ""): +def maybe_save_preprocessors( + src_name_or_path: Union[str, Path], + dest_dir: Union[str, Path], + src_subfolder: str = "", + trust_remote_code: bool = False, +): """ Saves the tokenizer, the processor and the feature extractor when found in `src_dir` in `dest_dir`. @@ -55,10 +70,14 @@ def maybe_save_preprocessors(src_name_or_path: Union[str, Path], dest_dir: Union src_subfolder (`str`, defaults to `""`): In case the preprocessor files are located inside a subfolder of the model directory / repo on the Hugging Face Hub, you can specify the subfolder name here. + trust_remote_code (`bool`, defaults to `False`): + Whether to allow to save preprocessors that is allowed to run arbitrary code. Use this option at your own risk. """ if not isinstance(dest_dir, Path): dest_dir = Path(dest_dir) dest_dir.mkdir(exist_ok=True) - for preprocessor in maybe_load_preprocessors(src_name_or_path, subfolder=src_subfolder): + for preprocessor in maybe_load_preprocessors( + src_name_or_path, subfolder=src_subfolder, trust_remote_code=trust_remote_code + ): preprocessor.save_pretrained(dest_dir) diff --git a/optimum/version.py b/optimum/version.py index bf3cf3c6e7..78887e5a1c 100644 --- a/optimum/version.py +++ b/optimum/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.8.8.dev0" +__version__ = "1.10.1.dev0" diff --git a/setup.py b/setup.py index 5d08ff1d23..90b9ab89f1 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,8 @@ "torchvision", "diffusers>=0.17.0", "torchaudio", + "einops", + "invisible-watermark", ] QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241,<=0.0.259"] @@ -56,7 +58,8 @@ ], "exporters": ["onnx", "onnxruntime", "timm"], "exporters-gpu": ["onnx", "onnxruntime-gpu", "timm"], - "exporters-tf": ["tensorflow>=2.4,<2.11", "tf2onnx", "onnx", "onnxruntime", "timm", "h5py", "numpy<1.24.0"], + "exporters-tf": ["tensorflow>=2.4", "tf2onnx", "onnx", "onnxruntime", "timm", "h5py", "numpy<1.24.0"], + "diffusers": ["diffusers"], "intel": "optimum-intel", "openvino": "optimum-intel[openvino]", "nncf": "optimum-intel[nncf]", @@ -65,6 +68,7 @@ "habana": ["transformers<4.29.0", "optimum-habana"], "neuron": "optimum-neuron[neuron]", "neuronx": "optimum-neuron[neuronx]", + "furiosa": "optimum-furiosa", "dev": TESTS_REQUIRE + QUALITY_REQUIRE, "tests": TESTS_REQUIRE, "quality": QUALITY_REQUIRE, diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index d446a08385..3fb92ab126 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -49,9 +49,6 @@ def prepare_inputs_for_class(self, model_id: str, model_type: str, batch_size: i texts = ["a dummy input yeah!"] + ["and two"] * (batch_size - 1) inputs = tokenizer(texts, return_tensors="pt", padding=padding, max_length=20, **preprocessor_kwargs) - if model_type == "llama": - del inputs["token_type_ids"] - return inputs @parameterized.expand( @@ -158,9 +155,6 @@ def test_generation(self, test_name: str, model_type: str, batch_size: int, padd text.append("Please continue this my dear me") inp = tokenizer(text, return_tensors="pt", padding=padding, max_length=30) - if model_type == "llama": - del inp["token_type_ids"] - length = 50 result_vanilla = model.generate(**inp, num_beams=1, min_length=length, max_length=length) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index c28613c793..423875ca28 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -237,5 +237,6 @@ } PYTORCH_STABLE_DIFFUSION_MODEL = { - ("hf-internal-testing/tiny-stable-diffusion-torch"), + "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", + "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", } diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 39342cb4d5..a92a5d1881 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -27,12 +27,13 @@ from optimum.exporters.error_utils import MinimumVersionError from optimum.exporters.onnx.__main__ import main_export from optimum.onnxruntime import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME +from optimum.utils.testing_utils import require_diffusers if is_torch_available(): from optimum.exporters.tasks import TasksManager -from ..exporters_utils import PYTORCH_EXPORT_MODELS_TINY +from ..exporters_utils import PYTORCH_EXPORT_MODELS_TINY, PYTORCH_STABLE_DIFFUSION_MODEL def _get_models_to_test(export_models_dict: Dict): @@ -134,6 +135,31 @@ def test_all_models_tested(self): if len(missing_models_set) > 0: self.fail(f"Not testing all models. Missing models: {missing_models_set}") + @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @require_torch + @require_vision + @require_diffusers + def test_exporters_cli_pytorch_cpu_stable_diffusion(self, model_type: str, model_name: str): + self._onnx_export(model_name, model_type) + + @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @require_torch_gpu + @require_vision + @require_diffusers + @slow + @pytest.mark.run_slow + def test_exporters_cli_pytorch_gpu_stable_diffusion(self, model_type: str, model_name: str): + self._onnx_export(model_name, model_type, device="cuda") + + @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @require_torch_gpu + @require_vision + @require_diffusers + @slow + @pytest.mark.run_slow + def test_exporters_cli_fp16_stable_diffusion(self, model_type: str, model_name: str): + self._onnx_export(model_name, model_type, device="cuda", fp16=True) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) @require_torch @require_vision diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index d8a521770d..9a96d13e47 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import gc +import os from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict from unittest import TestCase from unittest.mock import patch +import onnx import pytest from parameterized import parameterized -from transformers import AutoConfig, is_tf_available, is_torch_available, set_seed +from transformers import AutoConfig, is_tf_available, is_torch_available from transformers.testing_utils import require_onnx, require_tf, require_torch, require_torch_gpu, require_vision, slow from optimum.exporters.error_utils import AtolError @@ -34,7 +36,11 @@ get_stable_diffusion_models_for_export, validate_models_outputs, ) -from optimum.utils import is_diffusers_available +from optimum.exporters.onnx.__main__ import main_export +from optimum.exporters.onnx.base import ConfigBehavior +from optimum.exporters.onnx.config import TextDecoderOnnxConfig +from optimum.exporters.onnx.model_configs import WhisperOnnxConfig +from optimum.utils import ONNX_WEIGHTS_NAME, DummyPastKeyValuesGenerator, NormalizedTextConfig from optimum.utils.testing_utils import grid_parameters, require_diffusers from ..exporters_utils import ( @@ -48,9 +54,6 @@ if is_torch_available() or is_tf_available(): from optimum.exporters.tasks import TasksManager -if is_diffusers_available(): - from diffusers import StableDiffusionPipeline - SEED = 42 @@ -308,6 +311,30 @@ def _onnx_export( gc.collect() + def _onnx_export_sd(self, model_type: str, model_name: str, device="cpu"): + pipeline = TasksManager.get_model_from_task(model_type, model_name, device=device) + models_and_onnx_configs = get_stable_diffusion_models_for_export(pipeline) + output_names = [os.path.join(name_dir, ONNX_WEIGHTS_NAME) for name_dir in models_and_onnx_configs] + model, _ = models_and_onnx_configs["vae_encoder"] + model.forward = lambda sample: {"latent_sample": model.encode(x=sample)["latent_dist"].parameters} + + with TemporaryDirectory() as tmpdirname: + _, onnx_outputs = export_models( + models_and_onnx_configs=models_and_onnx_configs, + opset=14, + output_dir=Path(tmpdirname), + output_names=output_names, + device=device, + ) + validate_models_outputs( + models_and_onnx_configs=models_and_onnx_configs, + onnx_named_outputs=onnx_outputs, + output_dir=Path(tmpdirname), + atol=1e-3, + onnx_files_subpaths=output_names, + use_subprocess=False, + ) + def test_all_models_tested(self): # make sure we test all models missing_models_set = TasksManager._SUPPORTED_CLI_MODEL_TYPE - set(PYTORCH_EXPORT_MODELS_TINY.keys()) @@ -377,37 +404,187 @@ def test_tensorflow_export(self, test_name, name, model_name, task, onnx_config_ self._onnx_export(test_name, name, model_name, task, onnx_config_class_constructor, monolith=monolith) - @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL) + @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @require_torch + @require_vision + @require_diffusers + def test_pytorch_export_for_stable_diffusion_models(self, model_type, model_name): + self._onnx_export_sd(model_type, model_name) + + @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) @require_torch @require_vision @require_diffusers - def test_pytorch_export_for_stable_diffusion_models(self, model_name): - set_seed(SEED) - - pipeline = StableDiffusionPipeline.from_pretrained(model_name) - output_names = [ - "text_encoder/model.onnx", - "unet/model.onnx", - "vae_encoder/model.onnx", - "vae_decoder/model.onnx", + @require_torch_gpu + @slow + @pytest.mark.run_slow + @pytest.mark.gpu_test + def test_pytorch_export_for_stable_diffusion_models_cuda(self, model_type, model_name): + self._onnx_export_sd(model_type, model_name, device="cuda") + + +class CustomWhisperOnnxConfig(WhisperOnnxConfig): + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + common_outputs = super().outputs + + if self._behavior is ConfigBehavior.ENCODER: + for i in range(self._config.encoder_layers): + common_outputs[f"encoder_attentions.{i}"] = {0: "batch_size"} + elif self._behavior is ConfigBehavior.DECODER: + for i in range(self._config.decoder_layers): + common_outputs[f"decoder_attentions.{i}"] = {0: "batch_size", 3: "decoder_sequence_length"} + for i in range(self._config.decoder_layers): + common_outputs[f"cross_attentions.{i}"] = {0: "batch_size", 3: "cross_attention_length"} + + return common_outputs + + @property + def torch_to_onnx_output_map(self): + if self._behavior is ConfigBehavior.ENCODER: + # The encoder export uses WhisperEncoder that returns the key "attentions" + return {"attentions": "encoder_attentions"} + else: + return {} + + +class MPTDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + """ + MPT swaps the two last dimensions for the key cache compared to usual transformers + decoder models, thus the redefinition here. + """ + + def generate(self, input_name: str, framework: str = "pt"): + past_key_shape = ( + self.batch_size, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads, + self.sequence_length, + ) + past_value_shape = ( + self.batch_size, + self.num_attention_heads, + self.sequence_length, + self.hidden_size // self.num_attention_heads, + ) + return [ + ( + self.random_float_tensor(past_key_shape, framework=framework), + self.random_float_tensor(past_value_shape, framework=framework), + ) + for _ in range(self.num_layers) ] - models_and_onnx_configs = get_stable_diffusion_models_for_export(pipeline) - model, _ = models_and_onnx_configs["vae_encoder"] - model.forward = lambda sample: {"latent_sample": model.encode(x=sample)["latent_dist"].parameters} + + +class CustomMPTOnnxConfig(TextDecoderOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + MPTDummyPastKeyValuesGenerator, + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DUMMY_PKV_GENERATOR_CLASS = MPTDummyPastKeyValuesGenerator + + DEFAULT_ONNX_OPSET = 14 # aten::tril operator requires opset>=14 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + hidden_size="d_model", num_layers="n_layers", num_attention_heads="n_heads" + ) + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + """ + Adapted from https://github.com/huggingface/optimum/blob/v1.9.0/optimum/exporters/onnx/base.py#L625 + """ + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 3: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name} + + +def fn_get_submodels_custom(model): + return {"decoder_model": model, "decoder_with_past_model": model} + + +class OnnxCustomExport(TestCase): + def test_custom_export_official_model(self): + model_id = "openai/whisper-tiny.en" + config = AutoConfig.from_pretrained(model_id) + + custom_whisper_onnx_config = CustomWhisperOnnxConfig( + config=config, + task="automatic-speech-recognition", + ) + + encoder_config = custom_whisper_onnx_config.with_behavior("encoder") + decoder_config = custom_whisper_onnx_config.with_behavior("decoder", use_past=False) + decoder_with_past_config = custom_whisper_onnx_config.with_behavior("decoder", use_past=True) + + custom_onnx_configs = { + "encoder_model": encoder_config, + "decoder_model": decoder_config, + "decoder_with_past_model": decoder_with_past_config, + } with TemporaryDirectory() as tmpdirname: - _, onnx_outputs = export_models( - models_and_onnx_configs=models_and_onnx_configs, - opset=14, - output_dir=Path(tmpdirname), - output_names=output_names, - device="cpu", # TODO: Add GPU test + main_export( + model_id, + output=tmpdirname, + no_post_process=True, + model_kwargs={"output_attentions": True}, + custom_onnx_configs=custom_onnx_configs, ) - validate_models_outputs( - models_and_onnx_configs=models_and_onnx_configs, - onnx_named_outputs=onnx_outputs, - output_dir=Path(tmpdirname), - atol=1e-3, - onnx_files_subpaths=output_names, - use_subprocess=False, + + model = onnx.load(os.path.join(tmpdirname, "decoder_model.onnx")) + + output_names = [outp.name for outp in model.graph.output] + assert "decoder_attentions.0" in output_names + assert "cross_attentions.0" in output_names + + @parameterized.expand([(None,), (fn_get_submodels_custom,)]) + def test_custom_export_trust_remote(self, fn_get_submodels): + model_id = "fxmarty/tiny-mpt-random-remote-code" + config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + + onnx_config = CustomMPTOnnxConfig( + config=config, + task="text-generation", + use_past_in_inputs=False, + use_present_in_outputs=True, + ) + onnx_config_with_past = CustomMPTOnnxConfig(config, task="text-generation", use_past=True) + + custom_onnx_configs = { + "decoder_model": onnx_config, + "decoder_with_past_model": onnx_config_with_past, + } + + with TemporaryDirectory() as tmpdirname: + main_export( + model_id, + output=tmpdirname, + task="text-generation-with-past", + trust_remote_code=True, + custom_onnx_configs=custom_onnx_configs, + no_post_process=True, + fn_get_submodels=fn_get_submodels, ) + + def test_custom_export_trust_remote_error(self): + model_id = "fxmarty/tiny-mpt-random-remote-code" + + with self.assertRaises(ValueError) as context: + with TemporaryDirectory() as tmpdirname: + main_export( + model_id, + output=tmpdirname, + task="text-generation-with-past", + trust_remote_code=True, + no_post_process=True, + ) + + self.assertIn("custom or unsupported architecture", str(context.exception)) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 0594f6b0f7..6ffbbb7732 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -59,7 +59,7 @@ from transformers.modeling_utils import no_init_weights from transformers.onnx.utils import get_preprocessor from transformers.testing_utils import get_gpu_count, require_torch_gpu -from utils_onnxruntime_tests import MODEL_NAMES, SEED +from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin from optimum.exporters import TasksManager from optimum.exporters.onnx import main_export @@ -90,7 +90,12 @@ ORTStableDiffusionPipeline, ) from optimum.onnxruntime.base import ORTDecoder, ORTDecoderForSeq2Seq, ORTEncoder -from optimum.onnxruntime.modeling_diffusion import ORTModelTextEncoder, ORTModelUnet, ORTModelVaeDecoder +from optimum.onnxruntime.modeling_diffusion import ( + ORTModelTextEncoder, + ORTModelUnet, + ORTModelVaeDecoder, + ORTModelVaeEncoder, +) from optimum.onnxruntime.modeling_ort import ORTModel from optimum.pipelines import pipeline from optimum.utils import ( @@ -98,9 +103,10 @@ DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, DIFFUSION_MODEL_UNET_SUBFOLDER, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, + DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, logging, ) -from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_hf_token +from optimum.utils.testing_utils import grid_parameters, require_hf_token logger = logging.get_logger() @@ -115,57 +121,6 @@ def __exit__(self, type, value, traceback): self.elapsed = (time.perf_counter() - self.elapsed) * 1e3 -class ORTModelTestMixin(unittest.TestCase): - ARCH_MODEL_MAP = {} - - TENSOR_ALIAS_TO_TYPE = { - "pt": torch.Tensor, - "np": np.ndarray, - } - - @classmethod - def setUpClass(cls): - cls.onnx_model_dirs = {} - - def _setup(self, model_args: Dict): - """ - Exports the PyTorch models to ONNX ahead of time to avoid multiple exports during the tests. - We don't use unittest setUpClass, in order to still be able to run individual tests. - """ - model_arch = model_args["model_arch"] - model_arch_and_params = model_args["test_name"] - - # TODO: this should actually be checked in ORTModel! - task = self.TASK - if "use_cache" in model_args and model_args["use_cache"] is True: - task = task + "-with-past" - - if "use_cache" in model_args and task not in TasksManager.get_supported_tasks_for_model_type( - model_arch.replace("_", "-"), exporter="onnx" - ): - self.skipTest("Unsupported export case") - - if model_arch_and_params not in self.onnx_model_dirs: - # model_args will contain kwargs to pass to ORTModel.from_pretrained() - model_args.pop("test_name") - model_args.pop("model_arch") - - model_id = ( - self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] - ) - set_seed(SEED) - onnx_model = self.ORTMODEL_CLASS.from_pretrained(model_id, **model_args, use_io_binding=False, export=True) - - model_dir = tempfile.mkdtemp(prefix=f"{model_arch_and_params}_{self.TASK}_") - onnx_model.save_pretrained(model_dir) - self.onnx_model_dirs[model_arch_and_params] = model_dir - - @classmethod - def tearDownClass(cls): - for _, dir_path in cls.onnx_model_dirs.items(): - shutil.rmtree(dir_path) - - class ORTModelIntegrationTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -259,6 +214,7 @@ def test_load_stable_diffusion_model_from_cache(self): self.assertIsInstance(model.text_encoder, ORTModelTextEncoder) self.assertIsInstance(model.vae_decoder, ORTModelVaeDecoder) + self.assertIsInstance(model.vae_encoder, ORTModelVaeEncoder) self.assertIsInstance(model.unet, ORTModelUnet) self.assertIsInstance(model.config, Dict) @@ -330,6 +286,7 @@ def test_load_stable_diffusion_model_from_hub(self): model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) self.assertIsInstance(model.text_encoder, ORTModelTextEncoder) self.assertIsInstance(model.vae_decoder, ORTModelVaeDecoder) + self.assertIsInstance(model.vae_encoder, ORTModelVaeEncoder) self.assertIsInstance(model.unet, ORTModelUnet) self.assertIsInstance(model.config, Dict) @@ -343,6 +300,7 @@ def test_load_stable_diffusion_model_cuda_provider(self): self.assertListEqual(model.unet.session.get_providers(), model.providers) self.assertListEqual(model.text_encoder.session.get_providers(), model.providers) self.assertListEqual(model.vae_decoder.session.get_providers(), model.providers) + self.assertListEqual(model.vae_encoder.session.get_providers(), model.providers) self.assertEqual(model.device, torch.device("cuda:0")) def test_load_stable_diffusion_model_cpu_provider(self): @@ -353,6 +311,7 @@ def test_load_stable_diffusion_model_cpu_provider(self): self.assertListEqual(model.unet.session.get_providers(), model.providers) self.assertListEqual(model.text_encoder.session.get_providers(), model.providers) self.assertListEqual(model.vae_decoder.session.get_providers(), model.providers) + self.assertListEqual(model.vae_encoder.session.get_providers(), model.providers) self.assertEqual(model.device, torch.device("cpu")) def test_load_stable_diffusion_model_unknown_provider(self): @@ -480,6 +439,7 @@ def test_passing_session_options_stable_diffusion(self): self.assertEqual(model.unet.session.get_session_options().intra_op_num_threads, 3) self.assertEqual(model.text_encoder.session.get_session_options().intra_op_num_threads, 3) self.assertEqual(model.vae_decoder.session.get_session_options().intra_op_num_threads, 3) + self.assertEqual(model.vae_encoder.session.get_session_options().intra_op_num_threads, 3) @require_torch_gpu @pytest.mark.gpu_test @@ -696,7 +656,9 @@ def test_passing_provider_options_stable_diffusion(self): self.assertEqual( model.vae_decoder.session.get_provider_options()["CUDAExecutionProvider"]["do_copy_in_default_stream"], "1" ) - + self.assertEqual( + model.vae_encoder.session.get_provider_options()["CUDAExecutionProvider"]["do_copy_in_default_stream"], "1" + ) model = ORTStableDiffusionPipeline.from_pretrained( self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID, provider="CUDAExecutionProvider", @@ -712,6 +674,9 @@ def test_passing_provider_options_stable_diffusion(self): self.assertEqual( model.vae_decoder.session.get_provider_options()["CUDAExecutionProvider"]["do_copy_in_default_stream"], "0" ) + self.assertEqual( + model.vae_encoder.session.get_provider_options()["CUDAExecutionProvider"]["do_copy_in_default_stream"], "0" + ) def test_stable_diffusion_model_on_cpu(self): model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) @@ -721,9 +686,11 @@ def test_stable_diffusion_model_on_cpu(self): self.assertEqual(model.unet.device, cpu) self.assertEqual(model.text_encoder.device, cpu) self.assertEqual(model.vae_decoder.device, cpu) + self.assertEqual(model.vae_encoder.device, cpu) self.assertEqual(model.unet.session.get_providers()[0], "CPUExecutionProvider") self.assertEqual(model.text_encoder.session.get_providers()[0], "CPUExecutionProvider") self.assertEqual(model.vae_decoder.session.get_providers()[0], "CPUExecutionProvider") + self.assertEqual(model.vae_encoder.session.get_providers()[0], "CPUExecutionProvider") self.assertListEqual(model.providers, ["CPUExecutionProvider"]) # test string device input for to() @@ -735,9 +702,11 @@ def test_stable_diffusion_model_on_cpu_str(self): self.assertEqual(model.unet.device, cpu) self.assertEqual(model.text_encoder.device, cpu) self.assertEqual(model.vae_decoder.device, cpu) + self.assertEqual(model.vae_encoder.device, cpu) self.assertEqual(model.unet.session.get_providers()[0], "CPUExecutionProvider") self.assertEqual(model.text_encoder.session.get_providers()[0], "CPUExecutionProvider") self.assertEqual(model.vae_decoder.session.get_providers()[0], "CPUExecutionProvider") + self.assertEqual(model.vae_encoder.session.get_providers()[0], "CPUExecutionProvider") self.assertListEqual(model.providers, ["CPUExecutionProvider"]) @require_torch_gpu @@ -750,9 +719,11 @@ def test_stable_diffusion_model_on_gpu(self): self.assertEqual(model.unet.device, torch.device("cuda:0")) self.assertEqual(model.text_encoder.device, torch.device("cuda:0")) self.assertEqual(model.vae_decoder.device, torch.device("cuda:0")) + self.assertEqual(model.vae_encoder.device, torch.device("cuda:0")) self.assertEqual(model.unet.session.get_providers()[0], "CUDAExecutionProvider") self.assertEqual(model.text_encoder.session.get_providers()[0], "CUDAExecutionProvider") self.assertEqual(model.vae_decoder.session.get_providers()[0], "CUDAExecutionProvider") + self.assertEqual(model.vae_encoder.session.get_providers()[0], "CUDAExecutionProvider") self.assertListEqual(model.providers, ["CUDAExecutionProvider", "CPUExecutionProvider"]) @unittest.skipIf(get_gpu_count() <= 1, "this test requires multi-gpu") @@ -762,18 +733,21 @@ def test_stable_diffusion_model_on_gpu_id(self): self.assertEqual(model.unet.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") self.assertEqual(model.text_encoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") self.assertEqual(model.vae_decoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") + self.assertEqual(model.vae_encoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) model.to(1) self.assertEqual(model.unet.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") self.assertEqual(model.text_encoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") self.assertEqual(model.vae_decoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") + self.assertEqual(model.vae_encoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) model.to("cuda:1") self.assertEqual(model.unet.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") self.assertEqual(model.text_encoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") self.assertEqual(model.vae_decoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") + self.assertEqual(model.vae_encoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") # test string device input for to() @require_torch_gpu @@ -785,9 +759,11 @@ def test_stable_diffusion_model_on_gpu_str(self): self.assertEqual(model.unet.device, torch.device("cuda:0")) self.assertEqual(model.text_encoder.device, torch.device("cuda:0")) self.assertEqual(model.vae_decoder.device, torch.device("cuda:0")) + self.assertEqual(model.vae_encoder.device, torch.device("cuda:0")) self.assertEqual(model.unet.session.get_providers()[0], "CUDAExecutionProvider") self.assertEqual(model.text_encoder.session.get_providers()[0], "CUDAExecutionProvider") self.assertEqual(model.vae_decoder.session.get_providers()[0], "CUDAExecutionProvider") + self.assertEqual(model.vae_encoder.session.get_providers()[0], "CUDAExecutionProvider") self.assertListEqual(model.providers, ["CUDAExecutionProvider", "CPUExecutionProvider"]) @require_hf_token @@ -837,6 +813,7 @@ def test_save_stable_diffusion_model(self): DIFFUSION_MODEL_UNET_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, + DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, }: folder_contents = os.listdir(os.path.join(tmpdirname, subfoler)) self.assertIn(ONNX_WEIGHTS_NAME, folder_contents) @@ -916,6 +893,7 @@ def test_save_load_stable_diffusion_model_with_external_data(self): DIFFUSION_MODEL_UNET_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, + DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, }: folder_contents = os.listdir(os.path.join(tmpdirname, subfoler)) self.assertIn(ONNX_WEIGHTS_NAME, folder_contents) @@ -1111,7 +1089,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForQuestionAnswering.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("custom or unsupported architecture", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): @@ -4045,28 +4023,56 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach feature_extractor, tokenizer = self._get_preprocessors(model_id) data = self._get_sample_image() - features = feature_extractor(data, return_tensors="pt") start_token = "" decoder_start_token_id = tokenizer.encode(start_token)[0] - decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} - with torch.no_grad(): - transformers_outputs = transformers_model(**features, **decoder_inputs) + extra_inputs = [{}, {}] - for input_type in ["pt", "np"]: - features = feature_extractor(data, return_tensors=input_type) + if use_cache and False: + # TODO: the dims will fail with other models + fake_pkv = tuple((torch.rand(1, 4, 1, 8), torch.rand(1, 4, 1, 8)) for _ in range(5)) + extra_inputs[1]["past_key_values"] = fake_pkv - if input_type == "np": - decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id} + for extra_inps in extra_inputs: + features = feature_extractor(data, return_tensors="pt") + decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} - onnx_outputs = onnx_model(**features, **decoder_inputs) + with torch.no_grad(): + transformers_outputs = transformers_model(**features, **decoder_inputs, **extra_inps) + for input_type in ["pt", "np"]: + features = feature_extractor(data, return_tensors=input_type) - self.assertTrue("logits" in onnx_outputs) - self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + if input_type == "np": + decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id} - # Compare tensor outputs - self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3)) + if "past_key_values" in extra_inps: + del extra_inps["past_key_values"] # test only with pytorch + + onnx_outputs = onnx_model(**features, **decoder_inputs, **extra_inps) + + self.assertTrue("logits" in onnx_outputs) + self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + + if use_cache: + self.assertEqual( + len(onnx_outputs["past_key_values"]), len(transformers_outputs["past_key_values"]) + ) + self.assertEqual( + len(onnx_outputs["past_key_values"][0]), len(transformers_outputs["past_key_values"][0]) + ) + for i, _ in enumerate(onnx_outputs["past_key_values"]): + for j, ort_pkv in enumerate(onnx_outputs["past_key_values"][i]): + trfs_pkv = transformers_outputs["past_key_values"][i][j] + self.assertTrue( + torch.allclose(ort_pkv, trfs_pkv, atol=1e-3), + f" Maxdiff: {torch.abs(ort_pkv - trfs_pkv).max()}", + ) + + # Compare tensor outputs + self.assertTrue( + torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3) + ) gc.collect() @@ -4285,120 +4291,3 @@ def test_find_untested_architectures(self, task: str, test_class): f"For the task `{task}`, the ONNX export supports {supported_export_models}, but only {tested_architectures} are tested.\n" f" Missing {untested_architectures}." ) - - -class ORTStableDiffusionPipelineIntegrationTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion", - ] - ORTMODEL_CLASS = ORTStableDiffusionPipeline - TASK = "stable-diffusion" - - @require_diffusers - def test_load_vanilla_model_which_is_not_supported(self): - with self.assertRaises(Exception) as context: - _ = ORTStableDiffusionPipeline.from_pretrained(MODEL_NAMES["bert"], export=True) - - self.assertIn( - f"does not appear to have a file named {ORTStableDiffusionPipeline.config_name}", str(context.exception) - ) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_compare_to_diffusers(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - model_id = MODEL_NAMES[model_arch] - ort_pipeline = ORTStableDiffusionPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) - - self.assertIsInstance(ort_pipeline.text_encoder, ORTModelTextEncoder) - self.assertIsInstance(ort_pipeline.vae_decoder, ORTModelVaeDecoder) - self.assertIsInstance(ort_pipeline.unet, ORTModelUnet) - self.assertIsInstance(ort_pipeline.config, Dict) - - from diffusers import StableDiffusionPipeline - - diffusers_pipeline = StableDiffusionPipeline.from_pretrained(model_id) - diffusers_pipeline.safety_checker = None - num_images_per_prompt, height, width, scale_factor = 1, 512, 512, 8 - latents_shape = ( - num_images_per_prompt, - diffusers_pipeline.unet.in_channels, - height // scale_factor, - width // scale_factor, - ) - latents = np.random.randn(*latents_shape).astype(np.float32) - kwargs = { - "prompt": "sailing ship in storm by Leonardo da Vinci", - "num_inference_steps": 1, - "output_type": "np", - "num_images_per_prompt": num_images_per_prompt, - "height": height, - "width": width, - } - ort_outputs = ort_pipeline(latents=latents, **kwargs).images - self.assertIsInstance(ort_outputs, np.ndarray) - - with torch.no_grad(): - diffusers_outputs = diffusers_pipeline(latents=torch.from_numpy(latents), **kwargs).images - # Compare model outputs - self.assertTrue(np.allclose(ort_outputs, diffusers_outputs, atol=1e-4)) - # Compare model devices - self.assertEqual(diffusers_pipeline.device, ort_pipeline.device) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_num_images_per_prompt(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - num_images_per_prompt = 4 - batch_size = 6 - - pipeline = ORTStableDiffusionPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) - prompt = "sailing ship in storm by Leonardo da Vinci" - outputs = pipeline(prompt, num_inference_steps=2, output_type="np").images - self.assertEqual(outputs.shape, (1, 128, 128, 3)) - outputs = pipeline( - prompt, num_inference_steps=2, num_images_per_prompt=num_images_per_prompt, output_type="np" - ).images - self.assertEqual(outputs.shape, (num_images_per_prompt, 128, 128, 3)) - outputs = pipeline([prompt] * batch_size, num_inference_steps=2, output_type="np").images - self.assertEqual(outputs.shape, (batch_size, 128, 128, 3)) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_image_reproducibility(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - ort_pipeline = ORTStableDiffusionPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) - kwargs = { - "prompt": "sailing ship in storm by Leonardo da Vinci", - "output_type": "np", - "num_inference_steps": 2, - } - np.random.seed(0) - ort_outputs_1 = ort_pipeline(**kwargs) - np.random.seed(0) - ort_outputs_2 = ort_pipeline(**kwargs) - ort_outputs_3 = ort_pipeline(**kwargs) - - # Compare model outputs - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) - self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) - - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) - ) - @require_torch_gpu - @pytest.mark.gpu_test - @require_diffusers - def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): - model_args = {"test_name": test_name, "model_arch": model_arch} - self._setup(model_args) - pipe = ORTStableDiffusionPipeline.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) - outputs = pipe("sailing ship in storm by Leonardo da Vinci", output_type="np").images - # Verify model devices - self.assertEqual(pipe.device.type.lower(), "cuda") - # Verify model outptus - self.assertIsInstance(outputs, np.ndarray) - self.assertEqual(outputs.shape, (1, 128, 128, 3)) diff --git a/tests/onnxruntime/test_stable_diffusion_pipeline.py b/tests/onnxruntime/test_stable_diffusion_pipeline.py new file mode 100644 index 0000000000..e7b3bc5ec6 --- /dev/null +++ b/tests/onnxruntime/test_stable_diffusion_pipeline.py @@ -0,0 +1,366 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +import unittest +from typing import Dict + +import numpy as np +import pytest +import torch +from diffusers import ( + OnnxStableDiffusionImg2ImgPipeline, + StableDiffusionPipeline, + StableDiffusionXLPipeline, +) +from diffusers.utils import floats_tensor, load_image +from parameterized import parameterized +from transformers.testing_utils import require_torch_gpu +from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin + +from optimum.onnxruntime import ORTStableDiffusionPipeline +from optimum.onnxruntime.modeling_diffusion import ( + ORTModelTextEncoder, + ORTModelUnet, + ORTModelVaeDecoder, + ORTModelVaeEncoder, + ORTStableDiffusionImg2ImgPipeline, + ORTStableDiffusionInpaintPipeline, + ORTStableDiffusionXLImg2ImgPipeline, + ORTStableDiffusionXLPipeline, +) +from optimum.utils import logging +from optimum.utils.testing_utils import grid_parameters, require_diffusers + + +logger = logging.get_logger() + + +def _generate_inputs(): + inputs = { + "prompt": "sailing ship in storm by Leonardo da Vinci", + "num_inference_steps": 3, + "guidance_scale": 7.5, + "output_type": "np", + } + return inputs + + +class ORTStableDiffusionPipelineBase(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = [ + "stable-diffusion", + ] + ORTMODEL_CLASS = ORTStableDiffusionPipeline + TASK = "stable-diffusion" + + @require_diffusers + def test_load_vanilla_model_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) + + self.assertIn( + f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_num_images_per_prompt(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + num_images_per_prompt = 4 + batch_size = 6 + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + self.assertEqual(pipeline.vae_scale_factor, 2) + self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) + self.assertEqual(pipeline.unet.config["in_channels"], 4) + inputs = self.generate_inputs() + outputs = pipeline(**inputs).images + self.assertEqual(outputs.shape, (1, 128, 128, 3)) + outputs = pipeline(**inputs, num_images_per_prompt=num_images_per_prompt).images + self.assertEqual(outputs.shape, (num_images_per_prompt, 128, 128, 3)) + outputs = pipeline([inputs.pop("prompt")] * batch_size, **inputs).images + self.assertEqual(outputs.shape, (batch_size, 128, 128, 3)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) + ) + @require_torch_gpu + @pytest.mark.gpu_test + @require_diffusers + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + inputs = self.generate_inputs() + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (1, 128, 128, 3)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_callback(self, model_arch: str): + def callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: + callback_fn.has_been_called = True + callback_fn.number_of_steps += 1 + + pipe = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) + callback_fn.has_been_called = False + callback_fn.number_of_steps = 0 + inputs = self.generate_inputs(height=64, width=64) + pipe(**inputs, callback=callback_fn, callback_steps=1) + self.assertTrue(callback_fn.has_been_called) + self.assertEqual(callback_fn.number_of_steps, inputs["num_inference_steps"]) + + def generate_inputs(self, height=128, width=128): + inputs = _generate_inputs() + inputs["height"] = height + inputs["width"] = width + return inputs + + +class ORTStableDiffusionImg2ImgPipelineTest(ORTStableDiffusionPipelineBase): + SUPPORTED_ARCHITECTURES = [ + "stable-diffusion", + ] + ORTMODEL_CLASS = ORTStableDiffusionImg2ImgPipeline + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_compare_diffusers_pipeline(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + inputs = self.generate_inputs() + inputs["prompt"] = "A painting of a squirrel eating a burger" + + output = pipeline(**inputs, generator=np.random.RandomState(0)).images[0, -3:, -3:, -1] + # https://github.com/huggingface/diffusers/blob/v0.17.1/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py#L71 + expected_slice = np.array([0.69643, 0.58484, 0.50314, 0.58760, 0.55368, 0.59643, 0.51529, 0.41217, 0.49087]) + self.assertTrue(np.allclose(output.flatten(), expected_slice, atol=1e-1)) + + # Verify it can be loaded with ORT diffusers pipeline + diffusers_pipeline = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) + diffusers_output = diffusers_pipeline(**inputs, generator=np.random.RandomState(0)).images[0, -3:, -3:, -1] + self.assertTrue(np.allclose(output, diffusers_output, atol=1e-4)) + + def generate_inputs(self, height=128, width=128): + inputs = _generate_inputs() + inputs["image"] = floats_tensor((1, 3, height, width), rng=random.Random(SEED)) + inputs["strength"] = 0.75 + return inputs + + +class ORTStableDiffusionPipelineTest(unittest.TestCase): + SUPPORTED_ARCHITECTURES = [ + "stable-diffusion", + ] + ORTMODEL_CLASS = ORTStableDiffusionPipeline + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_compare_to_diffusers(self, model_arch: str): + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) + self.assertIsInstance(ort_pipeline.text_encoder, ORTModelTextEncoder) + self.assertIsInstance(ort_pipeline.vae_decoder, ORTModelVaeDecoder) + self.assertIsInstance(ort_pipeline.vae_encoder, ORTModelVaeEncoder) + self.assertIsInstance(ort_pipeline.unet, ORTModelUnet) + self.assertIsInstance(ort_pipeline.config, Dict) + + pipeline = StableDiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) + pipeline.safety_checker = None + batch_size, num_images_per_prompt, height, width = 1, 2, 64, 64 + + latents = ort_pipeline.prepare_latents( + batch_size * num_images_per_prompt, + ort_pipeline.unet.config["in_channels"], + height, + width, + dtype=np.float32, + generator=np.random.RandomState(0), + ) + + kwargs = { + "prompt": "sailing ship in storm by Leonardo da Vinci", + "num_inference_steps": 1, + "num_images_per_prompt": num_images_per_prompt, + "height": height, + "width": width, + "guidance_rescale": 0.1, + } + + for output_type in ["latent", "np"]: + ort_outputs = ort_pipeline(latents=latents, output_type=output_type, **kwargs).images + self.assertIsInstance(ort_outputs, np.ndarray) + with torch.no_grad(): + outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images + # Compare model outputs + self.assertTrue(np.allclose(ort_outputs, outputs, atol=1e-4)) + # Compare model devices + self.assertEqual(pipeline.device, ort_pipeline.device) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_image_reproducibility(self, model_arch: str): + pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) + inputs = _generate_inputs() + height = 64 + width = 64 + np.random.seed(0) + ort_outputs_1 = pipeline(**inputs, height=height, width=width) + np.random.seed(0) + ort_outputs_2 = pipeline(**inputs, height=height, width=width) + ort_outputs_3 = pipeline(**inputs, height=height, width=width) + # Compare model outputs + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) + self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + + +class ORTStableDiffusionXLPipelineTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = [ + "stable-diffusion-xl", + ] + ORTMODEL_CLASS = ORTStableDiffusionXLPipeline + TASK = "stable-diffusion-xl" + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_compare_to_diffusers(self, model_arch: str): + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) + self.assertIsInstance(ort_pipeline.text_encoder, ORTModelTextEncoder) + self.assertIsInstance(ort_pipeline.text_encoder_2, ORTModelTextEncoder) + self.assertIsInstance(ort_pipeline.vae_decoder, ORTModelVaeDecoder) + self.assertIsInstance(ort_pipeline.vae_encoder, ORTModelVaeEncoder) + self.assertIsInstance(ort_pipeline.unet, ORTModelUnet) + self.assertIsInstance(ort_pipeline.config, Dict) + + pipeline = StableDiffusionXLPipeline.from_pretrained(MODEL_NAMES[model_arch]) + batch_size, num_images_per_prompt, height, width = 2, 2, 64, 64 + latents = ort_pipeline.prepare_latents( + batch_size * num_images_per_prompt, + ort_pipeline.unet.config["in_channels"], + height, + width, + dtype=np.float32, + generator=np.random.RandomState(0), + ) + + kwargs = { + "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size, + "num_inference_steps": 1, + "num_images_per_prompt": num_images_per_prompt, + "height": height, + "width": width, + "guidance_rescale": 0.1, + } + + for output_type in ["latent", "np"]: + ort_outputs = ort_pipeline(latents=latents, output_type=output_type, **kwargs).images + self.assertIsInstance(ort_outputs, np.ndarray) + with torch.no_grad(): + outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images + + # Compare model outputs + self.assertTrue(np.allclose(ort_outputs, outputs, atol=1e-4)) + # Compare model devices + self.assertEqual(pipeline.device, ort_pipeline.device) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_image_reproducibility(self, model_arch: str): + pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) + inputs = _generate_inputs() + height = 64 + width = 64 + np.random.seed(0) + ort_outputs_1 = pipeline(**inputs, height=height, width=width) + np.random.seed(0) + ort_outputs_2 = pipeline(**inputs, height=height, width=width) + ort_outputs_3 = pipeline(**inputs, height=height, width=width) + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) + self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + + +class ORTStableDiffusionInpaintPipelineTest(ORTStableDiffusionPipelineBase): + SUPPORTED_ARCHITECTURES = [ + "stable-diffusion", + ] + ORTMODEL_CLASS = ORTStableDiffusionInpaintPipeline + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_compare_diffusers_pipeline(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + height = 64 + width = 64 + latents_shape = ( + 1, + ort_pipeline.vae_decoder.config["latent_channels"], + height // ort_pipeline.vae_scale_factor, + width // ort_pipeline.vae_scale_factor, + ) + latents = np.random.randn(*latents_shape).astype(np.float32) + inputs = self.generate_inputs(height=height, width=width) + outputs = ort_pipeline(**inputs, latents=latents).images + self.assertEqual(outputs.shape, (1, height, width, 3)) + expected_slice = np.array([0.5442, 0.3002, 0.5665, 0.6485, 0.4421, 0.6441, 0.5778, 0.5076, 0.5612]) + self.assertTrue(np.allclose(outputs[0, -3:, -3:, -1].flatten(), expected_slice, atol=1e-4)) + + def generate_inputs(self, height=128, width=128): + inputs = super(ORTStableDiffusionInpaintPipelineTest, self).generate_inputs(height, width) + inputs["image"] = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ).resize((64, 64)) + + inputs["mask_image"] = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ).resize((64, 64)) + + return inputs + + +class ORTStableDiffusionXLImg2ImgPipelineTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = [ + "stable-diffusion-xl", + ] + ORTMODEL_CLASS = ORTStableDiffusionXLImg2ImgPipeline + TASK = "stable-diffusion-xl" + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_inference(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + inputs = self.generate_inputs() + output = pipeline(**inputs, generator=np.random.RandomState(0)).images[0, -3:, -3:, -1] + expected_slice = np.array([0.6515, 0.5405, 0.4858, 0.5632, 0.5174, 0.5681, 0.4948, 0.4253, 0.5080]) + + self.assertTrue(np.allclose(output.flatten(), expected_slice, atol=1e-1)) + + def generate_inputs(self, height=128, width=128): + inputs = _generate_inputs() + inputs["image"] = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ).resize((height, width)) + + inputs["strength"] = 0.75 + return inputs diff --git a/tests/onnxruntime/test_utils.py b/tests/onnxruntime/test_utils.py index dedb661124..c3dd4c0561 100644 --- a/tests/onnxruntime/test_utils.py +++ b/tests/onnxruntime/test_utils.py @@ -9,8 +9,10 @@ class ProviderAndDeviceGettersTest(unittest.TestCase): def test_get_device_for_provider(self): - self.assertEqual(get_device_for_provider("CPUExecutionProvider"), torch.device("cpu")) - self.assertEqual(get_device_for_provider("CUDAExecutionProvider"), torch.device("cuda:0")) + self.assertEqual(get_device_for_provider("CPUExecutionProvider", provider_options={}), torch.device("cpu")) + self.assertEqual( + get_device_for_provider("CUDAExecutionProvider", provider_options={"device_id": 1}), torch.device("cuda:1") + ) def test_get_provider_for_device(self): self.assertEqual(get_provider_for_device(torch.device("cpu")), "CPUExecutionProvider") diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 2d9eb71655..f83acd91e6 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -13,6 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import shutil +import tempfile +import unittest +from typing import Dict + +import numpy as np +import torch +from transformers import set_seed + +from optimum.exporters import TasksManager + + MODEL_NAMES = { "albert": "hf-internal-testing/tiny-random-AlbertModel", "audio_spectrogram_transformer": "Ericwang/tiny-random-ast", @@ -67,6 +79,7 @@ "segformer": "hf-internal-testing/tiny-random-SegformerModel", "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", + "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "swin": "hf-internal-testing/tiny-random-SwinModel", "t5": "hf-internal-testing/tiny-random-t5", "vit": "hf-internal-testing/tiny-random-vit", @@ -88,3 +101,54 @@ } SEED = 42 + + +class ORTModelTestMixin(unittest.TestCase): + ARCH_MODEL_MAP = {} + + TENSOR_ALIAS_TO_TYPE = { + "pt": torch.Tensor, + "np": np.ndarray, + } + + @classmethod + def setUpClass(cls): + cls.onnx_model_dirs = {} + + def _setup(self, model_args: Dict): + """ + Exports the PyTorch models to ONNX ahead of time to avoid multiple exports during the tests. + We don't use unittest setUpClass, in order to still be able to run individual tests. + """ + model_arch = model_args["model_arch"] + model_arch_and_params = model_args["test_name"] + + # TODO: this should actually be checked in ORTModel! + task = self.TASK + if "use_cache" in model_args and model_args["use_cache"] is True: + task = task + "-with-past" + + if "use_cache" in model_args and task not in TasksManager.get_supported_tasks_for_model_type( + model_arch.replace("_", "-"), exporter="onnx" + ): + self.skipTest("Unsupported export case") + + if model_arch_and_params not in self.onnx_model_dirs: + # model_args will contain kwargs to pass to ORTModel.from_pretrained() + model_args.pop("test_name") + model_args.pop("model_arch") + + model_id = ( + self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + ) + set_seed(SEED) + onnx_model = self.ORTMODEL_CLASS.from_pretrained(model_id, **model_args, use_io_binding=False, export=True) + + model_dir = tempfile.mkdtemp(prefix=f"{model_arch_and_params}_{self.TASK}_") + onnx_model.save_pretrained(model_dir) + self.onnx_model_dirs[model_arch_and_params] = model_dir + + @classmethod + def tearDownClass(cls): + for _, dir_path in cls.onnx_model_dirs.items(): + shutil.rmtree(dir_path)