diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index 720cdf68a..506efb847 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -345,6 +345,11 @@ def export_models( failed_models = [] total_compilation_time = 0 compile_configs = {} + models_and_neuron_configs.pop("text_encoder") + models_and_neuron_configs.pop("text_encoder_2") + # models_and_neuron_configs.pop('unet') + models_and_neuron_configs.pop('vae_encoder') + models_and_neuron_configs.pop('vae_decoder') for i, model_name in enumerate(models_and_neuron_configs.keys()): logger.info(f"***** Compiling {model_name} *****") submodel, sub_neuron_config = models_and_neuron_configs[model_name] diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index f26bea862..30a310bf8 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -1058,15 +1058,15 @@ def forward( inputs = (sample, timestep, encoder_hidden_states) if timestep_cond is not None: inputs = inputs + (timestep_cond,) - if added_cond_kwargs is not None: - text_embeds = added_cond_kwargs.pop("text_embeds", None) - time_ids = added_cond_kwargs.pop("time_ids", None) - inputs = inputs + (text_embeds, time_ids) if mid_block_additional_residual is not None: inputs = inputs + (mid_block_additional_residual,) if down_block_additional_residuals is not None: for idx in range(len(down_block_additional_residuals)): inputs = inputs + (down_block_additional_residuals[idx],) + if added_cond_kwargs: + text_embeds = added_cond_kwargs.pop("text_embeds", None) + time_ids = added_cond_kwargs.pop("time_ids", None) + inputs = inputs + (text_embeds, time_ids) outputs = self.model(*inputs) return outputs @@ -1139,9 +1139,15 @@ def forward( controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, guess_mode: bool = False, + added_cond_kwargs: Optional[Dict] = None, return_dict: bool = True, ) -> Union["ControlNetOutput", Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + timestep = timestep.expand((sample.shape[0],)).to(torch.long) inputs = (sample, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale) + if added_cond_kwargs: + text_embeds = added_cond_kwargs.pop("text_embeds", None) + time_ids = added_cond_kwargs.pop("time_ids", None) + inputs += (text_embeds, time_ids) outputs = self.model(*inputs) if guess_mode: diff --git a/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py b/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py index b8d4824ad..dade75f07 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py @@ -15,6 +15,7 @@ """Override some diffusers API for NeuronStableDiffusionXLControlNetPipelineMixin""" import logging +import copy from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -32,7 +33,163 @@ class NeuronStableDiffusionXLControlNetPipelineMixin( StableDiffusionXLPipelineMixin, StableDiffusionXLControlNetPipeline -): +): + # Adapted from https://github.com/huggingface/diffusers/blob/v0.29.2/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L625 + # Replace class types with Neuron ones + def check_inputs( + self, + prompt, + prompt_2, + image, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Check `image` + if self.controlnet.__class__.__name__ == "NeuronControlNetModel": + self.check_image(image, prompt, prompt_embeds) + elif self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are not supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + else: + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if self.controlnet.__class__.__name__ == "NeuronControlNetModel": + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError( + "A single batch of multiple conditionings are not supported at the moment." + ) + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Adapted from https://github.com/huggingface/diffusers/blob/v0.30.0/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L899 + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + # Adapted from https://github.com/huggingface/diffusers/blob/1f81fbe274e67c843283e69eb8f00bb56f75ffc4/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L1001 def __call__( self, @@ -229,7 +386,7 @@ def __call__( elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + mult = len(controlnet.nets) if controlnet.__class__.__name__ == "NeuronMultiControlNetModel" else 1 control_guidance_start, control_guidance_end = ( mult * [control_guidance_start], mult * [control_guidance_end], @@ -237,22 +394,21 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, - prompt_2, - image, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - negative_pooled_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, + prompt=prompt, + prompt_2=prompt_2, + image=image, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale @@ -268,68 +424,74 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - device = self._execution_device - - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + if isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = torch.tensor([controlnet_conditioning_scale]) + if controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) global_pool_conditions = ( controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) + if controlnet.__class__.__name__ == "NeuronControlNetModel" else controlnet.nets[0].config.global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions + # TODO: Remove after the guess mode of ControlNet is supported + if guess_mode: + logger.info("Disabling the guess mode as this is not supported yet.") + guess_mode = False # 3.1 Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) + lora_scale = None + do_classifier_free_guidance = guidance_scale > 1.0 and ( + self.dynamic_batch_size or self.data_parallel_mode == "unet" + ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( - prompt, - prompt_2, - device, - num_images_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - negative_prompt_2, + prompt=prompt, + prompt_2=prompt_2, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, - lora_scale=text_encoder_lora_scale, + lora_scale=lora_scale, clip_skip=self.clip_skip, ) # 3.2 Encode ip_adapter_image + # TODO: support ip adapter if ip_adapter_image is not None or ip_adapter_image_embeds is not None: - image_embeds = self.prepare_ip_adapter_image_embeds( - ip_adapter_image, - ip_adapter_image_embeds, - device, - batch_size * num_images_per_prompt, - self.do_classifier_free_guidance, + logger.info( + "IP adapter is not supported yet, `ip_adapter_image` and `ip_adapter_image_embeds` will be ignored." ) # 4. Prepare image - if isinstance(controlnet, ControlNetModel): + height = self.vae_encoder.config.neuron["static_height"] + width = self.vae_encoder.config.neuron["static_width"] + if controlnet.__class__.__name__ == "NeuronControlNetModel": image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, + device=None, + dtype=None, + do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] - elif isinstance(controlnet, MultiControlNetModel): + elif controlnet.__class__.__name__ == "NeuronMultiControlNetModel": images = [] for image_ in image: @@ -339,9 +501,9 @@ def __call__( height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, + device=None, + dtype=None, + do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -354,7 +516,11 @@ def __call__( # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas + scheduler=self.scheduler, + num_inference_steps=num_inference_steps, + device=None, + timesteps=timesteps, + sigmas=sigmas, ) self._num_timesteps = len(timesteps) @@ -366,7 +532,6 @@ def __call__( height, width, prompt_embeds.dtype, - device, generator, latents, ) @@ -377,7 +542,7 @@ def __call__( guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim - ).to(device=device, dtype=latents.dtype) + ).to(device=None, dtype=latents.dtype) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -389,7 +554,7 @@ def __call__( 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) + controlnet_keep.append(keeps[0] if controlnet.__class__.__name__ == "NeuronControlNetModel" else keeps) # 7.2 Prepare added time ids & embeddings if isinstance(image, list): @@ -423,14 +588,12 @@ def __call__( else: negative_add_time_ids = add_time_ids - if self.do_classifier_free_guidance: + if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) - prompt_embeds = prompt_embeds.to(device) - add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -451,23 +614,18 @@ def __call__( num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] - is_unet_compiled = is_compiled_module(self.unet) - is_controlnet_compiled = is_compiled_module(self.controlnet) - is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 - if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: - torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference - if guess_mode and self.do_classifier_free_guidance: + if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -479,7 +637,7 @@ def __call__( else: control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds - controlnet_added_cond_kwargs = added_cond_kwargs + controlnet_added_cond_kwargs = copy.deepcopy(added_cond_kwargs) if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] @@ -500,7 +658,7 @@ def __call__( return_dict=False, ) - if guess_mode and self.do_classifier_free_guidance: + if guess_mode and do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -508,7 +666,9 @@ def __call__( mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: - added_cond_kwargs["image_embeds"] = image_embeds + logger.info( + "IP adapter is not supported yet, `ip_adapter_image` and `ip_adapter_image_embeds` will be ignored." + ) # predict the noise residual noise_pred = self.unet( @@ -516,15 +676,15 @@ def __call__( t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, - cross_attention_kwargs=self.cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, - return_dict=False, )[0] + import pdb + pdb.set_trace() # perform guidance - if self.do_classifier_free_guidance: + if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -550,22 +710,12 @@ def __call__( # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) if not output_type == "latent": - # make sure the VAE is in float32 mode, as it overflows in float16 - needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast - - if needs_upcasting: - self.upcast_vae() - latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - # unscale/denormalize the latents # denormalize with the mean and std if available and not None - has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None - has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + has_latents_mean = hasattr(self.vae_decoder.config, "latents_mean") and self.vae_decoder.config.latents_mean is not None + has_latents_std = hasattr(self.vae_decoder.config, "latents_std") and self.vae_decoder.config.latents_std is not None if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) @@ -573,15 +723,12 @@ def __call__( latents_std = ( torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) - latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + latents = latents * latents_std / getattr(self.vae_decoder.config, "scaling_factor", 0.18215) + latents_mean else: - latents = latents / self.vae.config.scaling_factor + latents = latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215) - image = self.vae.decode(latents, return_dict=False)[0] + image = self.vae_decoder(latents)[0] - # cast back to fp16 if needed - if needs_upcasting: - self.vae.to(dtype=torch.float16) else: image = latents @@ -592,8 +739,6 @@ def __call__( image = self.image_processor.postprocess(image, output_type=output_type) - # Offload all models - self.maybe_free_model_hooks() if not return_dict: return (image,)