From 61a123a1e083c584a333874b89828125171f7635 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 3 Dec 2023 03:31:47 -0500 Subject: [PATCH] A different way of handling multiple images passed to SVD. Previously when a list of 3 images [0, 1, 2] was used for a 6 frame video they were concated like this: [0, 1, 2, 0, 1, 2] now they are concated like this: [0, 0, 1, 1, 2, 2] --- comfy/model_base.py | 2 +- comfy/utils.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 786c9cf47ba..253ea66673b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -303,7 +303,7 @@ def extra_conds(self, **kwargs): if latent_image.shape[1:] != noise.shape[1:]: latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") - latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0]) + latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0]) out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) diff --git a/comfy/utils.py b/comfy/utils.py index 294bbb425ff..50557704736 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -239,6 +239,26 @@ def repeat_to_batch_size(tensor, batch_size): return tensor.repeat([math.ceil(batch_size / tensor.shape[0])] + [1] * (len(tensor.shape) - 1))[:batch_size] return tensor +def resize_to_batch_size(tensor, batch_size): + in_batch_size = tensor.shape[0] + if in_batch_size == batch_size: + return tensor + + if batch_size <= 1: + return tensor[:batch_size] + + output = torch.empty([batch_size] + list(tensor.shape)[1:], dtype=tensor.dtype, device=tensor.device) + if batch_size < in_batch_size: + scale = (in_batch_size - 1) / (batch_size - 1) + for i in range(batch_size): + output[i] = tensor[min(round(i * scale), in_batch_size - 1)] + else: + scale = in_batch_size / batch_size + for i in range(batch_size): + output[i] = tensor[min(math.floor((i + 0.5) * scale), in_batch_size - 1)] + + return output + def convert_sd_to(state_dict, dtype): keys = list(state_dict.keys()) for k in keys: