From c32583e7c7a8d95129c0a53fe0356f990037224e Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Wed, 21 Aug 2024 13:18:39 +0000 Subject: [PATCH] pipeline done --- .../pipelines/diffusers/pipeline_controlnet_sd_xl.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py b/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py index dade75f07..854bad722 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py @@ -647,6 +647,14 @@ def __call__( controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] + # Duplicate inputs for ddp + t = torch.tensor([t] * 2) if self.data_parallel_mode == "unet" else t + cond_scale = ( + torch.tensor([cond_scale]).repeat(2) + if self.data_parallel_mode == "unet" + else torch.tensor(cond_scale) + ) + down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, @@ -680,8 +688,6 @@ def __call__( mid_block_additional_residual=mid_block_res_sample, added_cond_kwargs=added_cond_kwargs, )[0] - import pdb - pdb.set_trace() # perform guidance if do_classifier_free_guidance: