Skip to content

Commit

Permalink
fix SD XL ONNX export for img2img task
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jul 17, 2023
1 parent a9ffe07 commit 4069ae2
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
9 changes: 6 additions & 3 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,24 +100,27 @@ def _get_submodels_for_export_stable_diffusion(
"""
Returns the components of a Stable Diffusion model.
"""
from diffusers import StableDiffusionXLPipeline
from diffusers import StableDiffusionXLImg2ImgPipeline

models_for_export = {}
if isinstance(pipeline, StableDiffusionXLPipeline):
if isinstance(pipeline, StableDiffusionXLImg2ImgPipeline):
projection_dim = pipeline.text_encoder_2.config.projection_dim
else:
projection_dim = pipeline.text_encoder.config.projection_dim

# Text encoder
if pipeline.text_encoder is not None:
if isinstance(pipeline, StableDiffusionXLPipeline):
if isinstance(pipeline, StableDiffusionXLImg2ImgPipeline):
pipeline.text_encoder.config.output_hidden_states = True
models_for_export["text_encoder"] = pipeline.text_encoder

# U-NET
# PyTorch does not support the ONNX export of torch.nn.functional.scaled_dot_product_attention
pipeline.unet.set_attn_processor(AttnProcessor())
pipeline.unet.config.text_encoder_projection_dim = projection_dim
# The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score`
# https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571
pipeline.unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False)
models_for_export["unet"] = pipeline.unet

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class TasksManager:
"audio-xvector": "AutoModelForAudioXVector",
"image-to-text": "AutoModelForVision2Seq",
"stable-diffusion": "StableDiffusionPipeline",
"stable-diffusion-xl": "StableDiffusionXLPipeline",
"stable-diffusion-xl": "StableDiffusionXLImg2ImgPipeline",
"zero-shot-image-classification": "AutoModelForZeroShotImageClassification",
"zero-shot-object-detection": "AutoModelForZeroShotObjectDetection",
}
Expand Down
3 changes: 2 additions & 1 deletion optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ def __init__(
self.task = task
self.vocab_size = normalized_config.vocab_size
self.text_encoder_projection_dim = normalized_config.text_encoder_projection_dim
self.time_ids = 5 if normalized_config.requires_aesthetics_score else 6
if random_batch_size_range:
low, high = random_batch_size_range
self.batch_size = random.randint(low, high)
Expand All @@ -634,7 +635,7 @@ def generate(self, input_name: str, framework: str = "pt"):
if input_name == "timestep":
return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework)

shape.append(self.text_encoder_projection_dim if input_name == "text_embeds" else 6)
shape.append(self.text_encoder_projection_dim if input_name == "text_embeds" else self.time_ids)
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework)


Expand Down

0 comments on commit 4069ae2

Please sign in to comment.