From f1b708c4e29a392d84d69f820bcc45bfd89cc221 Mon Sep 17 00:00:00 2001 From: Tom Savage Date: Mon, 16 Sep 2024 09:04:31 +0100 Subject: [PATCH 1/2] Fixes detection of CuPy installed with pre-built wheels (#1965) The CuPy library ships both a source distribution (`cupy`) as well as versions containing pre-built wheels (`cupy-cuda11x`, `cupy-cuda12x`, `cupy-rocm-5-0`, `cupy-rocm-4-3`). Use of `_is_package_available` to detect CuPy only works for the source distribution of CuPy and fails when using the pre-built wheels versions. This is because the `_is_package_available` will always attempt to resolve version information (even if it's not required) and in doing so assumes that the _importable_ package name matches the _installed_ distribution name. While this is usually the case, it doesn't work for CuPy and several other libraries. ONNX Runtime for example might be installed as `onnxruntime` or `onnxruntime-gpu` and thus Optimum just uses `importlib.util.find_spec` to work around the same problem. This commit replicates the same solution for CuPy. --- optimum/onnxruntime/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index ad40af92b9..985980e31b 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -13,6 +13,7 @@ # limitations under the License. """Utility functions, classes and constants for ONNX Runtime.""" +import importlib import os import re from enum import Enum @@ -31,7 +32,6 @@ import onnxruntime as ort from ..exporters.onnx import OnnxConfig, OnnxConfigWithLoss -from ..utils.import_utils import _is_package_available if TYPE_CHECKING: @@ -91,9 +91,11 @@ def is_onnxruntime_training_available(): def is_cupy_available(): """ - Checks if onnxruntime-training is available. + Checks if CuPy is available. """ - return _is_package_available("cupy") + # Don't use _is_package_available as it doesn't work with CuPy installed + # with `cupy-cuda*` and `cupy-rocm-*` package name (prebuilt wheels). + return importlib.util.find_spec("cupy") is not None class ORTConfigManager: From ca36fc4f66577cd4ac2e6cedcc204d830a1f4985 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 16 Sep 2024 11:08:34 +0200 Subject: [PATCH 2/2] Adding `ORTPipelineForxxx` entrypoints (#1960) * created auto task mappings * added correct auto classes * created auto task mappings * added correct auto classes * added ort/auto diffusion classes * fix ORTPipeline detection * start test refactoring * dynamic dtype * support torch random numbers generator * compact diffusion testing suite * fix * test * test * test * use latent-consistency architecture name instead of lcm * fix * add ort diffusion pipeline tests * added dummy objects * remove duplicate code * support testing without diffusers * remove unnecessary * revert * style * remove model parts from optimum.onnxruntime --- optimum/exporters/tasks.py | 2 +- optimum/modeling_base.py | 9 +- optimum/onnxruntime/__init__.py | 16 + optimum/onnxruntime/base.py | 50 +- optimum/onnxruntime/modeling_diffusion.py | 338 ++++++-- optimum/onnxruntime/modeling_seq2seq.py | 68 -- .../diffusers/pipeline_latent_consistency.py | 6 +- .../diffusers/pipeline_stable_diffusion.py | 16 +- .../pipeline_stable_diffusion_img2img.py | 83 +- .../pipeline_stable_diffusion_inpaint.py | 22 +- .../diffusers/pipeline_stable_diffusion_xl.py | 20 +- .../pipeline_stable_diffusion_xl_img2img.py | 28 +- optimum/pipelines/diffusers/pipeline_utils.py | 8 +- optimum/utils/dummy_diffusers_objects.py | 44 + tests/exporters/exporters_utils.py | 2 +- tests/onnxruntime/test_diffusion.py | 793 ++++++++++++++++++ tests/onnxruntime/test_modeling.py | 47 +- .../test_stable_diffusion_pipeline.py | 562 ------------- 18 files changed, 1287 insertions(+), 827 deletions(-) create mode 100644 tests/onnxruntime/test_diffusion.py delete mode 100644 tests/onnxruntime/test_stable_diffusion_pipeline.py diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 9705304087..a489f34fb0 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -308,9 +308,9 @@ class TasksManager: "image-feature-extraction": "feature-extraction", # for backward compatibility and testing (where # model task and model type are still the same) - "lcm": "text-to-image", "stable-diffusion": "text-to-image", "stable-diffusion-xl": "text-to-image", + "latent-consistency": "text-to-image", } _CUSTOM_CLASSES = { diff --git a/optimum/modeling_base.py b/optimum/modeling_base.py index 5bab0622de..3da2d9d0d2 100644 --- a/optimum/modeling_base.py +++ b/optimum/modeling_base.py @@ -85,7 +85,6 @@ class PreTrainedModel(ABC): # noqa: F811 class OptimizedModel(PreTrainedModel): config_class = AutoConfig - load_tf_weights = None base_model_prefix = "optimized_model" config_name = CONFIG_NAME @@ -378,10 +377,14 @@ def from_pretrained( ) model_id, revision = model_id.split("@") - library_name = TasksManager.infer_library_from_model(model_id, subfolder, revision, cache_dir, token=token) + library_name = TasksManager.infer_library_from_model( + model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + ) if library_name == "timm": - config = PretrainedConfig.from_pretrained(model_id, subfolder, revision) + config = PretrainedConfig.from_pretrained( + model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token + ) if config is None: if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME: diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index f1d4f63a9f..09a48ec955 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -79,6 +79,10 @@ "ORTStableDiffusionXLPipeline", "ORTStableDiffusionXLImg2ImgPipeline", "ORTLatentConsistencyModelPipeline", + "ORTPipelineForImage2Image", + "ORTPipelineForInpainting", + "ORTPipelineForText2Image", + "ORTDiffusionPipeline", ] else: _import_structure["modeling_diffusion"] = [ @@ -88,6 +92,10 @@ "ORTStableDiffusionXLPipeline", "ORTStableDiffusionXLImg2ImgPipeline", "ORTLatentConsistencyModelPipeline", + "ORTPipelineForImage2Image", + "ORTPipelineForInpainting", + "ORTPipelineForText2Image", + "ORTDiffusionPipeline", ] @@ -137,7 +145,11 @@ raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_diffusers_objects import ( + ORTDiffusionPipeline, ORTLatentConsistencyModelPipeline, + ORTPipelineForImage2Image, + ORTPipelineForInpainting, + ORTPipelineForText2Image, ORTStableDiffusionImg2ImgPipeline, ORTStableDiffusionInpaintPipeline, ORTStableDiffusionPipeline, @@ -146,7 +158,11 @@ ) else: from .modeling_diffusion import ( + ORTDiffusionPipeline, ORTLatentConsistencyModelPipeline, + ORTPipelineForImage2Image, + ORTPipelineForInpainting, + ORTPipelineForText2Image, ORTStableDiffusionImg2ImgPipeline, ORTStableDiffusionInpaintPipeline, ORTStableDiffusionPipeline, diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index d9877670ba..0e54bafed7 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -41,17 +41,11 @@ class ORTModelPart: _prepare_onnx_inputs = ORTModel._prepare_onnx_inputs _prepare_onnx_outputs = ORTModel._prepare_onnx_outputs - def __init__( - self, - session: InferenceSession, - parent_model: "ORTModel", - ): + def __init__(self, session: InferenceSession, parent_model: "ORTModel"): self.session = session self.parent_model = parent_model - self.normalized_config = NormalizedConfigManager.get_normalized_config_class( - self.parent_model.config.model_type - )(self.parent_model.config) self.main_input_name = self.parent_model.main_input_name + self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()} @@ -90,12 +84,18 @@ class ORTEncoder(ORTModelPart): Encoder part of the encoder-decoder model for ONNX Runtime inference. """ - def forward( - self, - input_ids: torch.LongTensor, - attention_mask: torch.LongTensor, - **kwargs, - ) -> BaseModelOutput: + def __init__(self, session: InferenceSession, parent_model: "ORTModel"): + super().__init__(session, parent_model) + + config = ( + self.parent_model.config.encoder + if hasattr(self.parent_model.config, "encoder") + else self.parent_model.config + ) + + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) + + def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, **kwargs) -> BaseModelOutput: use_torch = isinstance(input_ids, torch.Tensor) self.parent_model.raise_on_numpy_input_io_binding(use_torch) @@ -138,6 +138,14 @@ def __init__( ): super().__init__(session, parent_model) + config = ( + self.parent_model.config.decoder + if hasattr(self.parent_model.config, "decoder") + else self.parent_model.config + ) + + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) + # TODO: make this less hacky. self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)] self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)] @@ -153,11 +161,7 @@ def __init__( self.use_past_in_outputs = len(self.key_value_output_names) > 0 self.use_past_in_inputs = len(self.key_value_input_names) > 0 - self.use_fp16 = False - for inp in session.get_inputs(): - if "past_key_values" in inp.name and inp.type == "tensor(float16)": - self.use_fp16 = True - break + self.use_fp16 = self.dtype == torch.float16 # We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2 # can be used but do not support KV caching for the cross-attention key/values, see: @@ -461,11 +465,3 @@ def prepare_inputs_for_merged( cache_position = cache_position.to(self.device) return use_cache_branch_tensor, past_key_values, cache_position - - -class ORTDecoder(ORTDecoderForSeq2Seq): - def __init__(self, *args, **kwargs): - logger.warning( - "The class `ORTDecoder` is deprecated and will be removed in optimum v1.15.0, please use `ORTDecoderForSeq2Seq` instead." - ) - super().__init__(*args, **kwargs) diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 4bbfb2eda2..18cd38c5f2 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -17,7 +17,7 @@ import os import shutil import warnings -from abc import abstractmethod +from collections import OrderedDict from pathlib import Path from tempfile import TemporaryDirectory from typing import Any, Dict, Optional, Union @@ -25,18 +25,28 @@ import numpy as np import torch from diffusers import ( + AutoPipelineForImage2Image, + AutoPipelineForInpainting, + AutoPipelineForText2Image, + ConfigMixin, DDIMScheduler, + LatentConsistencyModelPipeline, LMSDiscreteScheduler, PNDMScheduler, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLPipeline, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available from huggingface_hub import snapshot_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE +from huggingface_hub.utils import validate_hf_hub_args from transformers import CLIPFeatureExtractor, CLIPTokenizer from transformers.file_utils import add_end_docstrings +from transformers.modeling_outputs import ModelOutput import onnxruntime as ort @@ -56,9 +66,10 @@ DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, ) +from .base import ORTModelPart +from .io_binding import TypeHelper from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .utils import ( - _ORT_TO_NP_TYPE, ONNX_WEIGHTS_NAME, get_provider_for_device, parse_device, @@ -69,23 +80,23 @@ logger = logging.getLogger(__name__) -class ORTStableDiffusionPipelineBase(ORTModel): - auto_model_class = StableDiffusionPipeline - main_input_name = "input_ids" - base_model_prefix = "onnx_model" +class ORTPipeline(ORTModel): + auto_model_class = None + model_type = "onnx_pipeline" + config_name = "model_index.json" sub_component_config_name = "config.json" def __init__( self, vae_decoder_session: ort.InferenceSession, - text_encoder_session: ort.InferenceSession, unet_session: ort.InferenceSession, - config: Dict[str, Any], tokenizer: CLIPTokenizer, + config: Dict[str, Any], scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], feature_extractor: Optional[CLIPFeatureExtractor] = None, vae_encoder_session: Optional[ort.InferenceSession] = None, + text_encoder_session: Optional[ort.InferenceSession] = None, text_encoder_2_session: Optional[ort.InferenceSession] = None, tokenizer_2: Optional[CLIPTokenizer] = None, use_io_binding: Optional[bool] = None, @@ -94,23 +105,28 @@ def __init__( """ Args: vae_decoder_session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the VAE decoder. - text_encoder_session (`ort.InferenceSession`): - The ONNX Runtime inference session associated to the text encoder. + The ONNX Runtime inference session associated to the VAE decoder unet_session (`ort.InferenceSession`): The ONNX Runtime inference session associated to the U-NET. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + for the text encoder. config (`Dict[str, Any]`): A config dictionary from which the model components will be instantiated. Make sure to only load configuration files of compatible classes. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler (`Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]`): A scheduler to be used in combination with the U-NET component to denoise the encoded image latents. feature_extractor (`Optional[CLIPFeatureExtractor]`, defaults to `None`): A model extracting features from generated images to be used as inputs for the `safety_checker` vae_encoder_session (`Optional[ort.InferenceSession]`, defaults to `None`): The ONNX Runtime inference session associated to the VAE encoder. + text_encoder_session (`Optional[ort.InferenceSession]`, defaults to `None`): + The ONNX Runtime inference session associated to the text encoder. + tokenizer_2 (`Optional[CLIPTokenizer]`, defaults to `None`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer) + for the second text encoder. use_io_binding (`Optional[bool]`, defaults to `None`): Whether to use IOBinding during inference to avoid memory copy between the host and devices. Defaults to `True` if the device is CUDA, otherwise defaults to `False`. @@ -118,7 +134,7 @@ def __init__( The directory under which the model exported to ONNX was saved. """ self.shared_attributes_init( - vae_decoder_session, + model=vae_decoder_session, use_io_binding=use_io_binding, model_save_dir=model_save_dir, ) @@ -350,9 +366,9 @@ def _from_pretrained( text_encoder_path=new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name, unet_path=new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, vae_encoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, - text_encoder_2_path=new_model_save_dir - / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER - / text_encoder_2_file_name, + text_encoder_2_path=( + new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name + ), provider=provider, session_options=session_options, provider_options=provider_options, @@ -399,7 +415,7 @@ def _from_transformers( provider_options: Optional[Dict[str, Any]] = None, use_io_binding: Optional[bool] = None, task: Optional[str] = None, - ) -> "ORTStableDiffusionPipeline": + ) -> "ORTPipeline": if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.", @@ -480,131 +496,142 @@ def _save_config(self, save_directory): self.save_config(save_directory) -# TODO : Use ORTModelPart once IOBinding support is added -class _ORTDiffusionModelPart: - """ - For multi-file ONNX models, represents a part of the model. - It has its own `onnxruntime.InferenceSession`, and can perform a forward pass. - """ - +class ORTPipelinePart(ORTModelPart): CONFIG_NAME = "config.json" - def __init__(self, session: ort.InferenceSession, parent_model: ORTModel): - self.session = session - self.parent_model = parent_model - self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())} - self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} + def __init__(self, session: ort.InferenceSession, parent_model: ORTPipeline): config_path = Path(session._model_path).parent / self.CONFIG_NAME - self.config = self.parent_model._dict_from_json_file(config_path) if config_path.is_file() else {} - self.input_dtype = {inputs.name: _ORT_TO_NP_TYPE[inputs.type] for inputs in self.session.get_inputs()} + + if config_path.is_file(): + # TODO: use FrozenDict + self.config = parent_model._dict_from_json_file(config_path) + else: + self.config = {} + + super().__init__(session, parent_model) @property - def device(self): - return self.parent_model.device + def input_dtype(self): + # for backward compatibility and diffusion mixins (will be standardized in the future) + return {name: TypeHelper.ort_type_to_numpy_type(ort_type) for name, ort_type in self.input_dtypes.items()} - @abstractmethod - def forward(self, *args, **kwargs): - pass - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) +class ORTModelTextEncoder(ORTPipelinePart): + def forward(self, input_ids: Union[np.ndarray, torch.Tensor]): + use_torch = isinstance(input_ids, torch.Tensor) + model_inputs = {"input_ids": input_ids} -class ORTModelTextEncoder(_ORTDiffusionModelPart): - def forward(self, input_ids: np.ndarray): - onnx_inputs = { - "input_ids": input_ids, - } - outputs = self.session.run(None, onnx_inputs) - return outputs + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + return ModelOutput(**model_outputs) -class ORTModelUnet(_ORTDiffusionModelPart): - def __init__(self, session: ort.InferenceSession, parent_model: ORTModel): - super().__init__(session, parent_model) +class ORTModelUnet(ORTPipelinePart): def forward( self, - sample: np.ndarray, - timestep: np.ndarray, - encoder_hidden_states: np.ndarray, - text_embeds: Optional[np.ndarray] = None, - time_ids: Optional[np.ndarray] = None, - timestep_cond: Optional[np.ndarray] = None, + sample: Union[np.ndarray, torch.Tensor], + timestep: Union[np.ndarray, torch.Tensor], + encoder_hidden_states: Union[np.ndarray, torch.Tensor], + text_embeds: Optional[Union[np.ndarray, torch.Tensor]] = None, + time_ids: Optional[Union[np.ndarray, torch.Tensor]] = None, + timestep_cond: Optional[Union[np.ndarray, torch.Tensor]] = None, ): - onnx_inputs = { + use_torch = isinstance(sample, torch.Tensor) + + model_inputs = { "sample": sample, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states, + "text_embeds": text_embeds, + "time_ids": time_ids, + "timestep_cond": timestep_cond, } - if text_embeds is not None: - onnx_inputs["text_embeds"] = text_embeds - if time_ids is not None: - onnx_inputs["time_ids"] = time_ids - if timestep_cond is not None: - onnx_inputs["timestep_cond"] = timestep_cond - outputs = self.session.run(None, onnx_inputs) - return outputs + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + return ModelOutput(**model_outputs) -class ORTModelVaeDecoder(_ORTDiffusionModelPart): - def forward(self, latent_sample: np.ndarray): - onnx_inputs = { - "latent_sample": latent_sample, - } - outputs = self.session.run(None, onnx_inputs) - return outputs +class ORTModelVaeDecoder(ORTPipelinePart): + def forward(self, latent_sample: Union[np.ndarray, torch.Tensor]): + use_torch = isinstance(latent_sample, torch.Tensor) -class ORTModelVaeEncoder(_ORTDiffusionModelPart): - def forward(self, sample: np.ndarray): - onnx_inputs = { - "sample": sample, - } - outputs = self.session.run(None, onnx_inputs) - return outputs + model_inputs = {"latent_sample": latent_sample} + + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + + return ModelOutput(**model_outputs) + + +class ORTModelVaeEncoder(ORTPipelinePart): + def forward(self, sample: Union[np.ndarray, torch.Tensor]): + use_torch = isinstance(sample, torch.Tensor) + + model_inputs = {"sample": sample} + + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.session.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + + return ModelOutput(**model_outputs) @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTStableDiffusionPipeline(ORTStableDiffusionPipelineBase, StableDiffusionPipelineMixin): +class ORTStableDiffusionPipeline(ORTPipeline, StableDiffusionPipelineMixin): """ ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline). """ + main_input_name = "prompt" + auto_model_class = StableDiffusionPipeline + __call__ = StableDiffusionPipelineMixin.__call__ @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTStableDiffusionImg2ImgPipeline(ORTStableDiffusionPipelineBase, StableDiffusionImg2ImgPipelineMixin): +class ORTStableDiffusionImg2ImgPipeline(ORTPipeline, StableDiffusionImg2ImgPipelineMixin): """ ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img#diffusers.StableDiffusionImg2ImgPipeline). """ + main_input_name = "prompt" + auto_model_class = StableDiffusionImg2ImgPipeline + __call__ = StableDiffusionImg2ImgPipelineMixin.__call__ @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTStableDiffusionInpaintPipeline(ORTStableDiffusionPipelineBase, StableDiffusionInpaintPipelineMixin): +class ORTStableDiffusionInpaintPipeline(ORTPipeline, StableDiffusionInpaintPipelineMixin): """ ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionInpaintPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint#diffusers.StableDiffusionInpaintPipeline). """ + main_input_name = "prompt" + auto_model_class = StableDiffusionInpaintPipeline + __call__ = StableDiffusionInpaintPipelineMixin.__call__ @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTLatentConsistencyModelPipeline(ORTStableDiffusionPipelineBase, LatentConsistencyPipelineMixin): +class ORTLatentConsistencyModelPipeline(ORTPipeline, LatentConsistencyPipelineMixin): """ ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.LatentConsistencyModelPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/latent_consistency#diffusers.LatentConsistencyModelPipeline). """ - __call__ = LatentConsistencyPipelineMixin.__call__ + main_input_name = "prompt" + auto_model_class = LatentConsistencyModelPipeline + __call__ = LatentConsistencyPipelineMixin.__call__ -class ORTStableDiffusionXLPipelineBase(ORTStableDiffusionPipelineBase): - auto_model_class = StableDiffusionXLImg2ImgPipeline +class ORTStableDiffusionXLPipelineBase(ORTPipeline): def __init__( self, vae_decoder_session: ort.InferenceSession, @@ -657,6 +684,9 @@ class ORTStableDiffusionXLPipeline(ORTStableDiffusionXLPipelineBase, StableDiffu ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline). """ + main_input_name = "prompt" + auto_model_class = StableDiffusionXLPipeline + __call__ = StableDiffusionXLPipelineMixin.__call__ @@ -666,4 +696,140 @@ class ORTStableDiffusionXLImg2ImgPipeline(ORTStableDiffusionXLPipelineBase, Stab ONNX Runtime-powered stable diffusion pipeline corresponding to [diffusers.StableDiffusionXLImg2ImgPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline). """ + main_input_name = "prompt" + auto_model_class = StableDiffusionXLImg2ImgPipeline + __call__ = StableDiffusionXLImg2ImgPipelineMixin.__call__ + + +SUPPORTED_ORT_PIPELINES = [ + ORTStableDiffusionPipeline, + ORTStableDiffusionImg2ImgPipeline, + ORTStableDiffusionInpaintPipeline, + ORTLatentConsistencyModelPipeline, + ORTStableDiffusionXLPipeline, + ORTStableDiffusionXLImg2ImgPipeline, +] + + +def _get_pipeline_class(pipeline_class_name: str, throw_error_if_not_exist: bool = True): + for ort_pipeline_class in SUPPORTED_ORT_PIPELINES: + if ( + ort_pipeline_class.__name__ == pipeline_class_name + or ort_pipeline_class.auto_model_class.__name__ == pipeline_class_name + ): + return ort_pipeline_class + + if throw_error_if_not_exist: + raise ValueError(f"ORTDiffusionPipeline can't find a pipeline linked to {pipeline_class_name}") + + +class ORTDiffusionPipeline(ConfigMixin): + config_name = "model_index.json" + + @classmethod + @validate_hf_hub_args + def from_pretrained(cls, pretrained_model_or_path, **kwargs): + load_config_kwargs = { + "force_download": kwargs.get("force_download", False), + "resume_download": kwargs.get("resume_download", None), + "local_files_only": kwargs.get("local_files_only", False), + "cache_dir": kwargs.get("cache_dir", None), + "revision": kwargs.get("revision", None), + "proxies": kwargs.get("proxies", None), + "token": kwargs.get("token", None), + } + + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + config = config[0] if isinstance(config, tuple) else config + class_name = config["_class_name"] + + ort_pipeline_class = _get_pipeline_class(class_name) + + return ort_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs) + + +ORT_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", ORTStableDiffusionPipeline), + ("stable-diffusion-xl", ORTStableDiffusionXLPipeline), + ("latent-consistency", ORTLatentConsistencyModelPipeline), + ] +) + +ORT_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", ORTStableDiffusionImg2ImgPipeline), + ("stable-diffusion-xl", ORTStableDiffusionXLImg2ImgPipeline), + ] +) + +ORT_INPAINT_PIPELINES_MAPPING = OrderedDict( + [ + ("stable-diffusion", ORTStableDiffusionInpaintPipeline), + ] +) + +SUPPORTED_ORT_PIPELINES_MAPPINGS = [ + ORT_TEXT2IMAGE_PIPELINES_MAPPING, + ORT_IMAGE2IMAGE_PIPELINES_MAPPING, + ORT_INPAINT_PIPELINES_MAPPING, +] + + +def _get_task_class(mapping, pipeline_class_name): + def _get_model_name(pipeline_class_name): + for ort_pipelines_mapping in SUPPORTED_ORT_PIPELINES_MAPPINGS: + for model_name, ort_pipeline_class in ort_pipelines_mapping.items(): + if ( + ort_pipeline_class.__name__ == pipeline_class_name + or ort_pipeline_class.auto_model_class.__name__ == pipeline_class_name + ): + return model_name + + model_name = _get_model_name(pipeline_class_name) + + if model_name is not None: + task_class = mapping.get(model_name, None) + if task_class is not None: + return task_class + + raise ValueError(f"ORTPipelineForTask can't find a pipeline linked to {pipeline_class_name} for {model_name}") + + +class ORTPipelineForTask(ConfigMixin): + config_name = "model_index.json" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, **kwargs): + load_config_kwargs = { + "force_download": kwargs.get("force_download", False), + "resume_download": kwargs.get("resume_download", None), + "local_files_only": kwargs.get("local_files_only", False), + "cache_dir": kwargs.get("cache_dir", None), + "revision": kwargs.get("revision", None), + "proxies": kwargs.get("proxies", None), + "token": kwargs.get("token", None), + } + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + config = config[0] if isinstance(config, tuple) else config + class_name = config["_class_name"] + + ort_pipeline_class = _get_task_class(cls.ort_pipelines_mapping, class_name) + + return ort_pipeline_class.from_pretrained(pretrained_model_or_path, **kwargs) + + +class ORTPipelineForText2Image(ORTPipelineForTask): + auto_model_class = AutoPipelineForText2Image + ort_pipelines_mapping = ORT_TEXT2IMAGE_PIPELINES_MAPPING + + +class ORTPipelineForImage2Image(ORTPipelineForTask): + auto_model_class = AutoPipelineForImage2Image + ort_pipelines_mapping = ORT_IMAGE2IMAGE_PIPELINES_MAPPING + + +class ORTPipelineForInpainting(ORTPipelineForTask): + auto_model_class = AutoPipelineForInpainting + ort_pipelines_mapping = ORT_INPAINT_PIPELINES_MAPPING diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 4ce3e4707e..3cecadafe3 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -46,7 +46,6 @@ from ..onnx.utils import _get_external_data_paths from ..utils import check_if_transformers_greater from ..utils.file_utils import validate_file_exists -from ..utils.normalized_config import NormalizedConfigManager from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .base import ORTDecoderForSeq2Seq, ORTEncoder from .constants import ( @@ -72,16 +71,6 @@ from transformers.generation_utils import GenerationMixin -# if check_if_transformers_greater("4.37.0"): -# # starting from transformers v4.37.0, the whisper generation loop is implemented in the `WhisperGenerationMixin` -# # and it implements many new features including short and long form generation, and starts with 2 init tokens -# from transformers.models.whisper.generation_whisper import WhisperGenerationMixin -# else: - -# class WhisperGenerationMixin(WhisperForConditionalGeneration, GenerationMixin): -# pass - - if check_if_transformers_greater("4.43.0"): from transformers.cache_utils import EncoderDecoderCache else: @@ -1165,49 +1154,6 @@ class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin): auto_model_class = AutoModelForSeq2SeqLM main_input_name = "input_ids" - def __init__( - self, - encoder_session: ort.InferenceSession, - decoder_session: ort.InferenceSession, - config: "PretrainedConfig", - onnx_paths: List[str], - decoder_with_past_session: Optional[ort.InferenceSession] = None, - use_cache: bool = True, - use_io_binding: Optional[bool] = None, - model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - preprocessors: Optional[List] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ): - super().__init__( - encoder_session, - decoder_session, - config, - onnx_paths, - decoder_with_past_session, - use_cache, - use_io_binding, - model_save_dir, - preprocessors, - generation_config, - **kwargs, - ) - - # The normalized_config initialization in ORTModelPart is unfortunately wrong as the top level config is initialized. - if config.model_type == "encoder-decoder": - self.encoder.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.encoder.model_type - )(config.encoder) - - self.decoder.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.decoder.model_type - )(config.decoder) - - if self.decoder_with_past is not None: - self.decoder_with_past.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.decoder.model_type - )(config.decoder) - def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: return ORTEncoder(session, self) @@ -1521,20 +1467,6 @@ def __init__( **kwargs, ) - # The normalized_config initialization in ORTModelPart is unfortunately wrong as the top level config is initialized. - self.encoder.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.encoder.model_type - )(config.encoder) - - self.decoder.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.decoder.model_type - )(config.decoder) - - if self.decoder_with_past is not None: - self.decoder_with_past.normalized_config = NormalizedConfigManager.get_normalized_config_class( - config.decoder.model_type - )(config.decoder) - def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: return ORTEncoderForVisionEncoderDecoder(session, self) diff --git a/optimum/pipelines/diffusers/pipeline_latent_consistency.py b/optimum/pipelines/diffusers/pipeline_latent_consistency.py index 41c85b5b6a..630d463de7 100644 --- a/optimum/pipelines/diffusers/pipeline_latent_consistency.py +++ b/optimum/pipelines/diffusers/pipeline_latent_consistency.py @@ -36,7 +36,7 @@ def __call__( original_inference_steps: int = None, guidance_scale: float = 8.5, num_images_per_prompt: int = 1, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, output_type: str = "pil", @@ -66,7 +66,7 @@ def __call__( usually at the expense of lower image quality. num_images_per_prompt (`int`, defaults to 1): The number of images to generate per prompt. - generator (`Optional[np.random.RandomState]`, defaults to `None`):: + generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`): A np.random.RandomState to make generation deterministic. latents (`Optional[np.ndarray]`, defaults to `None`): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image @@ -121,7 +121,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() prompt_embeds = self._encode_prompt( prompt, diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion.py index 98bff0de44..6cc47fab1b 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion.py @@ -189,7 +189,15 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype ) if latents is None: - latents = generator.randn(*shape).astype(dtype) + if isinstance(generator, np.random.RandomState): + latents = generator.randn(*shape).astype(dtype) + elif isinstance(generator, torch.Generator): + latents = torch.randn(*shape, generator=generator).numpy().astype(dtype) + else: + raise ValueError( + f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" + f" {type(generator)}." + ) elif latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -209,7 +217,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, @@ -248,7 +256,7 @@ def __call__( eta (`float`, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[np.random.RandomState]`, defaults to `None`):: + generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`):: A np.random.RandomState to make generation deterministic. latents (`Optional[np.ndarray]`, defaults to `None`): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image @@ -303,7 +311,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py index 81a6ffa1e0..a66035a789 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py @@ -16,10 +16,9 @@ from typing import Callable, List, Optional, Union import numpy as np -import PIL +import PIL.Image import torch from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.utils import deprecate from .pipeline_stable_diffusion import StableDiffusionPipelineMixin @@ -72,6 +71,43 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, image, timesteps, batch_size, num_images_per_prompt, dtype, generator=None): + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + else: + init_latents = self.vae_encoder(sample=image)[0] * self.vae_decoder.config.get("scaling_factor", 0.18215) + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = np.concatenate([init_latents] * additional_image_per_prompt, axis=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = np.concatenate([init_latents], axis=0) + + # add noise to latents using the timesteps + if isinstance(generator, np.random.RandomState): + noise = generator.randn(*init_latents.shape).astype(dtype) + elif isinstance(generator, torch.Generator): + noise = torch.randn(*init_latents.shape, generator=generator).numpy().astype(dtype) + else: + raise ValueError( + f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" + f" {type(generator)}." + ) + + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) + ).numpy() + + return init_latents + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionImg2ImgPipeline.__call__ def __call__( self, @@ -83,7 +119,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, output_type: str = "pil", @@ -125,7 +161,7 @@ def __call__( eta (`float`, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[np.random.RandomState]`, defaults to `None`):: + generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`): A np.random.RandomState to make generation deterministic. prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not @@ -168,7 +204,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() # set timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -191,31 +227,7 @@ def __call__( latents_dtype = prompt_embeds.dtype image = image.astype(latents_dtype) - # encode the init image into latents and scale the latents - init_latents = self.vae_encoder(sample=image)[0] - scaling_factor = self.vae_decoder.config.get("scaling_factor", 0.18215) - init_latents = scaling_factor * init_latents - - if isinstance(prompt, str): - prompt = [prompt] - if len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] == 0: - # expand init_latents for batch_size - deprecation_message = ( - f"You have passed {len(prompt)} text prompts (`prompt`), but only {init_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = len(prompt) // init_latents.shape[0] - init_latents = np.concatenate([init_latents] * additional_image_per_prompt * num_images_per_prompt, axis=0) - elif len(prompt) > init_latents.shape[0] and len(prompt) % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {len(prompt)} text prompts." - ) - else: - init_latents = np.concatenate([init_latents] * num_images_per_prompt, axis=0) # get the original timestep using init_timestep offset = self.scheduler.config.get("steps_offset", 0) @@ -225,12 +237,8 @@ def __call__( timesteps = self.scheduler.timesteps.numpy()[-init_timestep] timesteps = np.array([timesteps] * batch_size * num_images_per_prompt) - # add noise to latents using the timesteps - noise = generator.randn(*init_latents.shape).astype(latents_dtype) - init_latents = self.scheduler.add_noise( - torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) - ) - init_latents = init_latents.numpy() + # 5. Prepare latent variables + latents = self.prepare_latents(image, timesteps, batch_size, num_images_per_prompt, latents_dtype, generator) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -241,8 +249,6 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta - latents = init_latents - t_start = max(num_inference_steps - init_timestep + offset, 0) timesteps = self.scheduler.timesteps[t_start:].numpy() @@ -276,7 +282,8 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) if output_type == "latent": image = latents diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py index 19de793ccd..cb3c7db96e 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py @@ -16,7 +16,7 @@ from typing import Callable, List, Optional, Union import numpy as np -import PIL +import PIL.Image import torch from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import PIL_INTERPOLATION @@ -108,7 +108,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, @@ -200,7 +200,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() # set timesteps self.scheduler.set_timesteps(num_inference_steps) @@ -229,11 +229,19 @@ def __call__( width // self.vae_scale_factor, ) latents_dtype = prompt_embeds.dtype + if latents is None: - latents = generator.randn(*latents_shape).astype(latents_dtype) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + if isinstance(generator, np.random.RandomState): + latents = generator.randn(*latents_shape).astype(latents_dtype) + elif isinstance(generator, torch.Generator): + latents = torch.randn(*latents_shape, generator=generator).numpy().astype(latents_dtype) + else: + raise ValueError( + f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" + f" {type(generator)}." + ) + elif latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # prepare mask and masked_image mask, masked_image = prepare_mask_and_masked_image( diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py index 2a5e7bf78b..0407c16a77 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py @@ -235,7 +235,15 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype ) if latents is None: - latents = generator.randn(*shape).astype(dtype) + if isinstance(generator, np.random.RandomState): + latents = generator.randn(*shape).astype(dtype) + elif isinstance(generator, torch.Generator): + latents = torch.randn(*shape, generator=generator).numpy().astype(dtype) + else: + raise ValueError( + f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" + f" {type(generator)}." + ) elif latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") @@ -270,7 +278,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, @@ -315,7 +323,7 @@ def __call__( eta (`float`, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`Optional[np.random.RandomState]`, defaults to `None`):: + generator (`Optional[Union[np.random.RandomState, torch.Generator]]`, defaults to `None`):: A np.random.RandomState to make generation deterministic. latents (`Optional[np.ndarray]`, defaults to `None`): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image @@ -383,7 +391,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -440,6 +448,7 @@ def __call__( timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance @@ -475,7 +484,8 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) if output_type == "latent": image = latents diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py index a07903a735..19988599b6 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py @@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import PIL +import PIL.Image import torch from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput @@ -222,7 +222,7 @@ def get_timesteps(self, num_inference_steps, strength): return timesteps, num_inference_steps - t_start # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None): + def prepare_latents(self, image, timesteps, batch_size, num_images_per_prompt, dtype, generator=None): batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: @@ -242,11 +242,22 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt init_latents = np.concatenate([init_latents], axis=0) # add noise to latents using the timesteps - noise = generator.randn(*init_latents.shape).astype(dtype) + if isinstance(generator, np.random.RandomState): + noise = generator.randn(*init_latents.shape).astype(dtype) + elif isinstance(generator, torch.Generator): + noise = torch.randn(*init_latents.shape, generator=generator).numpy().astype(dtype) + else: + raise ValueError( + f"Expected `generator` to be of type `np.random.RandomState` or `torch.Generator`, but got" + f" {type(generator)}." + ) + init_latents = self.scheduler.add_noise( - torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timestep) + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timesteps) ) - return init_latents.numpy() + init_latents = init_latents.numpy() + + return init_latents def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype @@ -274,7 +285,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, eta: float = 0.0, - generator: Optional[np.random.RandomState] = None, + generator: Optional[Union[np.random.RandomState, torch.Generator]] = None, latents: Optional[np.ndarray] = None, prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, @@ -375,7 +386,7 @@ def __call__( batch_size = prompt_embeds.shape[0] if generator is None: - generator = np.random + generator = np.random.RandomState() # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -482,7 +493,8 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) if output_type == "latent": image = latents diff --git a/optimum/pipelines/diffusers/pipeline_utils.py b/optimum/pipelines/diffusers/pipeline_utils.py index 869b91ffe5..e9d5986b61 100644 --- a/optimum/pipelines/diffusers/pipeline_utils.py +++ b/optimum/pipelines/diffusers/pipeline_utils.py @@ -17,7 +17,7 @@ from typing import List, Optional, Union import numpy as np -import PIL +import PIL.Image import torch from diffusers import ConfigMixin from diffusers.image_processor import VaeImageProcessor as DiffusersVaeImageProcessor @@ -206,7 +206,7 @@ def postprocess( def get_height_width( self, - image: [PIL.Image.Image, np.ndarray], + image: Union[PIL.Image.Image, np.ndarray], height: Optional[int] = None, width: Optional[int] = None, ): @@ -264,10 +264,10 @@ def reshape(images: np.ndarray) -> np.ndarray: # TODO : remove after diffusers v0.21.0 release def resize( self, - image: [PIL.Image.Image, np.ndarray, torch.Tensor], + image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], height: Optional[int] = None, width: Optional[int] = None, - ) -> [PIL.Image.Image, np.ndarray, torch.Tensor]: + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Resize image. """ diff --git a/optimum/utils/dummy_diffusers_objects.py b/optimum/utils/dummy_diffusers_objects.py index f6914bbcd3..35d1ffe9fc 100644 --- a/optimum/utils/dummy_diffusers_objects.py +++ b/optimum/utils/dummy_diffusers_objects.py @@ -79,3 +79,47 @@ def __init__(self, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["diffusers"]) + + +class ORTDiffusionPipeline(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTPipelineForText2Image(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTPipelineForImage2Image(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTPipelineForInpainting(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index a55c7a124d..c8a33b0be3 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -298,7 +298,7 @@ PYTORCH_DIFFUSION_MODEL = { "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", - "lcm": "echarlaix/tiny-random-latent-consistency", + "latent-consistency": "echarlaix/tiny-random-latent-consistency", } PYTORCH_TIMM_MODEL = { diff --git a/tests/onnxruntime/test_diffusion.py b/tests/onnxruntime/test_diffusion.py new file mode 100644 index 0000000000..9f480b2d1a --- /dev/null +++ b/tests/onnxruntime/test_diffusion.py @@ -0,0 +1,793 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import PIL +import pytest +import torch +from diffusers import ( + AutoPipelineForImage2Image, + AutoPipelineForInpainting, + AutoPipelineForText2Image, + DiffusionPipeline, +) +from diffusers.utils import load_image +from parameterized import parameterized +from transformers.testing_utils import require_torch_gpu +from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin + +from optimum.onnxruntime import ( + ORTDiffusionPipeline, + ORTPipelineForImage2Image, + ORTPipelineForInpainting, + ORTPipelineForText2Image, +) +from optimum.pipelines.diffusers.pipeline_utils import VaeImageProcessor +from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_ort_rocm + + +def get_generator(framework, seed): + if framework == "np": + return np.random.RandomState(seed) + elif framework == "pt": + return torch.Generator().manual_seed(seed) + else: + raise ValueError(f"Unknown framework: {framework}") + + +def _generate_prompts(batch_size=1): + inputs = { + "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size, + "num_inference_steps": 3, + "guidance_scale": 7.5, + "output_type": "np", + } + return inputs + + +def _generate_images(height=128, width=128, batch_size=1, channel=3, input_type="pil"): + if input_type == "pil": + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ).resize((width, height)) + elif input_type == "np": + image = np.random.rand(height, width, channel) + elif input_type == "pt": + image = torch.rand((channel, height, width)) + + return [image] * batch_size + + +def to_np(image): + if isinstance(image[0], PIL.Image.Image): + return np.stack([np.array(i) for i in image], axis=0) + elif isinstance(image, torch.Tensor): + return image.cpu().numpy().transpose(0, 2, 3, 1) + return image + + +class ORTPipelineForText2ImageTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = ["latent-consistency", "stable-diffusion", "stable-diffusion-xl"] + + ORTMODEL_CLASS = ORTPipelineForText2Image + AUTOMODEL_CLASS = AutoPipelineForText2Image + + TASK = "text-to-image" + + def generate_inputs(self, height=128, width=128, batch_size=1): + inputs = _generate_prompts(batch_size=batch_size) + + inputs["height"] = height + inputs["width"] = width + + return inputs + + @require_diffusers + def test_load_vanilla_model_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) + + self.assertIn( + f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_ort_pipeline_class_dispatch(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + auto_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + auto_pipeline = DiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) + ort_pipeline = ORTDiffusionPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_num_images_per_prompt(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + self.assertEqual(pipeline.vae_scale_factor, 2) + self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) + self.assertEqual(pipeline.unet.config["in_channels"], 4) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + for num_images in [1, 3]: + outputs = pipeline(**inputs, num_images_per_prompt=num_images).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_compare_to_diffusers_pipeline(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 128, 128, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + if model_arch == "latent-consistency": + # Latent Consistency Model (LCM) doesn't support deterministic outputs beyond the first inference step + # TODO: Investigate why this is the case + inputs["num_inference_steps"] = 1 + + for output_type in ["latent", "np"]: + inputs["output_type"] = output_type + + ort_output = ort_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images + diffusers_output = diffusers_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images + + self.assertTrue( + np.allclose(ort_output, diffusers_output, atol=1e-4), + np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4), + ) + self.assertEqual(ort_pipeline.device, diffusers_pipeline.device) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) + ) + @require_torch_gpu + @pytest.mark.cuda_ep_test + @require_diffusers + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["ROCMExecutionProvider"]}) + ) + @require_torch_gpu + @require_ort_rocm + @pytest.mark.rocm_ep_test + @require_diffusers + def test_pipeline_on_rocm_ep(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + height, width, batch_size = 64, 32, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_callback(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 64, 128, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + class Callback: + def __init__(self): + self.has_been_called = False + self.number_of_steps = 0 + + def __call__(self, step: int, timestep: int, latents: np.ndarray) -> None: + self.has_been_called = True + self.number_of_steps += 1 + + ort_callback = Callback() + auto_callback = Callback() + + ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + # callback_steps=1 to trigger callback every step + ort_pipe(**inputs, callback=ort_callback, callback_steps=1) + auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + + self.assertTrue(ort_callback.has_been_called) + self.assertTrue(auto_callback.has_been_called) + self.assertEqual(auto_callback.number_of_steps, ort_callback.number_of_steps) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_shape(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + height, width, batch_size = 128, 64, 1 + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + for output_type in ["np", "pil", "latent"]: + inputs["output_type"] = output_type + outputs = pipeline(**inputs).images + if output_type == "pil": + self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) + elif output_type == "np": + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + else: + self.assertEqual( + outputs.shape, + (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_image_reproducibility(self, model_arch: str): + if model_arch in ["latent-consistency"]: + pytest.skip("Latent Consistency Model (LCM) doesn't support deterministic outputs") + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + for generator_framework in ["np", "pt"]: + ort_outputs_1 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_2 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_3 = pipeline(**inputs, generator=get_generator(generator_framework, SEED + 1)) + + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) + self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_negative_prompt(self, model_arch: str): + if model_arch in ["latent-consistency"]: + pytest.skip("Latent Consistency Model (LCM) does not support negative prompts") + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + negative_prompt = ["This is a negative prompt"] + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + image_slice_1 = pipeline( + **inputs, negative_prompt=negative_prompt, generator=np.random.RandomState(SEED) + ).images[0, -3:, -3:, -1] + prompt = inputs.pop("prompt") + + if model_arch == "stable-diffusion-xl": + ( + inputs["prompt_embeds"], + inputs["negative_prompt_embeds"], + inputs["pooled_prompt_embeds"], + inputs["negative_pooled_prompt_embeds"], + ) = pipeline._encode_prompt(prompt, 1, False, negative_prompt) + else: + text_ids = pipeline.tokenizer( + prompt, + max_length=pipeline.tokenizer.model_max_length, + padding="max_length", + return_tensors="np", + truncation=True, + ).input_ids + negative_text_ids = pipeline.tokenizer( + negative_prompt, + max_length=pipeline.tokenizer.model_max_length, + padding="max_length", + return_tensors="np", + truncation=True, + ).input_ids + inputs["prompt_embeds"] = pipeline.text_encoder(text_ids)[0] + inputs["negative_prompt_embeds"] = pipeline.text_encoder(negative_text_ids)[0] + + image_slice_2 = pipeline(**inputs, generator=np.random.RandomState(SEED)).images[0, -3:, -3:, -1] + + self.assertTrue(np.allclose(image_slice_1, image_slice_2, rtol=1e-1)) + + +class ORTPipelineForImage2ImageTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = ["stable-diffusion", "stable-diffusion-xl"] + + AUTOMODEL_CLASS = AutoPipelineForImage2Image + ORTMODEL_CLASS = ORTPipelineForImage2Image + + TASK = "image-to-image" + + def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="np"): + inputs = _generate_prompts(batch_size=batch_size) + + inputs["image"] = _generate_images( + height=height, width=width, batch_size=batch_size, channel=channel, input_type=input_type + ) + + inputs["strength"] = 0.75 + + return inputs + + @require_diffusers + def test_load_vanilla_model_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) + + self.assertIn( + f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) + ) + + @parameterized.expand(list(SUPPORTED_ARCHITECTURES)) + @require_diffusers + def test_ort_pipeline_class_dispatch(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + auto_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + # auto_pipeline = DiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) + # ort_pipeline = ORTDiffusionPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) + + # self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_num_images_per_prompt(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + self.assertEqual(pipeline.vae_scale_factor, 2) + self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) + self.assertEqual(pipeline.unet.config["in_channels"], 4) + + batch_size, height = 1, 32 + for width in [64, 32]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + for num_images in [1, 3]: + outputs = pipeline(**inputs, num_images_per_prompt=num_images).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) + ) + @require_torch_gpu + @pytest.mark.cuda_ep_test + @require_diffusers + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["ROCMExecutionProvider"]}) + ) + @require_torch_gpu + @require_ort_rocm + @pytest.mark.rocm_ep_test + @require_diffusers + def test_pipeline_on_rocm_ep(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_callback(self, model_arch: str): + if model_arch in ["stable-diffusion"]: + pytest.skip( + "Stable Diffusion For Img2Img doesn't behave as expected with callbacks (doesn't call it every step with callback_steps=1)" + ) + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs["num_inference_steps"] = 3 + + class Callback: + def __init__(self): + self.has_been_called = False + self.number_of_steps = 0 + + def __call__(self, step: int, timestep: int, latents: np.ndarray) -> None: + self.has_been_called = True + self.number_of_steps += 1 + + ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + ort_callback = Callback() + auto_callback = Callback() + # callback_steps=1 to trigger callback every step + ort_pipe(**inputs, callback=ort_callback, callback_steps=1) + auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + + self.assertTrue(ort_callback.has_been_called) + self.assertEqual(ort_callback.number_of_steps, auto_callback.number_of_steps) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_shape(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + height, width, batch_size = 32, 64, 1 + + for input_type in ["np", "pil", "pt"]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) + + for output_type in ["np", "pil", "latent"]: + inputs["output_type"] = output_type + outputs = pipeline(**inputs).images + if output_type == "pil": + self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) + elif output_type == "np": + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + else: + self.assertEqual( + outputs.shape, + (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_compare_to_diffusers_pipeline(self, model_arch: str): + pytest.skip("Img2Img models do not support support output reproducibility for some reason") + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 128, 128, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + ort_output = ort_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images + + diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + diffusers_output = diffusers_pipeline(**inputs, generator=torch.Generator().manual_seed(SEED)).images + + self.assertTrue(np.allclose(ort_output, diffusers_output, rtol=1e-2)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_image_reproducibility(self, model_arch: str): + pytest.skip("Img2Img models do not support support output reproducibility for some reason") + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + for generator_framework in ["np", "pt"]: + ort_outputs_1 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_2 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_3 = pipeline(**inputs, generator=get_generator(generator_framework, SEED + 1)) + + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) + self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + + +class ORTPipelineForInpaintingTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = ["stable-diffusion"] + + AUTOMODEL_CLASS = AutoPipelineForInpainting + ORTMODEL_CLASS = ORTPipelineForInpainting + + TASK = "inpainting" + + def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_type="pil"): + assert batch_size == 1, "Inpainting models only support batch_size=1" + assert input_type == "pil", "Inpainting models only support input_type='pil'" + + inputs = _generate_prompts(batch_size=batch_size) + + inputs["image"] = _generate_images( + height=height, width=width, batch_size=1, channel=channel, input_type="pil" + )[0] + inputs["mask_image"] = _generate_images( + height=height, width=width, batch_size=1, channel=channel, input_type="pil" + )[0] + + inputs["height"] = height + inputs["width"] = width + + return inputs + + @require_diffusers + def test_load_vanilla_model_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) + + self.assertIn( + f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_ort_pipeline_class_dispatch(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + auto_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + # auto_pipeline = DiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) + # ort_pipeline = ORTDiffusionPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) + + # self.assertEqual(ort_pipeline.auto_model_class, auto_pipeline.__class__) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_num_images_per_prompt(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + self.assertEqual(pipeline.vae_scale_factor, 2) + self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) + self.assertEqual(pipeline.unet.config["in_channels"], 4) + + batch_size, height = 1, 32 + for width in [64, 32]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + for num_images in [1, 3]: + outputs = pipeline(**inputs, num_images_per_prompt=num_images).images + self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) + ) + @require_torch_gpu + @pytest.mark.cuda_ep_test + @require_diffusers + def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["ROCMExecutionProvider"]}) + ) + @require_torch_gpu + @require_ort_rocm + @pytest.mark.rocm_ep_test + @require_diffusers + def test_pipeline_on_rocm_ep(self, test_name: str, model_arch: str, provider: str): + model_args = {"test_name": test_name, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) + outputs = pipeline(**inputs).images + # Verify model devices + self.assertEqual(pipeline.device.type.lower(), "cuda") + # Verify model outptus + self.assertIsInstance(outputs, np.ndarray) + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_callback(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 32, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + inputs["num_inference_steps"] = 3 + + class Callback: + def __init__(self): + self.has_been_called = False + self.number_of_steps = 0 + + def __call__(self, step: int, timestep: int, latents: np.ndarray) -> None: + self.has_been_called = True + self.number_of_steps += 1 + + ort_pipe = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + auto_pipe = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + ort_callback = Callback() + auto_callback = Callback() + # callback_steps=1 to trigger callback every step + ort_pipe(**inputs, callback=ort_callback, callback_steps=1) + auto_pipe(**inputs, callback=auto_callback, callback_steps=1) + + self.assertTrue(ort_callback.has_been_called) + self.assertEqual(ort_callback.number_of_steps, auto_callback.number_of_steps) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_shape(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + height, width, batch_size = 32, 64, 1 + + for input_type in ["pil"]: + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) + + for output_type in ["np", "pil", "latent"]: + inputs["output_type"] = output_type + outputs = pipeline(**inputs).images + if output_type == "pil": + self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) + elif output_type == "np": + self.assertEqual(outputs.shape, (batch_size, height, width, 3)) + else: + self.assertEqual( + outputs.shape, + (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_compare_to_diffusers_pipeline(self, model_arch: str): + if model_arch in ["stable-diffusion"]: + pytest.skip( + "Stable Diffusion For Inpainting fails, it was used to be compared to StableDiffusionPipeline for some reason which is the text-to-image variant" + ) + + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch]) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + latents_shape = ( + batch_size, + ort_pipeline.vae_decoder.config["latent_channels"], + height // ort_pipeline.vae_scale_factor, + width // ort_pipeline.vae_scale_factor, + ) + + np_latents = np.random.rand(*latents_shape).astype(np.float32) + torch_latents = torch.from_numpy(np_latents) + + ort_output = ort_pipeline(**inputs, latents=np_latents).images + diffusers_output = diffusers_pipeline(**inputs, latents=torch_latents).images + + self.assertTrue( + np.allclose(ort_output, diffusers_output, atol=1e-4), + np.testing.assert_allclose(ort_output, diffusers_output, atol=1e-4), + ) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_image_reproducibility(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + + height, width, batch_size = 64, 64, 1 + inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) + + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + + for generator_framework in ["np", "pt"]: + ort_outputs_1 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_2 = pipeline(**inputs, generator=get_generator(generator_framework, SEED)) + ort_outputs_3 = pipeline(**inputs, generator=get_generator(generator_framework, SEED + 1)) + + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) + self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + + +class ImageProcessorTest(unittest.TestCase): + def test_vae_image_processor_pt(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) + input_pt = torch.stack(_generate_images(height=8, width=8, batch_size=1, input_type="pt")) + input_np = to_np(input_pt) + + for output_type in ["np", "pil"]: + out = image_processor.postprocess(image_processor.preprocess(input_pt), output_type=output_type) + out_np = to_np(out) + in_np = (input_np * 255).round() if output_type == "pil" else input_np + self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) + + def test_vae_image_processor_np(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) + input_np = np.stack(_generate_images(height=8, width=8, input_type="np")) + for output_type in ["np", "pil"]: + out = image_processor.postprocess(image_processor.preprocess(input_np), output_type=output_type) + out_np = to_np(out) + in_np = (input_np * 255).round() if output_type == "pil" else input_np + self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) + + def test_vae_image_processor_pil(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) + input_pil = _generate_images(height=8, width=8, batch_size=1, input_type="pil") + + for output_type in ["np", "pil"]: + out = image_processor.postprocess(image_processor.preprocess(input_pil), output_type=output_type) + for i, o in zip(input_pil, out): + in_np = np.array(i) + out_np = to_np(out) if output_type == "pil" else (to_np(out) * 255).round() + self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 4b44acb38a..199b96342e 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -89,15 +89,8 @@ ORTModelForSpeechSeq2Seq, ORTModelForTokenClassification, ORTModelForVision2Seq, - ORTStableDiffusionPipeline, ) from optimum.onnxruntime.base import ORTDecoderForSeq2Seq, ORTEncoder -from optimum.onnxruntime.modeling_diffusion import ( - ORTModelTextEncoder, - ORTModelUnet, - ORTModelVaeDecoder, - ORTModelVaeEncoder, -) from optimum.onnxruntime.modeling_ort import ORTModel from optimum.pipelines import pipeline from optimum.utils import ( @@ -108,7 +101,24 @@ DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, logging, ) -from optimum.utils.testing_utils import grid_parameters, remove_directory, require_hf_token, require_ort_rocm +from optimum.utils.import_utils import is_diffusers_available +from optimum.utils.testing_utils import ( + grid_parameters, + remove_directory, + require_diffusers, + require_hf_token, + require_ort_rocm, +) + + +if is_diffusers_available(): + from optimum.onnxruntime.modeling_diffusion import ( + ORTModelTextEncoder, + ORTModelUnet, + ORTModelVaeDecoder, + ORTModelVaeEncoder, + ORTStableDiffusionPipeline, + ) logger = logging.get_logger() @@ -205,6 +215,7 @@ def test_load_seq2seq_model_from_empty_cache(self): with self.assertRaises(Exception): _ = ORTModelForSeq2SeqLM.from_pretrained(self.TINY_ONNX_SEQ2SEQ_MODEL_ID, local_files_only=True) + @require_diffusers def test_load_stable_diffusion_model_from_cache(self): _ = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) # caching @@ -218,6 +229,7 @@ def test_load_stable_diffusion_model_from_cache(self): self.assertIsInstance(model.unet, ORTModelUnet) self.assertIsInstance(model.config, Dict) + @require_diffusers def test_load_stable_diffusion_model_from_empty_cache(self): dirpath = os.path.join( default_cache_path, "models--" + self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID.replace("/", "--") @@ -300,6 +312,7 @@ def test_load_seq2seq_model_unknown_provider(self): with self.assertRaises(ValueError): ORTModelForSeq2SeqLM.from_pretrained(self.ONNX_SEQ2SEQ_MODEL_ID, provider="FooExecutionProvider") + @require_diffusers def test_load_stable_diffusion_model_from_hub(self): model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) self.assertIsInstance(model.text_encoder, ORTModelTextEncoder) @@ -308,6 +321,7 @@ def test_load_stable_diffusion_model_from_hub(self): self.assertIsInstance(model.unet, ORTModelUnet) self.assertIsInstance(model.config, Dict) + @require_diffusers @require_torch_gpu @pytest.mark.cuda_ep_test def test_load_stable_diffusion_model_cuda_provider(self): @@ -321,6 +335,7 @@ def test_load_stable_diffusion_model_cuda_provider(self): self.assertListEqual(model.vae_encoder.session.get_providers(), model.providers) self.assertEqual(model.device, torch.device("cuda:0")) + @require_diffusers @require_torch_gpu @require_ort_rocm @pytest.mark.rocm_ep_test @@ -335,6 +350,7 @@ def test_load_stable_diffusion_model_rocm_provider(self): self.assertListEqual(model.vae_encoder.session.get_providers(), model.providers) self.assertEqual(model.device, torch.device("cuda:0")) + @require_diffusers def test_load_stable_diffusion_model_cpu_provider(self): model = ORTStableDiffusionPipeline.from_pretrained( self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID, provider="CPUExecutionProvider" @@ -346,6 +362,7 @@ def test_load_stable_diffusion_model_cpu_provider(self): self.assertListEqual(model.vae_encoder.session.get_providers(), model.providers) self.assertEqual(model.device, torch.device("cpu")) + @require_diffusers def test_load_stable_diffusion_model_unknown_provider(self): with self.assertRaises(ValueError): ORTStableDiffusionPipeline.from_pretrained( @@ -478,6 +495,7 @@ def test_passing_session_options_seq2seq(self): self.assertEqual(model.encoder.session.get_session_options().intra_op_num_threads, 3) self.assertEqual(model.decoder.session.get_session_options().intra_op_num_threads, 3) + @require_diffusers def test_passing_session_options_stable_diffusion(self): options = onnxruntime.SessionOptions() options.intra_op_num_threads = 3 @@ -772,6 +790,7 @@ def test_seq2seq_model_on_rocm_ep_str(self): self.assertEqual(model.decoder_with_past.session.get_providers()[0], "ROCMExecutionProvider") self.assertListEqual(model.providers, ["ROCMExecutionProvider", "CPUExecutionProvider"]) + @require_diffusers @require_torch_gpu @pytest.mark.cuda_ep_test def test_passing_provider_options_stable_diffusion(self): @@ -810,6 +829,7 @@ def test_passing_provider_options_stable_diffusion(self): model.vae_encoder.session.get_provider_options()["CUDAExecutionProvider"]["do_copy_in_default_stream"], "0" ) + @require_diffusers def test_stable_diffusion_model_on_cpu(self): model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) cpu = torch.device("cpu") @@ -825,7 +845,7 @@ def test_stable_diffusion_model_on_cpu(self): self.assertEqual(model.vae_encoder.session.get_providers()[0], "CPUExecutionProvider") self.assertListEqual(model.providers, ["CPUExecutionProvider"]) - # test string device input for to() + @require_diffusers def test_stable_diffusion_model_on_cpu_str(self): model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) cpu = torch.device("cpu") @@ -841,6 +861,7 @@ def test_stable_diffusion_model_on_cpu_str(self): self.assertEqual(model.vae_encoder.session.get_providers()[0], "CPUExecutionProvider") self.assertListEqual(model.providers, ["CPUExecutionProvider"]) + @require_diffusers @require_torch_gpu @pytest.mark.cuda_ep_test def test_stable_diffusion_model_on_gpu(self): @@ -858,6 +879,7 @@ def test_stable_diffusion_model_on_gpu(self): self.assertEqual(model.vae_encoder.session.get_providers()[0], "CUDAExecutionProvider") self.assertListEqual(model.providers, ["CUDAExecutionProvider", "CPUExecutionProvider"]) + @require_diffusers @require_torch_gpu @require_ort_rocm @pytest.mark.rocm_ep_test @@ -876,6 +898,7 @@ def test_stable_diffusion_model_on_rocm_ep(self): self.assertEqual(model.vae_encoder.session.get_providers()[0], "ROCMExecutionProvider") self.assertListEqual(model.providers, ["ROCMExecutionProvider", "CPUExecutionProvider"]) + @require_diffusers @unittest.skipIf(get_gpu_count() <= 1, "this test requires multi-gpu") def test_stable_diffusion_model_on_gpu_id(self): model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) @@ -899,7 +922,7 @@ def test_stable_diffusion_model_on_gpu_id(self): self.assertEqual(model.vae_decoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") self.assertEqual(model.vae_encoder.session.get_provider_options()["CUDAExecutionProvider"]["device_id"], "1") - # test string device input for to() + @require_diffusers @require_torch_gpu @pytest.mark.cuda_ep_test def test_stable_diffusion_model_on_gpu_str(self): @@ -916,6 +939,7 @@ def test_stable_diffusion_model_on_gpu_str(self): self.assertEqual(model.vae_encoder.session.get_providers()[0], "CUDAExecutionProvider") self.assertListEqual(model.providers, ["CUDAExecutionProvider", "CPUExecutionProvider"]) + @require_diffusers @require_torch_gpu @require_ort_rocm @pytest.mark.rocm_ep_test @@ -975,6 +999,7 @@ def test_save_seq2seq_model_without_past(self): self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents) self.assertTrue(CONFIG_NAME in folder_contents) + @require_diffusers def test_save_stable_diffusion_model(self): with tempfile.TemporaryDirectory() as tmpdirname: model = ORTStableDiffusionPipeline.from_pretrained(self.TINY_ONNX_STABLE_DIFFUSION_MODEL_ID) @@ -1050,6 +1075,7 @@ def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool): os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") remove_directory(tmpdirname) + @require_diffusers def test_save_load_stable_diffusion_model_with_external_data(self): with tempfile.TemporaryDirectory() as tmpdirname: os.environ["FORCE_ONNX_EXTERNAL_DATA"] = "1" # force exporting small model with external data @@ -1180,6 +1206,7 @@ def test_push_seq2seq_model_with_external_data_to_hub(self): ) os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") + @require_diffusers @require_hf_token def test_push_stable_diffusion_model_with_external_data_to_hub(self): with tempfile.TemporaryDirectory() as tmpdirname: diff --git a/tests/onnxruntime/test_stable_diffusion_pipeline.py b/tests/onnxruntime/test_stable_diffusion_pipeline.py deleted file mode 100644 index 44cd22ffec..0000000000 --- a/tests/onnxruntime/test_stable_diffusion_pipeline.py +++ /dev/null @@ -1,562 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import random -import unittest -from typing import Dict - -import numpy as np -import PIL -import pytest -import torch -from diffusers import ( - OnnxStableDiffusionImg2ImgPipeline, - StableDiffusionPipeline, - StableDiffusionXLPipeline, -) -from diffusers.utils import load_image -from diffusers.utils.testing_utils import floats_tensor -from packaging.version import Version, parse -from parameterized import parameterized -from transformers.testing_utils import require_torch_gpu -from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin - -from optimum.onnxruntime import ( - ORTLatentConsistencyModelPipeline, - ORTStableDiffusionImg2ImgPipeline, - ORTStableDiffusionInpaintPipeline, - ORTStableDiffusionPipeline, - ORTStableDiffusionXLImg2ImgPipeline, - ORTStableDiffusionXLPipeline, -) -from optimum.onnxruntime.modeling_diffusion import ( - ORTModelTextEncoder, - ORTModelUnet, - ORTModelVaeDecoder, - ORTModelVaeEncoder, -) -from optimum.pipelines.diffusers.pipeline_utils import VaeImageProcessor -from optimum.utils.import_utils import _diffusers_version -from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_ort_rocm - - -if parse(_diffusers_version) > Version("0.21.4"): - from diffusers import LatentConsistencyModelPipeline - - -def _generate_inputs(batch_size=1): - inputs = { - "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size, - "num_inference_steps": 3, - "guidance_scale": 7.5, - "output_type": "np", - } - return inputs - - -def _create_image(height=128, width=128, batch_size=1, channel=3, input_type="pil"): - if input_type == "pil": - image = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ).resize((width, height)) - elif input_type == "np": - image = np.random.rand(height, width, channel) - elif input_type == "pt": - image = torch.rand((channel, height, width)) - - return [image] * batch_size - - -def to_np(image): - if isinstance(image[0], PIL.Image.Image): - return np.stack([np.array(i) for i in image], axis=0) - elif isinstance(image, torch.Tensor): - return image.cpu().numpy().transpose(0, 2, 3, 1) - return image - - -class ORTStableDiffusionPipelineBase(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion", - ] - ORTMODEL_CLASS = ORTStableDiffusionPipeline - TASK = "text-to-image" - - @require_diffusers - def test_load_vanilla_model_which_is_not_supported(self): - with self.assertRaises(Exception) as context: - _ = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES["bert"], export=True) - - self.assertIn( - f"does not appear to have a file named {self.ORTMODEL_CLASS.config_name}", str(context.exception) - ) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_num_images_per_prompt(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertEqual(pipeline.vae_scale_factor, 2) - self.assertEqual(pipeline.vae_decoder.config["latent_channels"], 4) - self.assertEqual(pipeline.unet.config["in_channels"], 4) - - batch_size, height = 1, 32 - for width in [64, 32]: - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - for num_images in [1, 3]: - outputs = pipeline(**inputs, num_images_per_prompt=num_images).images - self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3)) - - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["CUDAExecutionProvider"]}) - ) - @require_torch_gpu - @pytest.mark.cuda_ep_test - @require_diffusers - def test_pipeline_on_gpu(self, test_name: str, model_arch: str, provider: str): - model_args = {"test_name": test_name, "model_arch": model_arch} - self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) - height, width, batch_size = 32, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - outputs = pipeline(**inputs).images - # Verify model devices - self.assertEqual(pipeline.device.type.lower(), "cuda") - # Verify model outptus - self.assertIsInstance(outputs, np.ndarray) - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) - - @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "provider": ["ROCMExecutionProvider"]}) - ) - @require_torch_gpu - @require_ort_rocm - @pytest.mark.rocm_ep_test - @require_diffusers - def test_pipeline_on_rocm_ep(self, test_name: str, model_arch: str, provider: str): - model_args = {"test_name": test_name, "model_arch": model_arch} - self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[test_name], provider=provider) - height, width, batch_size = 32, 64, 1 - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - outputs = pipeline(**inputs).images - # Verify model devices - self.assertEqual(pipeline.device.type.lower(), "cuda") - # Verify model outptus - self.assertIsInstance(outputs, np.ndarray) - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_callback(self, model_arch: str): - def callback_fn(step: int, timestep: int, latents: np.ndarray) -> None: - callback_fn.has_been_called = True - callback_fn.number_of_steps += 1 - - pipe = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - callback_fn.has_been_called = False - callback_fn.number_of_steps = 0 - inputs = self.generate_inputs(height=64, width=64) - pipe(**inputs, callback=callback_fn, callback_steps=1) - self.assertTrue(callback_fn.has_been_called) - self.assertEqual(callback_fn.number_of_steps, inputs["num_inference_steps"]) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_shape(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - height, width, batch_size = 128, 64, 1 - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - - if self.TASK == "image-to-image": - input_types = ["np", "pil", "pt"] - elif self.TASK == "text-to-image": - input_types = ["np"] - else: - input_types = ["pil"] - - for input_type in input_types: - if self.TASK == "image-to-image": - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type) - else: - inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size) - for output_type in ["np", "pil", "latent"]: - inputs["output_type"] = output_type - outputs = pipeline(**inputs).images - if output_type == "pil": - self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width)) - elif output_type == "np": - self.assertEqual(outputs.shape, (batch_size, height, width, 3)) - else: - self.assertEqual( - outputs.shape, - (batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor), - ) - - def generate_inputs(self, height=128, width=128, batch_size=1): - inputs = _generate_inputs(batch_size=batch_size) - inputs["height"] = height - inputs["width"] = width - return inputs - - -class ORTStableDiffusionImg2ImgPipelineTest(ORTStableDiffusionPipelineBase): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion", - ] - ORTMODEL_CLASS = ORTStableDiffusionImg2ImgPipeline - TASK = "image-to-image" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_compare_diffusers_pipeline(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - height, width = 128, 128 - - inputs = self.generate_inputs(height=height, width=width) - inputs["prompt"] = "A painting of a squirrel eating a burger" - inputs["image"] = floats_tensor((1, 3, height, width), rng=random.Random(SEED)) - - ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - ort_output = ort_pipeline(**inputs, generator=np.random.RandomState(SEED)).images - - diffusers_onnx_pipeline = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(self.onnx_model_dirs[model_arch]) - diffusers_onnx_output = diffusers_onnx_pipeline(**inputs, generator=np.random.RandomState(SEED)).images - - self.assertTrue(np.allclose(ort_output, diffusers_onnx_output, atol=1e-1)) - - def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"): - inputs = _generate_inputs(batch_size=batch_size) - inputs["image"] = _create_image(height=height, width=width, batch_size=batch_size, input_type=input_type) - inputs["strength"] = 0.75 - return inputs - - -class ORTStableDiffusionPipelineTest(unittest.TestCase): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion", - ] - ORTMODEL_CLASS = ORTStableDiffusionPipeline - TASK = "text-to-image" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_compare_to_diffusers(self, model_arch: str): - ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - self.assertIsInstance(ort_pipeline.text_encoder, ORTModelTextEncoder) - self.assertIsInstance(ort_pipeline.vae_decoder, ORTModelVaeDecoder) - self.assertIsInstance(ort_pipeline.vae_encoder, ORTModelVaeEncoder) - self.assertIsInstance(ort_pipeline.unet, ORTModelUnet) - self.assertIsInstance(ort_pipeline.config, Dict) - - pipeline = StableDiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) - pipeline.safety_checker = None - batch_size, num_images_per_prompt, height, width = 1, 2, 64, 32 - - latents = ort_pipeline.prepare_latents( - batch_size * num_images_per_prompt, - ort_pipeline.unet.config["in_channels"], - height, - width, - dtype=np.float32, - generator=np.random.RandomState(0), - ) - - kwargs = { - "prompt": "sailing ship in storm by Leonardo da Vinci", - "num_inference_steps": 1, - "num_images_per_prompt": num_images_per_prompt, - "height": height, - "width": width, - "guidance_rescale": 0.1, - } - - for output_type in ["latent", "np"]: - ort_outputs = ort_pipeline(latents=latents, output_type=output_type, **kwargs).images - with torch.no_grad(): - outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images - - self.assertIsInstance(ort_outputs, np.ndarray) - # Compare model outputs - self.assertTrue(np.allclose(ort_outputs, outputs, atol=1e-4)) - # Compare model devices - self.assertEqual(pipeline.device, ort_pipeline.device) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_image_reproducibility(self, model_arch: str): - pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - inputs = _generate_inputs() - height, width = 64, 32 - np.random.seed(0) - ort_outputs_1 = pipeline(**inputs, height=height, width=width) - np.random.seed(0) - ort_outputs_2 = pipeline(**inputs, height=height, width=width) - ort_outputs_3 = pipeline(**inputs, height=height, width=width) - # Compare model outputs - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) - self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_negative_prompt(self, model_arch: str): - pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - inputs = _generate_inputs() - inputs["height"], inputs["width"] = 64, 32 - negative_prompt = ["This is a negative prompt"] - np.random.seed(0) - image_slice_1 = pipeline(**inputs, negative_prompt=negative_prompt).images[0, -3:, -3:, -1] - prompt = inputs.pop("prompt") - embeds = [] - for p in [prompt, negative_prompt]: - text_inputs = pipeline.tokenizer( - p, - padding="max_length", - max_length=pipeline.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_inputs = text_inputs["input_ids"].astype(pipeline.text_encoder.input_dtype.get("input_ids", np.int32)) - embeds.append(pipeline.text_encoder(text_inputs)[0]) - - inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds - np.random.seed(0) - image_slice_2 = pipeline(**inputs).images[0, -3:, -3:, -1] - self.assertTrue(np.allclose(image_slice_1, image_slice_2, atol=1e-4)) - - -class ORTStableDiffusionXLPipelineTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion-xl", - ] - ORTMODEL_CLASS = ORTStableDiffusionXLPipeline - TASK = "text-to-image" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_compare_to_diffusers(self, model_arch: str): - ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - self.assertIsInstance(ort_pipeline.text_encoder, ORTModelTextEncoder) - self.assertIsInstance(ort_pipeline.text_encoder_2, ORTModelTextEncoder) - self.assertIsInstance(ort_pipeline.vae_decoder, ORTModelVaeDecoder) - self.assertIsInstance(ort_pipeline.vae_encoder, ORTModelVaeEncoder) - self.assertIsInstance(ort_pipeline.unet, ORTModelUnet) - self.assertIsInstance(ort_pipeline.config, Dict) - - pipeline = StableDiffusionXLPipeline.from_pretrained(MODEL_NAMES[model_arch]) - batch_size, num_images_per_prompt, height, width = 2, 2, 64, 32 - latents = ort_pipeline.prepare_latents( - batch_size * num_images_per_prompt, - ort_pipeline.unet.config["in_channels"], - height, - width, - dtype=np.float32, - generator=np.random.RandomState(0), - ) - - kwargs = { - "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size, - "num_inference_steps": 1, - "num_images_per_prompt": num_images_per_prompt, - "height": height, - "width": width, - "guidance_rescale": 0.1, - } - - for output_type in ["latent", "np"]: - ort_outputs = ort_pipeline(latents=latents, output_type=output_type, **kwargs).images - self.assertIsInstance(ort_outputs, np.ndarray) - with torch.no_grad(): - outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images - - # Compare model outputs - self.assertTrue(np.allclose(ort_outputs, outputs, atol=1e-4)) - # Compare model devices - self.assertEqual(pipeline.device, ort_pipeline.device) - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_image_reproducibility(self, model_arch: str): - pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - inputs = _generate_inputs() - height, width = 64, 32 - np.random.seed(0) - ort_outputs_1 = pipeline(**inputs, height=height, width=width) - np.random.seed(0) - ort_outputs_2 = pipeline(**inputs, height=height, width=width) - ort_outputs_3 = pipeline(**inputs, height=height, width=width) - self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) - self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) - - -class ORTStableDiffusionInpaintPipelineTest(ORTStableDiffusionPipelineBase): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion", - ] - ORTMODEL_CLASS = ORTStableDiffusionInpaintPipeline - TASK = "inpainting" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_compare_diffusers_pipeline(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - diffusers_pipeline = self.ORTMODEL_CLASS.auto_model_class.from_pretrained(MODEL_NAMES[model_arch]) - height, width = 64, 64 - latents_shape = ( - 1, - ort_pipeline.vae_decoder.config["latent_channels"], - height // ort_pipeline.vae_scale_factor, - width // ort_pipeline.vae_scale_factor, - ) - inputs = self.generate_inputs(height=height, width=width) - - np_latents = np.random.rand(*latents_shape).astype(np.float32) - torch_latents = torch.from_numpy(np_latents) - - ort_outputs = ort_pipeline(**inputs, latents=np_latents).images - self.assertEqual(ort_outputs.shape, (1, height, width, 3)) - - diffusers_outputs = diffusers_pipeline(**inputs, latents=torch_latents).images - self.assertEqual(diffusers_outputs.shape, (1, height, width, 3)) - - self.assertTrue(np.allclose(ort_outputs, diffusers_outputs, atol=1e-4)) - - def generate_inputs(self, height=128, width=128, batch_size=1): - inputs = super(ORTStableDiffusionInpaintPipelineTest, self).generate_inputs(height, width) - inputs["image"] = _create_image(height=height, width=width, batch_size=1, input_type="pil")[0] - inputs["mask_image"] = _create_image(height=height, width=width, batch_size=1, input_type="pil")[0] - return inputs - - -class ORTStableDiffusionXLImg2ImgPipelineTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = [ - "stable-diffusion-xl", - ] - ORTMODEL_CLASS = ORTStableDiffusionXLImg2ImgPipeline - TASK = "image-to-image" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - def test_inference(self, model_arch: str): - model_args = {"test_name": model_arch, "model_arch": model_arch} - self._setup(model_args) - pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) - - height, width = 128, 128 - inputs = self.generate_inputs(height=height, width=width) - inputs["image"] = load_image( - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ).resize((width, height)) - output = pipeline(**inputs, generator=np.random.RandomState(0)).images[0, -3:, -3:, -1] - expected_slice = np.array([0.6515, 0.5405, 0.4858, 0.5632, 0.5174, 0.5681, 0.4948, 0.4253, 0.5080]) - - self.assertTrue(np.allclose(output.flatten(), expected_slice, atol=1e-1)) - - def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"): - inputs = _generate_inputs(batch_size=batch_size) - inputs["image"] = _create_image(height=height, width=width, batch_size=batch_size, input_type=input_type) - inputs["strength"] = 0.75 - return inputs - - -class ImageProcessorTest(unittest.TestCase): - def test_vae_image_processor_pt(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) - input_pt = torch.stack(_create_image(height=8, width=8, batch_size=1, input_type="pt")) - input_np = to_np(input_pt) - - for output_type in ["np", "pil"]: - out = image_processor.postprocess(image_processor.preprocess(input_pt), output_type=output_type) - out_np = to_np(out) - in_np = (input_np * 255).round() if output_type == "pil" else input_np - self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) - - def test_vae_image_processor_np(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) - input_np = np.stack(_create_image(height=8, width=8, input_type="np")) - for output_type in ["np", "pil"]: - out = image_processor.postprocess(image_processor.preprocess(input_np), output_type=output_type) - out_np = to_np(out) - in_np = (input_np * 255).round() if output_type == "pil" else input_np - self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) - - def test_vae_image_processor_pil(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) - input_pil = _create_image(height=8, width=8, batch_size=1, input_type="pil") - - for output_type in ["np", "pil"]: - out = image_processor.postprocess(image_processor.preprocess(input_pil), output_type=output_type) - for i, o in zip(input_pil, out): - in_np = np.array(i) - out_np = to_np(out) if output_type == "pil" else (to_np(out) * 255).round() - self.assertTrue(np.allclose(in_np, out_np, atol=1e-6)) - - -class ORTLatentConsistencyModelPipelineTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = [ - "latent-consistency", - ] - ORTMODEL_CLASS = ORTLatentConsistencyModelPipeline - TASK = "text-to-image" - - @parameterized.expand(SUPPORTED_ARCHITECTURES) - @require_diffusers - @unittest.skipIf( - parse(_diffusers_version) <= Version("0.21.4"), - "not supported with this diffusers version, needs diffusers>=v0.22.0", - ) - def test_compare_to_diffusers(self, model_arch: str): - ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) - self.assertIsInstance(ort_pipeline.text_encoder, ORTModelTextEncoder) - self.assertIsInstance(ort_pipeline.vae_decoder, ORTModelVaeDecoder) - self.assertIsInstance(ort_pipeline.vae_encoder, ORTModelVaeEncoder) - self.assertIsInstance(ort_pipeline.unet, ORTModelUnet) - self.assertIsInstance(ort_pipeline.config, Dict) - - pipeline = LatentConsistencyModelPipeline.from_pretrained(MODEL_NAMES[model_arch]) - batch_size, num_images_per_prompt, height, width = 2, 2, 64, 32 - latents = ort_pipeline.prepare_latents( - batch_size * num_images_per_prompt, - ort_pipeline.unet.config["in_channels"], - height, - width, - dtype=np.float32, - generator=np.random.RandomState(0), - ) - - kwargs = { - "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size, - "num_inference_steps": 1, - "num_images_per_prompt": num_images_per_prompt, - "height": height, - "width": width, - "guidance_scale": 8.5, - } - - for output_type in ["latent", "np"]: - ort_outputs = ort_pipeline(latents=latents, output_type=output_type, **kwargs).images - self.assertIsInstance(ort_outputs, np.ndarray) - with torch.no_grad(): - outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images - - # Compare model outputs - self.assertTrue(np.allclose(ort_outputs, outputs, atol=1e-4)) - # Compare model devices - self.assertEqual(pipeline.device, ort_pipeline.device)