diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index c1bee9a4da..24a809a977 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -100,17 +100,17 @@ def _get_submodels_for_export_stable_diffusion( """ Returns the components of a Stable Diffusion model. """ - from diffusers import StableDiffusionXLPipeline + from diffusers import StableDiffusionXLImg2ImgPipeline models_for_export = {} - if isinstance(pipeline, StableDiffusionXLPipeline): + if isinstance(pipeline, StableDiffusionXLImg2ImgPipeline): projection_dim = pipeline.text_encoder_2.config.projection_dim else: projection_dim = pipeline.text_encoder.config.projection_dim # Text encoder if pipeline.text_encoder is not None: - if isinstance(pipeline, StableDiffusionXLPipeline): + if isinstance(pipeline, StableDiffusionXLImg2ImgPipeline): pipeline.text_encoder.config.output_hidden_states = True models_for_export["text_encoder"] = pipeline.text_encoder @@ -118,6 +118,9 @@ def _get_submodels_for_export_stable_diffusion( # PyTorch does not support the ONNX export of torch.nn.functional.scaled_dot_product_attention pipeline.unet.set_attn_processor(AttnProcessor()) pipeline.unet.config.text_encoder_projection_dim = projection_dim + # The U-NET time_ids inputs shapes depends on the value of `requires_aesthetics_score` + # https://github.com/huggingface/diffusers/blob/v0.18.2/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L571 + pipeline.unet.config.requires_aesthetics_score = getattr(pipeline.config, "requires_aesthetics_score", False) models_for_export["unet"] = pipeline.unet # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 2f3c432968..2f1654bbbd 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -171,7 +171,7 @@ class TasksManager: "audio-xvector": "AutoModelForAudioXVector", "image-to-text": "AutoModelForVision2Seq", "stable-diffusion": "StableDiffusionPipeline", - "stable-diffusion-xl": "StableDiffusionXLPipeline", + "stable-diffusion-xl": "StableDiffusionXLImg2ImgPipeline", "zero-shot-image-classification": "AutoModelForZeroShotImageClassification", "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", } diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index d88f21fd2b..d062a29d7e 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -622,6 +622,7 @@ def __init__( self.task = task self.vocab_size = normalized_config.vocab_size self.text_encoder_projection_dim = normalized_config.text_encoder_projection_dim + self.time_ids = 5 if normalized_config.requires_aesthetics_score else 6 if random_batch_size_range: low, high = random_batch_size_range self.batch_size = random.randint(low, high) @@ -634,7 +635,7 @@ def generate(self, input_name: str, framework: str = "pt"): if input_name == "timestep": return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework) - shape.append(self.text_encoder_projection_dim if input_name == "text_embeds" else 6) + shape.append(self.text_encoder_projection_dim if input_name == "text_embeds" else self.time_ids) return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework)