Skip to content

Commit

Permalink
Add vae image processor (#421)
Browse files Browse the repository at this point in the history
* add vae image processor

* add test

* add optimum min version
  • Loading branch information
echarlaix authored Sep 8, 2023
1 parent 375942f commit f3bb7f2
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 38 deletions.
8 changes: 0 additions & 8 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class PreTrainedModel(OptimizedModel):
""",
)
class OVBaseModel(PreTrainedModel):
_AUTOMODELS_TO_TASKS = {cls_name: task for task, cls_name in TasksManager._TASKS_TO_AUTOMODELS.items()}
auto_model_class = None
export_feature = None

Expand Down Expand Up @@ -391,13 +390,6 @@ def _ensure_supported_device(self, device: str = None):
def forward(self, *args, **kwargs):
raise NotImplementedError

@classmethod
def _auto_model_to_task(cls, auto_model_class):
"""
Get the task corresponding to a class (for example AutoModelForXXX in transformers).
"""
return cls._AUTOMODELS_TO_TASKS[auto_model_class.__name__]

def can_generate(self) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`.
Expand Down
22 changes: 17 additions & 5 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
StableDiffusionXLPipeline,
)
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME
from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available
from huggingface_hub import snapshot_download
from openvino._offline_transformations import compress_model_transformation
from openvino.runtime import Core
Expand All @@ -42,6 +42,7 @@
from optimum.pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin
from optimum.pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin
from optimum.pipelines.diffusers.pipeline_utils import VaeImageProcessor
from optimum.utils import (
DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
Expand Down Expand Up @@ -106,6 +107,8 @@ def __init__(
else:
self.vae_scale_factor = 8

self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
self.scheduler = scheduler
Expand Down Expand Up @@ -687,12 +690,21 @@ class OVStableDiffusionXLPipelineBase(OVStableDiffusionPipelineBase):
auto_model_class = StableDiffusionXLPipeline
export_feature = "stable-diffusion-xl"

def __init__(self, *args, **kwargs):
def __init__(self, *args, add_watermarker: Optional[bool] = None, **kwargs):
super().__init__(*args, **kwargs)
# additional invisible-watermark dependency for SD XL
from optimum.pipelines.diffusers.watermark import StableDiffusionXLWatermarker

self.watermark = StableDiffusionXLWatermarker()
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()

if add_watermarker:
if not is_invisible_watermark_available():
raise ImportError(
"`add_watermarker` requires invisible-watermark to be installed, which can be installed with `pip install invisible-watermark`."
)
from optimum.pipelines.diffusers.watermark import StableDiffusionXLWatermarker

self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None


class OVStableDiffusionXLPipeline(OVStableDiffusionXLPipelineBase, StableDiffusionXLPipelineMixin):
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
assert False, "Error: Could not open '%s' due %s\n" % (filepath, error)

INSTALL_REQUIRE = [
"optimum>=1.10.0",
"optimum>=1.13.0",
"transformers>=4.20.0",
"datasets>=1.4.0",
"sentencepiece",
Expand All @@ -31,6 +31,7 @@
"torchaudio",
"rjieba",
"timm",
"invisible-watermark>=0.2.0",
]

QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"]
Expand All @@ -44,7 +45,7 @@
"openvino": ["openvino>=2023.0.0", "onnx", "onnxruntime"],
"nncf": ["nncf>=2.5.0", "openvino-dev>=2023.0.0"],
"ipex": ["transformers<4.32.0", "intel-extension-for-pytorch", "onnx"],
"diffusers": ["diffusers", "invisible-watermark>=0.2.0"],
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
}
Expand Down
104 changes: 81 additions & 23 deletions tests/openvino/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Dict

import numpy as np
import PIL
import torch
from diffusers import (
StableDiffusionPipeline,
Expand Down Expand Up @@ -60,17 +61,32 @@ def _generate_inputs(batch_size=1):
return inputs


def _create_image(height=128, width=128):
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
return image.resize((width, height))
def _create_image(height=128, width=128, batch_size=1, channel=3, input_type="pil"):
if input_type == "pil":
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
).resize((width, height))
elif input_type == "np":
image = np.random.rand(height, width, channel)
elif input_type == "pt":
image = torch.rand((channel, height, width))

return [image] * batch_size


def to_np(image):
if isinstance(image[0], PIL.Image.Image):
return np.stack([np.array(i) for i in image], axis=0)
elif isinstance(image, torch.Tensor):
return image.cpu().numpy().transpose(0, 2, 3, 1)
return image


class OVStableDiffusionPipelineBaseTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("stable-diffusion",)
MODEL_CLASS = OVStableDiffusionPipeline
TASK = "text-to-image"

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_num_images_per_prompt(self, model_arch: str):
Expand Down Expand Up @@ -104,6 +120,36 @@ def callback_fn(step: int, timestep: int, latents: np.ndarray) -> None:
self.assertTrue(callback_fn.has_been_called)
self.assertEqual(callback_fn.number_of_steps, inputs["num_inference_steps"])

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_shape(self, model_arch: str):
height, width, batch_size = 128, 64, 1
pipeline = self.MODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True)

if self.TASK == "image-to-image":
input_types = ["np", "pil", "pt"]
elif self.TASK == "text-to-image":
input_types = ["np"]
else:
input_types = ["pil"]

for input_type in input_types:
if self.TASK == "image-to-image":
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size, input_type=input_type)
else:
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
for output_type in ["np", "pil", "latent"]:
inputs["output_type"] = output_type
outputs = pipeline(**inputs).images
if output_type == "pil":
self.assertEqual((len(outputs), outputs[0].height, outputs[0].width), (batch_size, height, width))
elif output_type == "np":
self.assertEqual(outputs.shape, (batch_size, height, width, 3))
else:
self.assertEqual(
outputs.shape,
(batch_size, 4, height // pipeline.vae_scale_factor, width // pipeline.vae_scale_factor),
)

def generate_inputs(self, height=128, width=128, batch_size=1):
inputs = _generate_inputs(batch_size)
inputs["height"] = height
Expand All @@ -115,13 +161,16 @@ class OVStableDiffusionImg2ImgPipelineTest(OVStableDiffusionPipelineBaseTest):
SUPPORTED_ARCHITECTURES = ("stable-diffusion",)
MODEL_CLASS = OVStableDiffusionImg2ImgPipeline
ORT_MODEL_CLASS = ORTStableDiffusionImg2ImgPipeline
TASK = "image-to-image"

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_diffusers_pipeline(self, model_arch: str):
model_id = MODEL_NAMES[model_arch]
pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True)
inputs = self.generate_inputs()
height, width, batch_size = 128, 128, 1
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
inputs["prompt"] = "A painting of a squirrel eating a burger"
inputs["image"] = floats_tensor((batch_size, 3, height, width), rng=random.Random(SEED))
np.random.seed(0)
output = pipeline(**inputs).images[0, -3:, -3:, -1]
# https://github.com/huggingface/diffusers/blob/v0.17.1/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py#L71
Expand All @@ -139,16 +188,17 @@ def test_num_images_per_prompt_static_model(self, model_arch: str):
outputs = pipeline(**inputs, num_images_per_prompt=num_images, generator=np.random.RandomState(0)).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))

def generate_inputs(self, height=128, width=128, batch_size=1):
def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"):
inputs = _generate_inputs(batch_size)
inputs["image"] = floats_tensor((batch_size, 3, height, width), rng=random.Random(SEED))
inputs["image"] = _create_image(height=height, width=width, batch_size=batch_size, input_type=input_type)
inputs["strength"] = 0.75
return inputs


class OVStableDiffusionPipelineTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("stable-diffusion",)
MODEL_CLASS = OVStableDiffusionPipeline
TASK = "text-to-image"

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_diffusers(self, model_arch: str):
Expand Down Expand Up @@ -247,6 +297,7 @@ class OVStableDiffusionInpaintPipelineTest(OVStableDiffusionPipelineBaseTest):
SUPPORTED_ARCHITECTURES = ("stable-diffusion",)
MODEL_CLASS = OVStableDiffusionInpaintPipeline
ORT_MODEL_CLASS = ORTStableDiffusionInpaintPipeline
TASK = "inpaint"

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_diffusers_pipeline(self, model_arch: str):
Expand All @@ -262,6 +313,17 @@ def test_compare_diffusers_pipeline(self, model_arch: str):
generator=np.random.RandomState(0),
)
inputs = self.generate_inputs(height=height, width=width)

inputs["image"] = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
).resize((width, height))

inputs["mask_image"] = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
).resize((width, height))

outputs = pipeline(**inputs, latents=latents).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))

Expand All @@ -285,16 +347,8 @@ def test_num_images_per_prompt_static_model(self, model_arch: str):

def generate_inputs(self, height=128, width=128, batch_size=1):
inputs = super(OVStableDiffusionInpaintPipelineTest, self).generate_inputs(height, width, batch_size)
inputs["image"] = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
).resize((width, height))

inputs["mask_image"] = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
).resize((width, height))

inputs["image"] = _create_image(height=height, width=width, batch_size=1, input_type="pil")[0]
inputs["mask_image"] = _create_image(height=height, width=width, batch_size=1, input_type="pil")[0]
return inputs


Expand All @@ -303,6 +357,7 @@ class OVtableDiffusionXLPipelineTest(unittest.TestCase):
MODEL_CLASS = OVStableDiffusionXLPipeline
ORT_MODEL_CLASS = ORTStableDiffusionXLPipeline
PT_MODEL_CLASS = StableDiffusionXLPipeline
TASK = "text-to-image"

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_diffusers(self, model_arch: str):
Expand Down Expand Up @@ -387,6 +442,7 @@ class OVStableDiffusionXLImg2ImgPipelineTest(unittest.TestCase):
MODEL_CLASS = OVStableDiffusionXLImg2ImgPipeline
ORT_MODEL_CLASS = ORTStableDiffusionXLImg2ImgPipeline
PT_MODEL_CLASS = StableDiffusionXLImg2ImgPipeline
TASK = "image-to-image"

def test_inference(self):
model_id = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
Expand All @@ -396,10 +452,12 @@ def test_inference(self):
pipeline.save_pretrained(tmp_dir)
pipeline = self.MODEL_CLASS.from_pretrained(tmp_dir)

inputs = self.generate_inputs()
batch_size, height, width = 1, 128, 128
inputs = self.generate_inputs(height=height, width=width, batch_size=batch_size)
inputs["image"] = floats_tensor((batch_size, 3, height, width), rng=random.Random(SEED))
np.random.seed(0)
output = pipeline(**inputs).images[0, -3:, -3:, -1]
expected_slice = np.array([0.5675, 0.5108, 0.4758, 0.5280, 0.5080, 0.5473, 0.4789, 0.4286, 0.4861])
expected_slice = np.array([0.5683, 0.5121, 0.4767, 0.5253, 0.5072, 0.5462, 0.4766, 0.4279, 0.4855])
self.assertTrue(np.allclose(output.flatten(), expected_slice, atol=1e-3))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
Expand All @@ -413,8 +471,8 @@ def test_num_images_per_prompt_static_model(self, model_arch: str):
outputs = pipeline(**inputs, num_images_per_prompt=num_images, generator=np.random.RandomState(0)).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))

def generate_inputs(self, height=128, width=128, batch_size=1):
def generate_inputs(self, height=128, width=128, batch_size=1, input_type="np"):
inputs = _generate_inputs(batch_size)
inputs["image"] = floats_tensor((batch_size, 3, height, width), rng=random.Random(SEED))
inputs["image"] = _create_image(height=height, width=width, batch_size=batch_size, input_type=input_type)
inputs["strength"] = 0.75
return inputs

0 comments on commit f3bb7f2

Please sign in to comment.