Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vae image processor #1219

Merged
merged 16 commits into from
Sep 5, 2023
1 change: 1 addition & 0 deletions .github/workflows/test_onnxruntime.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install git+https://github.com/huggingface/diffusers
Copy link
Collaborator Author

@echarlaix echarlaix Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

waiting for diffusers next release before we can merge
edit : now comaptible

pip install .[tests,onnxruntime]
- name: Test with pytest
working-directory: tests
Expand Down
15 changes: 11 additions & 4 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
StableDiffusionXLImg2ImgPipeline,
)
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME
from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available
from huggingface_hub import snapshot_download
from transformers import CLIPFeatureExtractor, CLIPTokenizer
from transformers.file_utils import add_end_docstrings
Expand All @@ -45,6 +45,7 @@
from ..pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin
from ..pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin
from ..pipelines.diffusers.pipeline_utils import OptimumVaeImageProcessor
from ..utils import (
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
Expand Down Expand Up @@ -171,6 +172,8 @@ def __init__(
else:
self.vae_scale_factor = 8

self.image_processor = OptimumVaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

@staticmethod
def load_model(
vae_decoder_path: Union[str, Path],
Expand Down Expand Up @@ -578,6 +581,7 @@ def __init__(
tokenizer_2: Optional[CLIPTokenizer] = None,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
add_watermarker: Optional[bool] = None,
):
super().__init__(
vae_decoder_session=vae_decoder_session,
Expand All @@ -593,11 +597,14 @@ def __init__(
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()

# additional invisible-watermark dependency for SD XL
from ..pipelines.diffusers.watermark import StableDiffusionXLWatermarker
if add_watermarker:
from ..pipelines.diffusers.watermark import StableDiffusionXLWatermarker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this import be at the top of the file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No without making invisible-watermark a hard dependency, I added an additional verification in dd292d8 to check that invisible-watermark is installed if needed


self.watermark = StableDiffusionXLWatermarker()
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
Expand Down
19 changes: 10 additions & 9 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,32 +380,33 @@ def __call__(
image = latents
has_nsfw_concept = None
else:
latents = 1 / self.vae_decoder.config.get("scaling_factor", 0.18215) * latents
latents /= self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
# TODO: add image_processor
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
image, has_nsfw_concept = self.run_safety_checker(image)

if output_type == "pil":
image = self.numpy_to_pil(image)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if not return_dict:
return (image, has_nsfw_concept)

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

def run_safety_checker(self, image):
def run_safety_checker(self, image: np.ndarray):
if self.safety_checker is None:
has_nsfw_concept = None
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(
self.numpy_to_pil(image), return_tensors="np"
feature_extractor_input, return_tensors="np"
).pixel_values.astype(image.dtype)

images, has_nsfw_concept = [], []
for i in range(image.shape[0]):
image_i, has_nsfw_concept_i = self.safety_checker(
Expand Down
19 changes: 8 additions & 11 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import inspect
import logging
from typing import Callable, List, Optional, Union

import numpy as np
Expand All @@ -23,10 +22,6 @@
from diffusers.utils import deprecate

from .pipeline_stable_diffusion import StableDiffusionPipelineMixin
from .pipeline_utils import preprocess


logger = logging.getLogger(__name__)


class StableDiffusionImg2ImgPipelineMixin(StableDiffusionPipelineMixin):
Expand Down Expand Up @@ -178,7 +173,7 @@ def __call__(
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

image = preprocess(image)
image = self.image_processor.preprocess(image)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand Down Expand Up @@ -287,17 +282,19 @@ def __call__(
image = latents
has_nsfw_concept = None
else:
latents = 1 / scaling_factor * latents
latents /= scaling_factor
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
image, has_nsfw_concept = self.run_safety_checker(image)

if output_type == "pil":
image = self.numpy_to_pil(image)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
16 changes: 7 additions & 9 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import inspect
import logging
from typing import Callable, List, Optional, Union

import numpy as np
Expand All @@ -25,9 +24,6 @@
from .pipeline_stable_diffusion import StableDiffusionPipelineMixin


logger = logging.getLogger(__name__)


def prepare_mask_and_masked_image(image, mask, latents_shape, vae_scale_factor):
image = np.array(
image.convert("RGB").resize((latents_shape[1] * vae_scale_factor, latents_shape[0] * vae_scale_factor))
Expand Down Expand Up @@ -329,17 +325,19 @@ def __call__(
image = latents
has_nsfw_concept = None
else:
latents = 1 / scaling_factor * latents
latents /= scaling_factor
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
image, has_nsfw_concept = self.run_safety_checker(image)

if output_type == "pil":
image = self.numpy_to_pil(image)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

if not return_dict:
return (image, has_nsfw_concept)
Expand Down
13 changes: 5 additions & 8 deletions optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,15 @@ def __call__(
if output_type == "latent":
image = latents
else:
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
latents /= self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = self.watermark.apply_watermark(image)

# TODO: add image_processor
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))

if output_type == "pil":
image = self.numpy_to_pil(image)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)

if not return_dict:
return (image,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput

from .pipeline_utils import DiffusionPipelineMixin, preprocess, rescale_noise_cfg
from .pipeline_utils import DiffusionPipelineMixin, rescale_noise_cfg


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -400,7 +400,7 @@ def __call__(
)

# 3. Preprocess image
image = preprocess(image)
image = self.image_processor.preprocess(image)

# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps)
Expand Down Expand Up @@ -487,18 +487,15 @@ def __call__(
if output_type == "latent":
image = latents
else:
latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215)
latents /= self.vae_decoder.config.get("scaling_factor", 0.18215)
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = self.watermark.apply_watermark(image)

# TODO: add image_processor
image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1))

if output_type == "pil":
image = self.numpy_to_pil(image)
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)

if not return_dict:
return (image,)
Expand Down
Loading
Loading