Skip to content

Commit

Permalink
Add vae image processor (#1219)
Browse files Browse the repository at this point in the history
* add vae image processor

* tests refactorization

* make watermark optional

* format

* add shape test

* add tests

* format

* raise error if watermark not installed but needed

* fix image resizing for diffusers < v0.21.0
  • Loading branch information
echarlaix authored Sep 5, 2023
1 parent 7f8e606 commit 3bac338
Show file tree
Hide file tree
Showing 8 changed files with 408 additions and 133 deletions.
21 changes: 17 additions & 4 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
StableDiffusionXLImg2ImgPipeline,
)
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME
from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available
from huggingface_hub import snapshot_download
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from transformers.file_utils import add_end_docstrings
Expand All @@ -45,6 +45,7 @@
from ..pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin
from ..pipelines.diffusers.pipeline_utils import VaeImageProcessor
from ..utils import (
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
Expand Down Expand Up @@ -171,6 +172,8 @@ def __init__(
else:
self.vae_scale_factor = 8

self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

@staticmethod
def load_model(
vae_decoder_path: Union[str, Path],
Expand Down Expand Up @@ -578,6 +581,7 @@ def __init__(
tokenizer_2: Optional[CLIPTokenizer] = None,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
add_watermarker: Optional[bool] = None,
):
super().__init__(
vae_decoder_session=vae_decoder_session,
Expand All @@ -594,10 +598,19 @@ def __init__(
model_save_dir=model_save_dir,
)

# additional invisible-watermark dependency for SD XL
from ..pipelines.diffusers.watermark import StableDiffusionXLWatermarker
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()

if add_watermarker:
if not is_invisible_watermark_available():
raise ImportError(
"`add_watermarker` requires invisible-watermark to be installed, which can be installed with `pip install invisible-watermark`."
)

self.watermark = StableDiffusionXLWatermarker()
from ..pipelines.diffusers.watermark import StableDiffusionXLWatermarker

self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
Expand Down
19 changes: 10 additions & 9 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,32 +380,33 @@ def __call__(
image = latents
has_nsfw_concept = None
else:
latents = 1 / self.vae_decoder.config.get("scaling_factor", 0.18215) * latents
latents /= self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
# TODO: add image_processor
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
image, has_nsfw_concept = self.run_safety_checker(image)

if output_type == "pil":
image = self.numpy_to_pil(image)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if not return_dict:
return (image, has_nsfw_concept)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

def run_safety_checker(self, image):
def run_safety_checker(self, image: np.ndarray):
if self.safety_checker is None:
has_nsfw_concept = None
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
feature_extractor_input, return_tensors="np"
).pixel_values.astype(image.dtype)

images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
Expand Down
19 changes: 8 additions & 11 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import inspect
import logging
from typing import Callable, List, Optional, Union

import numpy as np
Expand All @@ -23,10 +22,6 @@
from diffusers.utils import deprecate

from .pipeline_stable_diffusion import StableDiffusionPipelineMixin
from .pipeline_utils import preprocess


logger = logging.getLogger(__name__)


class StableDiffusionImg2ImgPipelineMixin(StableDiffusionPipelineMixin):
Expand Down Expand Up @@ -178,7 +173,7 @@ def __call__(
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

image = preprocess(image)
image = self.image_processor.preprocess(image)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand Down Expand Up @@ -287,17 +282,19 @@ def __call__(
image = latents
has_nsfw_concept = None
else:
latents = 1 / scaling_factor * latents
latents /= scaling_factor
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
image, has_nsfw_concept = self.run_safety_checker(image)

if output_type == "pil":
image = self.numpy_to_pil(image)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
16 changes: 7 additions & 9 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import inspect
import logging
from typing import Callable, List, Optional, Union

import numpy as np
Expand All @@ -25,9 +24,6 @@
from .pipeline_stable_diffusion import StableDiffusionPipelineMixin


logger = logging.getLogger(__name__)


def prepare_mask_and_masked_image(image, mask, latents_shape, vae_scale_factor):
image = np.array(
image.convert("RGB").resize((latents_shape[1] * vae_scale_factor, latents_shape[0] * vae_scale_factor))
Expand Down Expand Up @@ -329,17 +325,19 @@ def __call__(
image = latents
has_nsfw_concept = None
else:
latents = 1 / scaling_factor * latents
latents /= scaling_factor
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
image, has_nsfw_concept = self.run_safety_checker(image)

if output_type == "pil":
image = self.numpy_to_pil(image)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
13 changes: 5 additions & 8 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,15 @@ def __call__(
if output_type == "latent":
image = latents
else:
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
latents /= self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = self.watermark.apply_watermark(image)

# TODO: add image_processor
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))

if output_type == "pil":
image = self.numpy_to_pil(image)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)

if not return_dict:
return (image,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput

from .pipeline_utils import DiffusionPipelineMixin, preprocess, rescale_noise_cfg
from .pipeline_utils import DiffusionPipelineMixin, rescale_noise_cfg


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -400,7 +400,7 @@ def __call__(
)

# 3. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)

# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
Expand Down Expand Up @@ -487,18 +487,15 @@ def __call__(
if output_type == "latent":
image = latents
else:
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
latents /= self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = self.watermark.apply_watermark(image)

# TODO: add image_processor
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))

if output_type == "pil":
image = self.numpy_to_pil(image)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)

if not return_dict:
return (image,)
Expand Down
Loading

0 comments on commit 3bac338

Please sign in to comment.