diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index f982d648ce4..9a63202ab07 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -27,7 +27,6 @@ def __init__( model_channels, hint_channels, num_res_blocks, - attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, @@ -52,6 +51,7 @@ def __init__( use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, + transformer_depth_output=None, device=None, operations=comfy.ops, ): @@ -79,10 +79,7 @@ def __init__( self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels - if isinstance(transformer_depth, int): - transformer_depth = len(channel_mult) * [transformer_depth] - if transformer_depth_middle is None: - transformer_depth_middle = transformer_depth[-1] + if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: @@ -90,18 +87,16 @@ def __init__( raise ValueError("provide num_res_blocks either as an int (globally constant) or " "as a list/tuple (per-level) with the same length as channel_mult") self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") - self.attention_resolutions = attention_resolutions + transformer_depth = transformer_depth[:] + self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample @@ -180,11 +175,14 @@ def __init__( dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, - operations=operations + dtype=self.dtype, + device=device, + operations=operations, ) ] ch = mult * model_channels - if ds in attention_resolutions: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -201,9 +199,9 @@ def __init__( if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: layers.append( SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, operations=operations + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -223,11 +221,13 @@ def __init__( use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, + dtype=self.dtype, + device=device, operations=operations ) if resblock_updown else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations + ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations ) ) ) @@ -245,7 +245,7 @@ def __init__( if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( + mid_block = [ ResBlock( ch, time_embed_dim, @@ -253,12 +253,15 @@ def __init__( dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, operations=operations - ), - SpatialTransformer( # always uses a self-attn + )] + if transformer_depth_middle >= 0: + mid_block += [SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, operations=operations + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ), ResBlock( ch, @@ -267,9 +270,11 @@ def __init__( dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + device=device, operations=operations - ), - ) + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self.middle_block_out = self.make_zero_conv(ch, operations=operations) self._feature_size += ch diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d86557646f1..e79b89c0f0d 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -36,6 +36,8 @@ def __call__(self, parser, namespace, values, option_string=None): parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") +parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.") + parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.") parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.") parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).") diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index e085186ef68..9e2e03d7238 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,5 +1,5 @@ -from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils -from .utils import load_torch_file, transformers_convert +from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils +from .utils import load_torch_file, transformers_convert, common_upscale import os import torch import contextlib @@ -7,6 +7,18 @@ import comfy.ops import comfy.model_patcher import comfy.model_management +import comfy.utils + +def clip_preprocess(image, size=224): + mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) + std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) + scale = (size / min(image.shape[1], image.shape[2])) + image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3,1,1])) / std.view([3,1,1]) class ClipVisionModel(): def __init__(self, json_config): @@ -23,25 +35,12 @@ def __init__(self, json_config): self.model.to(self.dtype) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) - self.processor = CLIPImageProcessor(crop_size=224, - do_center_crop=True, - do_convert_rgb=True, - do_normalize=True, - do_resize=True, - image_mean=[ 0.48145466,0.4578275,0.40821073], - image_std=[0.26862954,0.26130258,0.27577711], - resample=3, #bicubic - size=224) - def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) def encode_image(self, image): - img = torch.clip((255. * image), 0, 255).round().int() - img = list(map(lambda a: a, img)) - inputs = self.processor(images=img, return_tensors="pt") comfy.model_management.load_model_gpu(self.patcher) - pixel_values = inputs['pixel_values'].to(self.load_device) + pixel_values = clip_preprocess(image.to(self.load_device)) if self.dtype != torch.float32: precision_scope = torch.autocast diff --git a/comfy/conds.py b/comfy/conds.py new file mode 100644 index 00000000000..1e3111baff8 --- /dev/null +++ b/comfy/conds.py @@ -0,0 +1,64 @@ +import enum +import torch +import math +import comfy.utils + + +def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) + return abs(a*b) // math.gcd(a, b) + +class CONDRegular: + def __init__(self, cond): + self.cond = cond + + def _copy_with(self, cond): + return self.__class__(cond) + + def process_cond(self, batch_size, device, **kwargs): + return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) + + def can_concat(self, other): + if self.cond.shape != other.cond.shape: + return False + return True + + def concat(self, others): + conds = [self.cond] + for x in others: + conds.append(x.cond) + return torch.cat(conds) + +class CONDNoiseShape(CONDRegular): + def process_cond(self, batch_size, device, area, **kwargs): + data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) + + +class CONDCrossAttn(CONDRegular): + def can_concat(self, other): + s1 = self.cond.shape + s2 = other.cond.shape + if s1 != s2: + if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen + return False + + mult_min = lcm(s1[1], s2[1]) + diff = mult_min // min(s1[1], s2[1]) + if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much + return False + return True + + def concat(self, others): + conds = [self.cond] + crossattn_max_len = self.cond.shape[1] + for x in others: + c = x.cond + crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) + conds.append(c) + + out = [] + for c in conds: + if c.shape[1] < crossattn_max_len: + c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result + out.append(c) + return torch.cat(out) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index f1355e64e9d..09868158287 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -132,6 +132,7 @@ def __init__(self, control_model, global_average_pooling=False, device=None): self.control_model = control_model self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling + self.model_sampling_current = None def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -156,10 +157,13 @@ def get_control(self, x_noisy, t, cond, batched_number): context = cond['c_crossattn'] - y = cond.get('c_adm', None) + y = cond.get('y', None) if y is not None: y = y.to(self.control_model.dtype) - control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) + timestep = self.model_sampling_current.timestep(t) + x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) + + control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): @@ -172,6 +176,14 @@ def get_models(self): out.append(self.control_model_wrapped) return out + def pre_run(self, model, percent_to_timestep_function): + super().pre_run(model, percent_to_timestep_function) + self.model_sampling_current = model.model_sampling + + def cleanup(self): + self.model_sampling_current = None + super().cleanup() + class ControlLoraOps: class Linear(torch.nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True, diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 58e030d0439..1a7a8392902 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -852,6 +852,12 @@ def marginal_lambda(self, t): log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) return log_mean_coeff - log_std +def predict_eps_sigma(model, input, sigma_in, **kwargs): + sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1)) + input = input * ((sigma ** 2 + 1.0) ** 0.5) + return (input - model(input, sigma_in, **kwargs)) / sigma + + def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): timesteps = sigmas.clone() if sigmas[-1] == 0: @@ -874,14 +880,14 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex model_type = "noise" model_fn = model_wrapper( - model.predict_eps_sigma, + lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs), ns, model_type=model_type, guidance_type="uncond", model_kwargs=extra_args, ) - order = min(3, len(timesteps) - 1) + order = min(3, len(timesteps) - 2) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) x /= ns.marginal_alpha(timesteps[-1]) diff --git a/comfy/k_diffusion/external.py b/comfy/k_diffusion/external.py deleted file mode 100644 index 953d3db2c9f..00000000000 --- a/comfy/k_diffusion/external.py +++ /dev/null @@ -1,194 +0,0 @@ -import math - -import torch -from torch import nn - -from . import sampling, utils - - -class VDenoiser(nn.Module): - """A v-diffusion-pytorch model wrapper for k-diffusion.""" - - def __init__(self, inner_model): - super().__init__() - self.inner_model = inner_model - self.sigma_data = 1. - - def get_scalings(self, sigma): - c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - return c_skip, c_out, c_in - - def sigma_to_t(self, sigma): - return sigma.atan() / math.pi * 2 - - def t_to_sigma(self, t): - return (t * math.pi / 2).tan() - - def loss(self, input, noise, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - noised_input = input + noise * utils.append_dims(sigma, input.ndim) - model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) - target = (input - c_skip * noised_input) / c_out - return (model_output - target).pow(2).flatten(1).mean(1) - - def forward(self, input, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip - - -class DiscreteSchedule(nn.Module): - """A mapping between continuous noise levels (sigmas) and a list of discrete noise - levels.""" - - def __init__(self, sigmas, quantize): - super().__init__() - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) - self.quantize = quantize - - @property - def sigma_min(self): - return self.sigmas[0] - - @property - def sigma_max(self): - return self.sigmas[-1] - - def get_sigmas(self, n=None): - if n is None: - return sampling.append_zero(self.sigmas.flip(0)) - t_max = len(self.sigmas) - 1 - t = torch.linspace(t_max, 0, n, device=self.sigmas.device) - return sampling.append_zero(self.t_to_sigma(t)) - - def sigma_to_discrete_timestep(self, sigma): - log_sigma = sigma.log() - dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape) - - def sigma_to_t(self, sigma, quantize=None): - quantize = self.quantize if quantize is None else quantize - if quantize: - return self.sigma_to_discrete_timestep(sigma) - log_sigma = sigma.log() - dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] - w = (low - log_sigma) / (low - high) - w = w.clamp(0, 1) - t = (1 - w) * low_idx + w * high_idx - return t.view(sigma.shape) - - def t_to_sigma(self, t): - t = t.float() - low_idx = t.floor().long() - high_idx = t.ceil().long() - w = t-low_idx if t.device.type == 'mps' else t.frac() - log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] - return log_sigma.exp() - - def predict_eps_discrete_timestep(self, input, t, **kwargs): - if t.dtype != torch.int64 and t.dtype != torch.int32: - t = t.round() - sigma = self.t_to_sigma(t) - input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) - return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) - - def predict_eps_sigma(self, input, sigma, **kwargs): - input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) - return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) - -class DiscreteEpsDDPMDenoiser(DiscreteSchedule): - """A wrapper for discrete schedule DDPM models that output eps (the predicted - noise).""" - - def __init__(self, model, alphas_cumprod, quantize): - super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) - self.inner_model = model - self.sigma_data = 1. - - def get_scalings(self, sigma): - c_out = -sigma - c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - return c_out, c_in - - def get_eps(self, *args, **kwargs): - return self.inner_model(*args, **kwargs) - - def loss(self, input, noise, sigma, **kwargs): - c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - noised_input = input + noise * utils.append_dims(sigma, input.ndim) - eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) - return (eps - noise).pow(2).flatten(1).mean(1) - - def forward(self, input, sigma, **kwargs): - c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) - return input + eps * c_out - - -class OpenAIDenoiser(DiscreteEpsDDPMDenoiser): - """A wrapper for OpenAI diffusion models.""" - - def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'): - alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32) - super().__init__(model, alphas_cumprod, quantize=quantize) - self.has_learned_sigmas = has_learned_sigmas - - def get_eps(self, *args, **kwargs): - model_output = self.inner_model(*args, **kwargs) - if self.has_learned_sigmas: - return model_output.chunk(2, dim=1)[0] - return model_output - - -class CompVisDenoiser(DiscreteEpsDDPMDenoiser): - """A wrapper for CompVis diffusion models.""" - - def __init__(self, model, quantize=False, device='cpu'): - super().__init__(model, model.alphas_cumprod, quantize=quantize) - - def get_eps(self, *args, **kwargs): - return self.inner_model.apply_model(*args, **kwargs) - - -class DiscreteVDDPMDenoiser(DiscreteSchedule): - """A wrapper for discrete schedule DDPM models that output v.""" - - def __init__(self, model, alphas_cumprod, quantize): - super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) - self.inner_model = model - self.sigma_data = 1. - - def get_scalings(self, sigma): - c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 - return c_skip, c_out, c_in - - def get_v(self, *args, **kwargs): - return self.inner_model(*args, **kwargs) - - def loss(self, input, noise, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - noised_input = input + noise * utils.append_dims(sigma, input.ndim) - model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) - target = (input - c_skip * noised_input) / c_out - return (model_output - target).pow(2).flatten(1).mean(1) - - def forward(self, input, sigma, **kwargs): - c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] - return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip - - -class CompVisVDenoiser(DiscreteVDDPMDenoiser): - """A wrapper for CompVis diffusion models that output v.""" - - def __init__(self, model, quantize=False, device='cpu'): - super().__init__(model, model.alphas_cumprod, quantize=quantize) - - def get_v(self, x, t, cond, **kwargs): - return self.inner_model.apply_model(x, t, cond) diff --git a/comfy/ldm/models/diffusion/__init__.py b/comfy/ldm/models/diffusion/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py deleted file mode 100644 index 433d48e3064..00000000000 --- a/comfy/ldm/models/diffusion/ddim.py +++ /dev/null @@ -1,418 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm - -from comfy.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor - - -class DDIMSampler(object): - def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - self.device = device - self.parameterization = kwargs.get("parameterization", "eps") - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != self.device: - attr = attr.float().to(self.device) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) - self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose) - - def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True): - self.ddim_timesteps = torch.tensor(ddim_timesteps) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) - - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample_custom(self, - ddim_timesteps, - conditioning=None, - callback=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - denoise_function=None, - extra_args=None, - to_zero=True, - end_step=None, - disable_pbar=False, - **kwargs - ): - self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose) - samples, intermediates = self.ddim_sampling(conditioning, x_T.shape, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule, - denoise_function=denoise_function, - extra_args=extra_args, - to_zero=to_zero, - end_step=end_step, - disable_pbar=disable_pbar - ) - return samples, intermediates - - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] - cbs = ctmp.shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - elif isinstance(conditioning, list): - for ctmp in conditioning: - if ctmp.shape[0] != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule, - denoise_function=None, - extra_args=None - ) - return samples, intermediates - - def q_sample(self, x_start, t, noise=None): - if noise is None: - noise = torch.randn_like(x_start) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) - - @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False): - device = self.model.alphas_cumprod.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - # print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar) - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - if ucg_schedule is not None: - assert len(ucg_schedule) == len(time_range) - unconditional_guidance_scale = ucg_schedule[i] - - outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args) - img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - if to_zero: - img = pred_x0 - else: - if ddim_use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - img /= sqrt_alphas_cumprod[index - 1] - - return img, intermediates - - @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None, denoise_function=None, extra_args=None): - b, *_, device = *x.shape, x.device - - if denoise_function is not None: - model_output = denoise_function(x, t, **extra_args) - elif unconditional_conditioning is None or unconditional_guidance_scale == 1.: - model_output = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - if isinstance(c, dict): - assert isinstance(unconditional_conditioning, dict) - c_in = dict() - for k in c: - if isinstance(c[k], list): - c_in[k] = [torch.cat([ - unconditional_conditioning[k][i], - c[k][i]]) for i in range(len(c[k]))] - else: - c_in[k] = torch.cat([ - unconditional_conditioning[k], - c[k]]) - elif isinstance(c, list): - c_in = list() - assert isinstance(unconditional_conditioning, list) - for i in range(len(c)): - c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) - else: - c_in = torch.cat([unconditional_conditioning, c]) - model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) - - if self.parameterization == "v": - e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x - else: - e_t = model_output - - if score_corrector is not None: - assert self.parameterization == "eps", 'not implemented' - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - if self.parameterization != "v": - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - else: - pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output - - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - - if dynamic_threshold is not None: - raise NotImplementedError() - - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - @torch.no_grad() - def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, - unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): - num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] - - assert t_enc <= num_reference_steps - num_steps = t_enc - - if use_original_steps: - alphas_next = self.alphas_cumprod[:num_steps] - alphas = self.alphas_cumprod_prev[:num_steps] - else: - alphas_next = self.ddim_alphas[:num_steps] - alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) - - x_next = x0 - intermediates = [] - inter_steps = [] - for i in tqdm(range(num_steps), desc='Encoding Image'): - t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) - if unconditional_guidance_scale == 1.: - noise_pred = self.model.apply_model(x_next, t, c) - else: - assert unconditional_conditioning is not None - e_t_uncond, noise_pred = torch.chunk( - self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), - torch.cat((unconditional_conditioning, c))), 2) - noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) - - xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next - weighted_noise_pred = alphas_next[i].sqrt() * ( - (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred - x_next = xt_weighted + weighted_noise_pred - if return_intermediates and i % ( - num_steps // return_intermediates) == 0 and i < num_steps - 1: - intermediates.append(x_next) - inter_steps.append(i) - elif return_intermediates and i >= num_steps - 2: - intermediates.append(x_next) - inter_steps.append(i) - if callback: callback(i) - - out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} - if return_intermediates: - out.update({'intermediates': intermediates}) - return x_next, out - - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - if max_denoise: - noise_multiplier = 1.0 - else: - noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) - - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise) - - @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): - - timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) - x_dec = x_latent - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - if callback: callback(i) - return x_dec \ No newline at end of file diff --git a/comfy/ldm/models/diffusion/dpm_solver/__init__.py b/comfy/ldm/models/diffusion/dpm_solver/__init__.py deleted file mode 100644 index 7427f38c075..00000000000 --- a/comfy/ldm/models/diffusion/dpm_solver/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py b/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py deleted file mode 100644 index da8d41f9c5e..00000000000 --- a/comfy/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ /dev/null @@ -1,1163 +0,0 @@ -import torch -import torch.nn.functional as F -import math -from tqdm import tqdm - - -class NoiseScheduleVP: - def __init__( - self, - schedule='discrete', - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20., - ): - """Create a wrapper class for the forward SDE (VP type). - *** - Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. - We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. - *** - The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). - We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). - Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: - log_alpha_t = self.marginal_log_mean_coeff(t) - sigma_t = self.marginal_std(t) - lambda_t = self.marginal_lambda(t) - Moreover, as lambda(t) is an invertible function, we also support its inverse function: - t = self.inverse_lambda(lambda_t) - =============================================================== - We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). - 1. For discrete-time DPMs: - For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: - t_i = (i + 1) / N - e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. - We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. - Args: - betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) - alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) - Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. - **Important**: Please pay special attention for the args for `alphas_cumprod`: - The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that - q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). - Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have - alpha_{t_n} = \sqrt{\hat{alpha_n}}, - and - log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). - 2. For continuous-time DPMs: - We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise - schedule are the default settings in DDPM and improved-DDPM: - Args: - beta_min: A `float` number. The smallest beta for the linear schedule. - beta_max: A `float` number. The largest beta for the linear schedule. - cosine_s: A `float` number. The hyperparameter in the cosine schedule. - cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. - T: A `float` number. The ending time of the forward process. - =============================================================== - Args: - schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, - 'linear' or 'cosine' for continuous-time DPMs. - Returns: - A wrapper object of the forward SDE (VP type). - - =============================================================== - Example: - # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', betas=betas) - # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) - # For continuous-time DPMs (VPSDE), linear schedule: - >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) - """ - - if schedule not in ['discrete', 'linear', 'cosine']: - raise ValueError( - "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( - schedule)) - - self.schedule = schedule - if schedule == 'discrete': - if betas is not None: - log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) - else: - assert alphas_cumprod is not None - log_alphas = 0.5 * torch.log(alphas_cumprod) - self.total_N = len(log_alphas) - self.T = 1. - self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) - else: - self.total_N = 1000 - self.beta_0 = continuous_beta_0 - self.beta_1 = continuous_beta_1 - self.cosine_s = 0.008 - self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) - self.schedule = schedule - if schedule == 'cosine': - # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. - # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. - self.T = 0.9946 - else: - self.T = 1. - - def marginal_log_mean_coeff(self, t): - """ - Compute log(alpha_t) of a given continuous-time label t in [0, T]. - """ - if self.schedule == 'discrete': - return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), - self.log_alpha_array.to(t.device)).reshape((-1)) - elif self.schedule == 'linear': - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == 'cosine': - log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) - log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 - return log_alpha_t - - def marginal_alpha(self, t): - """ - Compute alpha_t of a given continuous-time label t in [0, T]. - """ - return torch.exp(self.marginal_log_mean_coeff(t)) - - def marginal_std(self, t): - """ - Compute sigma_t of a given continuous-time label t in [0, T]. - """ - return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) - - def marginal_lambda(self, t): - """ - Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. - """ - log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) - return log_mean_coeff - log_std - - def inverse_lambda(self, lamb): - """ - Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. - """ - if self.schedule == 'linear': - tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0 ** 2 + tmp - return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == 'discrete': - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) - t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), - torch.flip(self.t_array.to(lamb.device), [1])) - return t.reshape((-1,)) - else: - log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s - t = t_fn(log_alpha) - return t - - -def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs={}, -): - """Create a wrapper function for the noise prediction model. - DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to - firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. - We support four types of the diffusion model by setting `model_type`: - 1. "noise": noise prediction model. (Trained by predicting noise). - 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). - 3. "v": velocity prediction model. (Trained by predicting the velocity). - The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. - [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." - arXiv preprint arXiv:2202.00512 (2022). - [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." - arXiv preprint arXiv:2210.02303 (2022). - - 4. "score": marginal score function. (Trained by denoising score matching). - Note that the score function and the noise prediction model follows a simple relationship: - ``` - noise(x_t, t) = -sigma_t * score(x_t, t) - ``` - We support three types of guided sampling by DPMs by setting `guidance_type`: - 1. "uncond": unconditional sampling by DPMs. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - The input `classifier_fn` has the following format: - `` - classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) - `` - [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," - in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. - 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. - The input `model` has the following format: - `` - model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score - `` - And if cond == `unconditional_condition`, the model output is the unconditional DPM output. - [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." - arXiv preprint arXiv:2207.12598 (2022). - - The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) - or continuous-time labels (i.e. epsilon to T). - We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: - `` - def model_fn(x, t_continuous) -> noise: - t_input = get_model_input_time(t_continuous) - return noise_pred(model, x, t_input, **model_kwargs) - `` - where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. - =============================================================== - Args: - model: A diffusion model with the corresponding format described above. - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - model_type: A `str`. The parameterization type of the diffusion model. - "noise" or "x_start" or "v" or "score". - model_kwargs: A `dict`. A dict for the other inputs of the model function. - guidance_type: A `str`. The type of the guidance for sampling. - "uncond" or "classifier" or "classifier-free". - condition: A pytorch tensor. The condition for the guided sampling. - Only used for "classifier" or "classifier-free" guidance type. - unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. - Only used for "classifier-free" guidance type. - guidance_scale: A `float`. The scale for the guided sampling. - classifier_fn: A classifier function. Only used for the classifier guidance. - classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. - Returns: - A noise prediction model that accepts the noised data and the continuous time as the inputs. - """ - - def get_model_input_time(t_continuous): - """ - Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. - For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. - For continuous-time DPMs, we just use `t_continuous`. - """ - if noise_schedule.schedule == 'discrete': - return (t_continuous - 1. / noise_schedule.total_N) * 1000. - else: - return t_continuous - - def noise_pred_fn(x, t_continuous, cond=None): - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - t_input = get_model_input_time(t_continuous) - if cond is None: - output = model(x, t_input, **model_kwargs) - else: - output = model(x, t_input, cond, **model_kwargs) - if model_type == "noise": - return output - elif model_type == "x_start": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) - elif model_type == "v": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x - elif model_type == "score": - sigma_t = noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return -expand_dims(sigma_t, dims) * output - - def cond_grad_fn(x, t_input): - """ - Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). - """ - with torch.enable_grad(): - x_in = x.detach().requires_grad_(True) - log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) - return torch.autograd.grad(log_prob.sum(), x_in)[0] - - def model_fn(x, t_continuous): - """ - The noise predicition model function that is used for DPM-Solver. - """ - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - if guidance_type == "uncond": - return noise_pred_fn(x, t_continuous) - elif guidance_type == "classifier": - assert classifier_fn is not None - t_input = get_model_input_time(t_continuous) - cond_grad = cond_grad_fn(x, t_input) - sigma_t = noise_schedule.marginal_std(t_continuous) - noise = noise_pred_fn(x, t_continuous) - return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad - elif guidance_type == "classifier-free": - if guidance_scale == 1. or unconditional_condition is None: - return noise_pred_fn(x, t_continuous, cond=condition) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t_continuous] * 2) - if isinstance(condition, dict): - assert isinstance(unconditional_condition, dict) - c_in = dict() - for k in condition: - if isinstance(condition[k], list): - c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))] - else: - c_in[k] = torch.cat([unconditional_condition[k], condition[k]]) - else: - c_in = torch.cat([unconditional_condition, condition]) - noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) - return noise_uncond + guidance_scale * (noise - noise_uncond) - - assert model_type in ["noise", "x_start", "v"] - assert guidance_type in ["uncond", "classifier", "classifier-free"] - return model_fn - - -class DPM_Solver: - def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): - """Construct a DPM-Solver. - We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). - If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). - If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). - In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. - The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. - Args: - model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): - `` - def model_fn(x, t_continuous): - return noise - `` - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. - thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. - max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. - - [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. - """ - self.model = model_fn - self.noise_schedule = noise_schedule - self.predict_x0 = predict_x0 - self.thresholding = thresholding - self.max_val = max_val - - def noise_prediction_fn(self, x, t): - """ - Return the noise prediction model. - """ - return self.model(x, t) - - def data_prediction_fn(self, x, t): - """ - Return the data prediction model (with thresholding). - """ - noise = self.noise_prediction_fn(x, t) - dims = x.dim() - alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) - if self.thresholding: - p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. - s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) - x0 = torch.clamp(x0, -s, s) / s - return x0 - - def model_fn(self, x, t): - """ - Convert the model to the noise prediction model or the data prediction model. - """ - if self.predict_x0: - return self.data_prediction_fn(x, t) - else: - return self.noise_prediction_fn(x, t) - - def get_time_steps(self, skip_type, t_T, t_0, N, device): - """Compute the intermediate time steps for sampling. - Args: - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - N: A `int`. The total number of the spacing of the time steps. - device: A torch device. - Returns: - A pytorch tensor of the time steps, with the shape (N + 1,). - """ - if skip_type == 'logSNR': - lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) - lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) - logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) - return self.noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == 'time_uniform': - return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == 'time_quadratic': - t_order = 2 - t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) - return t - else: - raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) - - def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): - """ - Get the order of each step for sampling by the singlestep DPM-Solver. - We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". - Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: - - If order == 1: - We take `steps` of DPM-Solver-1 (i.e. DDIM). - - If order == 2: - - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of DPM-Solver-2. - - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If order == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. - ============================================ - Args: - order: A `int`. The max order for the solver (2 or 3). - steps: A `int`. The total number of function evaluations (NFE). - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - device: A torch device. - Returns: - orders: A list of the solver order of each step. - """ - if order == 3: - K = steps // 3 + 1 - if steps % 3 == 0: - orders = [3, ] * (K - 2) + [2, 1] - elif steps % 3 == 1: - orders = [3, ] * (K - 1) + [1] - else: - orders = [3, ] * (K - 1) + [2] - elif order == 2: - if steps % 2 == 0: - K = steps // 2 - orders = [2, ] * K - else: - K = steps // 2 + 1 - orders = [2, ] * (K - 1) + [1] - elif order == 1: - K = 1 - orders = [1, ] * steps - else: - raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == 'logSNR': - # To reproduce the results in DPM-Solver paper - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) - else: - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum(torch.tensor([0, ] + orders)).to(device)] - return timesteps_outer, orders - - def denoise_to_zero_fn(self, x, s): - """ - Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. - """ - return self.data_prediction_fn(x, s) - - def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): - """ - DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - if self.predict_x0: - phi_1 = torch.expm1(-h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - ) - if return_intermediate: - return x_t, {'model_s': model_s} - else: - return x_t - else: - phi_1 = torch.expm1(h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - ) - if return_intermediate: - return x_t, {'model_s': model_s} - else: - return x_t - - def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, - solver_type='dpm_solver'): - """ - Singlestep solver DPM-Solver-2 from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - r1: A `float`. The hyperparameter of the second-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 0.5 - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( - s1), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) - alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) - - if self.predict_x0: - phi_11 = torch.expm1(-r1 * h) - phi_1 = torch.expm1(-h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( - model_s1 - model_s) - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_1 = torch.expm1(h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) - ) - if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1} - else: - return x_t - - def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, - return_intermediate=False, solver_type='dpm_solver'): - """ - Singlestep solver DPM-Solver-3 from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - r1: A `float`. The hyperparameter of the third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). - If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 1. / 3. - if r2 is None: - r2 = 2. / 3. - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - lambda_s2 = lambda_s + r2 * h - s1 = ns.inverse_lambda(lambda_s1) - s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( - s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( - s2), ns.marginal_std(t) - alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) - - if self.predict_x0: - phi_11 = torch.expm1(-r1 * h) - phi_12 = torch.expm1(-r2 * h) - phi_1 = torch.expm1(-h) - phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. - phi_2 = phi_1 / h + 1. - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(sigma_s2 / sigma_s, dims) * x - - expand_dims(alpha_s2 * phi_12, dims) * model_s - + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) - ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + expand_dims(alpha_t * phi_2, dims) * D1 - - expand_dims(alpha_t * phi_3, dims) * D2 - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_12 = torch.expm1(r2 * h) - phi_1 = torch.expm1(h) - phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. - phi_2 = phi_1 / h - 1. - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - - expand_dims(sigma_s2 * phi_12, dims) * model_s - - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) - ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - expand_dims(sigma_t * phi_2, dims) * D1 - - expand_dims(sigma_t * phi_3, dims) * D2 - ) - - if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} - else: - return x_t - - def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): - """ - Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - ns = self.noise_schedule - dims = x.dim() - model_prev_1, model_prev_0 = model_prev_list - t_prev_1, t_prev_0 = t_prev_list - lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( - t_prev_0), ns.marginal_lambda(t) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0 = h_0 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - if self.predict_x0: - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 - ) - else: - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 - ) - return x_t - - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): - """ - Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - model_prev_2, model_prev_1, model_prev_0 = model_prev_list - t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( - t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_1 = lambda_prev_1 - lambda_prev_2 - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0, r1 = h_0 / h, h_1 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) - D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) - D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) - if self.predict_x0: - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 - ) - else: - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 - ) - return x_t - - def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, - r2=None): - """ - Singlestep DPM-Solver with the order `order` from time `s` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - r1: A `float`. The hyperparameter of the second-order or third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) - elif order == 2: - return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1) - elif order == 3: - return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1, r2=r2) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): - """ - Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) - elif order == 2: - return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - elif order == 3: - return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, - solver_type='dpm_solver'): - """ - The adaptive step size solver based on singlestep DPM-Solver. - Args: - x: A pytorch tensor. The initial value at time `t_T`. - order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - h_init: A `float`. The initial step size (for logSNR). - atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. - rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. - theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. - t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the - current time and `t_0` is less than `t_err`. The default setting is 1e-5. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_0: A pytorch tensor. The approximated solution at time `t_0`. - [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. - """ - ns = self.noise_schedule - s = t_T * torch.ones((x.shape[0],)).to(x) - lambda_s = ns.marginal_lambda(s) - lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) - h = h_init * torch.ones_like(s).to(x) - x_prev = x - nfe = 0 - if order == 2: - r1 = 0.5 - lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - solver_type=solver_type, - **kwargs) - elif order == 3: - r1, r2 = 1. / 3., 2. / 3. - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - return_intermediate=True, - solver_type=solver_type) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, - solver_type=solver_type, - **kwargs) - else: - raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) - while torch.abs((s - t_0)).mean() > t_err: - t = ns.inverse_lambda(lambda_s + h) - x_lower, lower_noise_kwargs = lower_update(x, s, t) - x_higher = higher_update(x, s, t, **lower_noise_kwargs) - delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) - norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) - E = norm_fn((x_higher - x_lower) / delta).max() - if torch.all(E <= 1.): - x = x_higher - s = t - x_prev = x_lower - lambda_s = ns.marginal_lambda(s) - h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) - nfe += order - print('adaptive solver nfe', nfe) - return x - - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, - ): - """ - Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. - ===================================================== - We support the following algorithms for both noise prediction model and data prediction model: - - 'singlestep': - Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. - We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). - The total number of function evaluations (NFE) == `steps`. - Given a fixed NFE == `steps`, the sampling procedure is: - - If `order` == 1: - - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. - - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If `order` == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. - - 'multistep': - Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. - We initialize the first `order` values by lower order multistep solvers. - Given a fixed NFE == `steps`, the sampling procedure is: - Denote K = steps. - - If `order` == 1: - - We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. - - If `order` == 3: - - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. - - 'singlestep_fixed': - Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). - We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. - - 'adaptive': - Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). - We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. - You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs - (NFE) and the sample quality. - - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. - - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. - ===================================================== - Some advices for choosing the algorithm: - - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: - Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, - skip_type='time_uniform', method='singlestep') - - For **guided sampling with large guidance scale** by DPMs: - Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, - skip_type='time_uniform', method='multistep') - We support three types of `skip_type`: - - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** - - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. - - 'time_quadratic': quadratic time for the time steps. - ===================================================== - Args: - x: A pytorch tensor. The initial value at time `t_start` - e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. - steps: A `int`. The total number of function evaluations (NFE). - t_start: A `float`. The starting time of the sampling. - If `T` is None, we use self.noise_schedule.T (default is 1.0). - t_end: A `float`. The ending time of the sampling. - If `t_end` is None, we use 1. / self.noise_schedule.total_N. - e.g. if total_N == 1000, we have `t_end` == 1e-3. - For discrete-time DPMs: - - We recommend `t_end` == 1. / self.noise_schedule.total_N. - For continuous-time DPMs: - - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. - order: A `int`. The order of DPM-Solver. - skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. - method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. - denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. - Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). - This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and - score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID - for diffusion models sampling by diffusion SDEs for low-resolutional images - (such as CIFAR-10). However, we observed that such trick does not matter for - high-resolutional images. As it needs an additional NFE, we do not recommend - it for high-resolutional images. - lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. - Only valid for `method=multistep` and `steps < 15`. We empirically find that - this trick is a key to stabilizing the sampling by DPM-Solver with very few steps - (especially for steps <= 10). So we recommend to set it to be `True`. - solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. - atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - Returns: - x_end: A pytorch tensor. The approximated solution at time `t_end`. - """ - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end - t_T = self.noise_schedule.T if t_start is None else t_start - device = x.device - if method == 'adaptive': - with torch.no_grad(): - x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, - solver_type=solver_type) - elif method == 'multistep': - assert steps >= order - timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - assert timesteps.shape[0] - 1 == steps - with torch.no_grad(): - vec_t = timesteps[0].expand((x.shape[0])) - model_prev_list = [self.model_fn(x, vec_t)] - t_prev_list = [vec_t] - # Init the first `order` values by lower order multistep DPM-Solver. - for init_order in tqdm(range(1, order), desc="DPM init order"): - vec_t = timesteps[init_order].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, - solver_type=solver_type) - model_prev_list.append(self.model_fn(x, vec_t)) - t_prev_list.append(vec_t) - # Compute the remaining values by `order`-th order multistep DPM-Solver. - for step in tqdm(range(order, steps + 1), desc="DPM multistep"): - vec_t = timesteps[step].expand(x.shape[0]) - if lower_order_final and steps < 15: - step_order = min(order, steps + 1 - step) - else: - step_order = order - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, - solver_type=solver_type) - for i in range(order - 1): - t_prev_list[i] = t_prev_list[i + 1] - model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = vec_t - # We do not need to evaluate the final model value. - if step < steps: - model_prev_list[-1] = self.model_fn(x, vec_t) - elif method in ['singlestep', 'singlestep_fixed']: - if method == 'singlestep': - timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, - skip_type=skip_type, - t_T=t_T, t_0=t_0, - device=device) - elif method == 'singlestep_fixed': - K = steps // order - orders = [order, ] * K - timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) - for i, order in enumerate(orders): - t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] - timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), - N=order, device=device) - lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) - vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) - h = lambda_inner[-1] - lambda_inner[0] - r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h - r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h - x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) - if denoise_to_zero: - x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) - return x - - -############################################################# -# other utility functions -############################################################# - -def interpolate_fn(x, xp, yp): - """ - A piecewise linear function y = f(x), using xp and yp as keypoints. - We implement f(x) in a differentiable way (i.e. applicable for autograd). - The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) - Args: - x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). - xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. - yp: PyTorch tensor with shape [C, K]. - Returns: - The function values f(x), with shape [N, C]. - """ - N, K = x.shape[0], xp.shape[1] - all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) - sorted_all_x, x_indices = torch.sort(all_x, dim=2) - x_idx = torch.argmin(x_indices, dim=2) - cand_start_idx = x_idx - 1 - start_idx = torch.where( - torch.eq(x_idx, 0), - torch.tensor(1, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) - start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) - end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) - start_idx2 = torch.where( - torch.eq(x_idx, 0), - torch.tensor(0, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) - start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) - end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) - cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) - return cand - - -def expand_dims(v, dims): - """ - Expand the tensor `v` to the dim `dims`. - Args: - `v`: a PyTorch tensor with shape [N]. - `dim`: a `int`. - Returns: - a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. - """ - return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/comfy/ldm/models/diffusion/dpm_solver/sampler.py b/comfy/ldm/models/diffusion/dpm_solver/sampler.py deleted file mode 100644 index e4d0d0a3875..00000000000 --- a/comfy/ldm/models/diffusion/dpm_solver/sampler.py +++ /dev/null @@ -1,96 +0,0 @@ -"""SAMPLING ONLY.""" -import torch - -from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver - -MODEL_TYPES = { - "eps": "noise", - "v": "v" -} - - -class DPMSolverSampler(object): - def __init__(self, model, device=torch.device("cuda"), **kwargs): - super().__init__() - self.model = model - self.device = device - to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) - self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != self.device: - attr = attr.to(self.device) - setattr(self, name, attr) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] - if isinstance(ctmp, torch.Tensor): - cbs = ctmp.shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - elif isinstance(conditioning, list): - for ctmp in conditioning: - if ctmp.shape[0] != batch_size: - print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}") - else: - if isinstance(conditioning, torch.Tensor): - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - - print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') - - device = self.model.betas.device - if x_T is None: - img = torch.randn(size, device=device) - else: - img = x_T - - ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) - - model_fn = model_wrapper( - lambda x, t, c: self.model.apply_model(x, t, c), - ns, - model_type=MODEL_TYPES[self.model.parameterization], - guidance_type="classifier-free", - condition=conditioning, - unconditional_condition=unconditional_conditioning, - guidance_scale=unconditional_guidance_scale, - ) - - dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, - lower_order_final=True) - - return x.to(device), None diff --git a/comfy/ldm/models/diffusion/plms.py b/comfy/ldm/models/diffusion/plms.py deleted file mode 100644 index 9d31b3994ed..00000000000 --- a/comfy/ldm/models/diffusion/plms.py +++ /dev/null @@ -1,245 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm -from functools import partial - -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like -from ldm.models.diffusion.sampling_util import norm_thresholding - - -class PLMSSampler(object): - def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - self.device = device - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != self.device: - attr = attr.to(self.device) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - if ddim_eta != 0: - raise ValueError('ddim_eta must be 0 for PLMS') - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') - - samples, intermediates = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ) - return samples, intermediates - - @torch.no_grad() - def plms_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running PLMS Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) - old_eps = [] - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next, - dynamic_threshold=dynamic_threshold) - img, pred_x0, e_t = outs - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, - dynamic_threshold=None): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - if dynamic_threshold is not None: - pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t diff --git a/comfy/ldm/models/diffusion/sampling_util.py b/comfy/ldm/models/diffusion/sampling_util.py deleted file mode 100644 index 7eff02be6d7..00000000000 --- a/comfy/ldm/models/diffusion/sampling_util.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch -import numpy as np - - -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions. - From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] - - -def norm_thresholding(x0, value): - s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) - return x0 * (value / s) - - -def spatial_norm_thresholding(x0, value): - # b c h w - s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) - return x0 * (value / s) \ No newline at end of file diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index dcf467489fe..016795a5974 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -160,32 +160,19 @@ def attention_sub_quad(query, key, value, heads, mask=None): mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) - chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD - kv_chunk_size_min = None + kv_chunk_size = None + query_chunk_size = None + + for x in [4096, 2048, 1024, 512, 256]: + count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0) + if count >= k_tokens: + kv_chunk_size = k_tokens + query_chunk_size = x + break - #not sure at all about the math here - #TODO: tweak this - if mem_free_total > 8192 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 4 - elif mem_free_total > 4096 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 2 - else: - query_chunk_size_x = 1024 - kv_chunk_size_min_x = None - kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024 - if kv_chunk_size_x < 1024: - kv_chunk_size_x = None - - if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: - # the big matmul fits into our memory limit; do everything in 1 chunk, - # i.e. send it down the unchunked fast-path - query_chunk_size = q_tokens - kv_chunk_size = k_tokens - else: - query_chunk_size = query_chunk_size_x - kv_chunk_size = kv_chunk_size_x - kv_chunk_size_min = kv_chunk_size_min_x + if query_chunk_size is None: + query_chunk_size = 512 hidden_states = efficient_dot_product_attention( query, @@ -222,9 +209,14 @@ def attention_split(q, k, v, heads, mask=None): mem_free_total = model_management.get_free_memory(q.device) + if _ATTN_PRECISION =="fp32": + element_size = 4 + else: + element_size = q.element_size() + gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size + modifier = 3 mem_required = tensor_size * modifier steps = 1 @@ -252,10 +244,10 @@ def attention_split(q, k, v, heads, mask=None): s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale else: s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale - first_op_done = True s2 = s1.softmax(dim=-1).to(v.dtype) del s1 + first_op_done = True r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index bf58a4045f7..7dfdfc0a29c 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -259,10 +259,6 @@ class UNetModel(nn.Module): :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and @@ -289,7 +285,6 @@ def __init__( model_channels, out_channels, num_res_blocks, - attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, @@ -314,6 +309,7 @@ def __init__( use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, + transformer_depth_output=None, device=None, operations=comfy.ops, ): @@ -341,10 +337,7 @@ def __init__( self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels - if isinstance(transformer_depth, int): - transformer_depth = len(channel_mult) * [transformer_depth] - if transformer_depth_middle is None: - transformer_depth_middle = transformer_depth[-1] + if isinstance(num_res_blocks, int): self.num_res_blocks = len(channel_mult) * [num_res_blocks] else: @@ -352,18 +345,16 @@ def __init__( raise ValueError("provide num_res_blocks either as an int (globally constant) or " "as a list/tuple (per-level) with the same length as channel_mult") self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) - print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set.") - self.attention_resolutions = attention_resolutions + transformer_depth = transformer_depth[:] + transformer_depth_output = transformer_depth_output[:] + self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample @@ -428,7 +419,8 @@ def __init__( ) ] ch = mult * model_channels - if ds in attention_resolutions: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -444,7 +436,7 @@ def __init__( if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: layers.append(SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) @@ -488,7 +480,7 @@ def __init__( if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( + mid_block = [ ResBlock( ch, time_embed_dim, @@ -499,8 +491,9 @@ def __init__( dtype=self.dtype, device=device, operations=operations - ), - SpatialTransformer( # always uses a self-attn + )] + if transformer_depth_middle >= 0: + mid_block += [SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations @@ -515,8 +508,8 @@ def __init__( dtype=self.dtype, device=device, operations=operations - ), - ) + )] + self.middle_block = TimestepEmbedSequential(*mid_block) self._feature_size += ch self.output_blocks = nn.ModuleList([]) @@ -538,7 +531,8 @@ def __init__( ) ] ch = model_channels * mult - if ds in attention_resolutions: + num_transformers = transformer_depth_output.pop() + if num_transformers > 0: if num_head_channels == -1: dim_head = ch // num_heads else: @@ -555,7 +549,7 @@ def __init__( if not exists(num_attention_blocks) or i < num_attention_blocks[level]: layers.append( SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim, + ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations ) diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 4d42059b5a8..8e8e8054dfd 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -83,7 +83,8 @@ def _summarize_chunk( ) max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() - torch.exp(attn_weights - max_score, out=attn_weights) + attn_weights -= max_score + torch.exp(attn_weights, out=attn_weights) exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) max_score = max_score.squeeze(-1) diff --git a/comfy/lora.py b/comfy/lora.py index 3009a1c9e0c..d4cf94c9599 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -141,9 +141,9 @@ def model_lora_keys_clip(model, key_map={}): text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False - for b in range(32): + for b in range(32): #TODO: clean up for c in LORA_CLIP_MAP: - k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) + k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) key_map[lora_key] = k @@ -154,6 +154,8 @@ def model_lora_keys_clip(model, key_map={}): k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) if k in sdk: + lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) + key_map[lora_key] = k lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base key_map[lora_key] = k clip_l_present = True diff --git a/comfy/model_base.py b/comfy/model_base.py index cda6765e43a..41d464e523c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -4,6 +4,7 @@ from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import comfy.model_management +import comfy.conds import numpy as np from enum import Enum from . import utils @@ -12,6 +13,96 @@ class ModelType(Enum): EPS = 1 V_PREDICTION = 2 + +#NOTE: all this sampling stuff will be moved +class EPS: + def calculate_input(self, sigma, noise): + sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1)) + return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + +class V_PREDICTION(EPS): + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + + +class ModelSamplingDiscrete(torch.nn.Module): + def __init__(self, model_config=None): + super().__init__() + beta_schedule = "linear" + if model_config is not None: + beta_schedule = model_config.beta_schedule + self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + self.sigma_data = 1.0 + + def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) + # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + + # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) + # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) + # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) + + sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def sigma(self, timestep): + t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1)) + low_idx = t.floor().long() + high_idx = t.ceil().long() + w = t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp() + + def percent_to_sigma(self, percent): + return self.sigma(torch.tensor(percent * 999.0)) + +def model_sampling(model_config, model_type): + if model_type == ModelType.EPS: + c = EPS + elif model_type == ModelType.V_PREDICTION: + c = V_PREDICTION + + s = ModelSamplingDiscrete + + class ModelSampling(s, c): + pass + + return ModelSampling(model_config) + + + class BaseModel(torch.nn.Module): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__() @@ -19,10 +110,12 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None): unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config - self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + if not unet_config.get("disable_unet_model_creation", False): self.diffusion_model = UNetModel(**unet_config, device=device) self.model_type = model_type + self.model_sampling = model_sampling(model_config, model_type) + self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 @@ -30,38 +123,22 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None): print("model_type", model_type.name) print("adm", self.adm_channels) - def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - if given_betas is not None: - betas = given_betas - else: - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) - alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - - self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) - self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) - self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) - - def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}): + def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + sigma = t + xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: - xc = torch.cat([x] + [c_concat], dim=1) - else: - xc = x + xc = torch.cat([xc] + [c_concat], dim=1) + context = c_crossattn dtype = self.get_dtype() xc = xc.to(dtype) - t = t.to(dtype) + t = self.model_sampling.timestep(t).float() context = context.to(dtype) - if c_adm is not None: - c_adm = c_adm.to(dtype) - return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options).float() + extra_conds = {} + for o in kwargs: + extra_conds[o] = kwargs[o].to(dtype) + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() + return self.model_sampling.calculate_denoised(sigma, model_output, x) def get_dtype(self): return self.diffusion_model.dtype @@ -72,7 +149,8 @@ def is_adm(self): def encode_adm(self, **kwargs): return None - def cond_concat(self, **kwargs): + def extra_conds(self, **kwargs): + out = {} if self.inpaint_model: concat_keys = ("mask", "masked_image") cond_concat = [] @@ -101,8 +179,12 @@ def blank_inpaint_image_like(latent_image): cond_concat.append(torch.ones_like(noise)[:,:1]) elif ck == "masked_image": cond_concat.append(blank_inpaint_image_like(noise)) - return cond_concat - return None + data = torch.cat(cond_concat, dim=1) + out['c_concat'] = comfy.conds.CONDNoiseShape(data) + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = comfy.conds.CONDRegular(adm) + return out def load_model_weights(self, sd, unet_prefix=""): to_load = {} diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 0ff2e7fb53f..4f4e0b3b7f0 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -14,6 +14,19 @@ def count_blocks(state_dict_keys, prefix_string): count += 1 return count +def calculate_transformer_depth(prefix, state_dict_keys, state_dict): + context_dim = None + use_linear_in_transformer = False + + transformer_prefix = prefix + "1.transformer_blocks." + transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys))) + if len(transformer_keys) > 0: + last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') + context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] + use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 + return last_transformer_depth, context_dim, use_linear_in_transformer + return None + def detect_unet_config(state_dict, key_prefix, dtype): state_dict_keys = list(state_dict.keys()) @@ -40,6 +53,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): channel_mult = [] attention_resolutions = [] transformer_depth = [] + transformer_depth_output = [] context_dim = None use_linear_in_transformer = False @@ -48,60 +62,67 @@ def detect_unet_config(state_dict, key_prefix, dtype): count = 0 last_res_blocks = 0 - last_transformer_depth = 0 last_channel_mult = 0 - while True: + input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.') + for count in range(input_block_count): prefix = '{}input_blocks.{}.'.format(key_prefix, count) + prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1) + block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys))) if len(block_keys) == 0: break + block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys))) + if "{}0.op.weight".format(prefix) in block_keys: #new layer - if last_transformer_depth > 0: - attention_resolutions.append(current_res) - transformer_depth.append(last_transformer_depth) num_res_blocks.append(last_res_blocks) channel_mult.append(last_channel_mult) current_res *= 2 last_res_blocks = 0 - last_transformer_depth = 0 last_channel_mult = 0 + out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict) + if out is not None: + transformer_depth_output.append(out[0]) + else: + transformer_depth_output.append(0) else: res_block_prefix = "{}0.in_layers.0.weight".format(prefix) if res_block_prefix in block_keys: last_res_blocks += 1 last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels - transformer_prefix = prefix + "1.transformer_blocks." - transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys))) - if len(transformer_keys) > 0: - last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') - if context_dim is None: - context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] - use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 + out = calculate_transformer_depth(prefix, state_dict_keys, state_dict) + if out is not None: + transformer_depth.append(out[0]) + if context_dim is None: + context_dim = out[1] + use_linear_in_transformer = out[2] + else: + transformer_depth.append(0) + + res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output) + if res_block_prefix in block_keys_output: + out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict) + if out is not None: + transformer_depth_output.append(out[0]) + else: + transformer_depth_output.append(0) - count += 1 - if last_transformer_depth > 0: - attention_resolutions.append(current_res) - transformer_depth.append(last_transformer_depth) num_res_blocks.append(last_res_blocks) channel_mult.append(last_channel_mult) - transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') - - if len(set(num_res_blocks)) == 1: - num_res_blocks = num_res_blocks[0] - - if len(set(transformer_depth)) == 1: - transformer_depth = transformer_depth[0] + if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: + transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}') + else: + transformer_depth_middle = -1 unet_config["in_channels"] = in_channels unet_config["model_channels"] = model_channels unet_config["num_res_blocks"] = num_res_blocks - unet_config["attention_resolutions"] = attention_resolutions unet_config["transformer_depth"] = transformer_depth + unet_config["transformer_depth_output"] = transformer_depth_output unet_config["channel_mult"] = channel_mult unet_config["transformer_depth_middle"] = transformer_depth_middle unet_config['use_linear_in_transformer'] = use_linear_in_transformer @@ -124,6 +145,45 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma else: return model_config +def convert_config(unet_config): + new_config = unet_config.copy() + num_res_blocks = new_config.get("num_res_blocks", None) + channel_mult = new_config.get("channel_mult", None) + + if isinstance(num_res_blocks, int): + num_res_blocks = len(channel_mult) * [num_res_blocks] + + if "attention_resolutions" in new_config: + attention_resolutions = new_config.pop("attention_resolutions") + transformer_depth = new_config.get("transformer_depth", None) + transformer_depth_middle = new_config.get("transformer_depth_middle", None) + + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + if transformer_depth_middle is None: + transformer_depth_middle = transformer_depth[-1] + t_in = [] + t_out = [] + s = 1 + for i in range(len(num_res_blocks)): + res = num_res_blocks[i] + d = 0 + if s in attention_resolutions: + d = transformer_depth[i] + + t_in += [d] * res + t_out += [d] * (res + 1) + s *= 2 + transformer_depth = t_in + transformer_depth_output = t_out + new_config["transformer_depth"] = t_in + new_config["transformer_depth_output"] = t_out + new_config["transformer_depth_middle"] = transformer_depth_middle + + new_config["num_res_blocks"] = num_res_blocks + return new_config + + def unet_config_from_diffusers_unet(state_dict, dtype): match = {} attention_resolutions = [] @@ -200,7 +260,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype): matches = False break if matches: - return unet_config + return convert_config(unet_config) return None def model_config_from_diffusers_unet(state_dict, dtype): diff --git a/comfy/sample.py b/comfy/sample.py index e6a69973d93..b3fcd1658a5 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -1,6 +1,7 @@ import torch import comfy.model_management import comfy.samplers +import comfy.conds import comfy.utils import math import numpy as np @@ -33,22 +34,24 @@ def prepare_mask(noise_mask, shape, device): noise_mask = noise_mask.to(device) return noise_mask -def broadcast_cond(cond, batch, device): - """broadcasts conditioning to the batch size""" - copy = [] - for p in cond: - t = comfy.utils.repeat_to_batch_size(p[0], batch) - t = t.to(device) - copy += [[t] + p[1:]] - return copy - def get_models_from_cond(cond, model_type): models = [] for c in cond: - if model_type in c[1]: - models += [c[1][model_type]] + if model_type in c: + models += [c[model_type]] return models +def convert_cond(cond): + out = [] + for c in cond: + temp = c[1].copy() + model_conds = temp.get("model_conds", {}) + if c[0] is not None: + model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) + temp["model_conds"] = model_conds + out.append(temp) + return out + def get_additional_models(positive, negative, dtype): """loads additional models in positive and negative conditioning""" control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) @@ -72,6 +75,8 @@ def cleanup_additional_models(models): def prepare_sampling(model, noise_shape, positive, negative, noise_mask): device = model.load_device + positive = convert_cond(positive) + negative = convert_cond(negative) if noise_mask is not None: noise_mask = prepare_mask(noise_mask, noise_shape, device) @@ -81,9 +86,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask): comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory) real_model = model.model - positive_copy = broadcast_cond(positive, noise_shape[0], device) - negative_copy = broadcast_cond(negative, noise_shape[0], device) - return real_model, positive_copy, negative_copy, noise_mask, models + return real_model, positive, negative, noise_mask, models def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): diff --git a/comfy/samplers.py b/comfy/samplers.py index 0b38fbd1e86..964febb262e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,48 +1,42 @@ from .k_diffusion import sampling as k_diffusion_sampling -from .k_diffusion import external as k_diffusion_external from .extra_samplers import uni_pc import torch +import enum from comfy import model_management -from .ldm.models.diffusion.ddim import DDIMSampler -from .ldm.modules.diffusionmodules.util import make_ddim_timesteps import math from comfy import model_base import comfy.utils +import comfy.conds -def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) - return abs(a*b) // math.gcd(a, b) #The main sampling function shared by all the samplers -#Returns predicted noise +#Returns denoised def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - def get_area_and_mult(cond, x_in, timestep_in): + def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 - if 'timestep_start' in cond[1]: - timestep_start = cond[1]['timestep_start'] + + if 'timestep_start' in conds: + timestep_start = conds['timestep_start'] if timestep_in[0] > timestep_start: return None - if 'timestep_end' in cond[1]: - timestep_end = cond[1]['timestep_end'] + if 'timestep_end' in conds: + timestep_end = conds['timestep_end'] if timestep_in[0] < timestep_end: return None - if 'area' in cond[1]: - area = cond[1]['area'] - if 'strength' in cond[1]: - strength = cond[1]['strength'] - - adm_cond = None - if 'adm_encoded' in cond[1]: - adm_cond = cond[1]['adm_encoded'] + if 'area' in conds: + area = conds['area'] + if 'strength' in conds: + strength = conds['strength'] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if 'mask' in cond[1]: + if 'mask' in conds: # Scale the mask to the size of the input # The mask should have been resized as we began the sampling process mask_strength = 1.0 - if "mask_strength" in cond[1]: - mask_strength = cond[1]["mask_strength"] - mask = cond[1]['mask'] + if "mask_strength" in conds: + mask_strength = conds["mask_strength"] + mask = conds['mask'] assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength @@ -51,7 +45,7 @@ def get_area_and_mult(cond, x_in, timestep_in): mask = torch.ones_like(input_x) mult = mask * strength - if 'mask' not in cond[1]: + if 'mask' not in conds: rr = 8 if area[2] != 0: for t in range(rr): @@ -67,27 +61,17 @@ def get_area_and_mult(cond, x_in, timestep_in): mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) conditionning = {} - conditionning['c_crossattn'] = cond[0] - - if 'concat' in cond[1]: - cond_concat_in = cond[1]['concat'] - if cond_concat_in is not None and len(cond_concat_in) > 0: - cropped = [] - for x in cond_concat_in: - cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - cropped.append(cr) - conditionning['c_concat'] = torch.cat(cropped, dim=1) - - if adm_cond is not None: - conditionning['c_adm'] = adm_cond + model_conds = conds["model_conds"] + for c in model_conds: + conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) control = None - if 'control' in cond[1]: - control = cond[1]['control'] + if 'control' in conds: + control = conds['control'] patches = None - if 'gligen' in cond[1]: - gligen = cond[1]['gligen'] + if 'gligen' in conds: + gligen = conds['gligen'] patches = {} gligen_type = gligen[0] gligen_model = gligen[1] @@ -105,22 +89,8 @@ def cond_equal_size(c1, c2): return True if c1.keys() != c2.keys(): return False - if 'c_crossattn' in c1: - s1 = c1['c_crossattn'].shape - s2 = c2['c_crossattn'].shape - if s1 != s2: - if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen - return False - - mult_min = lcm(s1[1], s2[1]) - diff = mult_min // min(s1[1], s2[1]) - if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much - return False - if 'c_concat' in c1: - if c1['c_concat'].shape != c2['c_concat'].shape: - return False - if 'c_adm' in c1: - if c1['c_adm'].shape != c2['c_adm'].shape: + for k in c1: + if not c1[k].can_concat(c2[k]): return False return True @@ -149,39 +119,27 @@ def cond_cat(c_list): c_concat = [] c_adm = [] crossattn_max_len = 0 + + temp = {} for x in c_list: - if 'c_crossattn' in x: - c = x['c_crossattn'] - if crossattn_max_len == 0: - crossattn_max_len = c.shape[1] - else: - crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) - c_crossattn.append(c) - if 'c_concat' in x: - c_concat.append(x['c_concat']) - if 'c_adm' in x: - c_adm.append(x['c_adm']) + for k in x: + cur = temp.get(k, []) + cur.append(x[k]) + temp[k] = cur + out = {} - c_crossattn_out = [] - for c in c_crossattn: - if c.shape[1] < crossattn_max_len: - c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result - c_crossattn_out.append(c) - - if len(c_crossattn_out) > 0: - out['c_crossattn'] = torch.cat(c_crossattn_out) - if len(c_concat) > 0: - out['c_concat'] = torch.cat(c_concat) - if len(c_adm) > 0: - out['c_adm'] = torch.cat(c_adm) + for k in temp: + conds = temp[k] + out[k] = conds[0].concat(conds[1:]) + return out def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in)/100000.0 + out_count = torch.ones_like(x_in) * 1e-37 out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in)/100000.0 + out_uncond_count = torch.ones_like(x_in) * 1e-37 COND = 0 UNCOND = 1 @@ -281,7 +239,6 @@ def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_tot del out_count out_uncond /= out_uncond_count del out_uncond_count - return out_cond, out_uncond @@ -291,29 +248,20 @@ def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_tot cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options) if "sampler_cfg_function" in model_options: - args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep} - return model_options["sampler_cfg_function"](args) + args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x} + return x - model_options["sampler_cfg_function"](args) else: return uncond + (cond - uncond) * cond_scale - -class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): - def __init__(self, model, quantize=False, device='cpu'): - super().__init__(model, model.alphas_cumprod, quantize=quantize) - - def get_v(self, x, t, cond, **kwargs): - return self.inner_model.apply_model(x, t, cond, **kwargs) - - class CFGNoisePredictor(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - self.alphas_cumprod = model.alphas_cumprod def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) return out - + def forward(self, *args, **kwargs): + return self.apply_model(*args, **kwargs) class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): @@ -332,32 +280,40 @@ def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_option return out def simple_scheduler(model, steps): + s = model.model_sampling sigs = [] - ss = len(model.sigmas) / steps + ss = len(s.sigmas) / steps for x in range(steps): - sigs += [float(model.sigmas[-(1 + int(x * ss))])] + sigs += [float(s.sigmas[-(1 + int(x * ss))])] sigs += [0.0] return torch.FloatTensor(sigs) def ddim_scheduler(model, steps): + s = model.model_sampling sigs = [] - ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False) - for x in range(len(ddim_timesteps) - 1, -1, -1): - ts = ddim_timesteps[x] - if ts > 999: - ts = 999 - sigs.append(model.t_to_sigma(torch.tensor(ts))) + ss = len(s.sigmas) // steps + x = 1 + while x < len(s.sigmas): + sigs += [float(s.sigmas[x])] + x += ss + sigs = sigs[::-1] sigs += [0.0] return torch.FloatTensor(sigs) -def sgm_scheduler(model, steps): +def normal_scheduler(model, steps, sgm=False, floor=False): + s = model.model_sampling + start = s.timestep(s.sigma_max) + end = s.timestep(s.sigma_min) + + if sgm: + timesteps = torch.linspace(start, end, steps + 1)[:-1] + else: + timesteps = torch.linspace(start, end, steps) + sigs = [] - timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int) for x in range(len(timesteps)): ts = timesteps[x] - if ts > 999: - ts = 999 - sigs.append(model.t_to_sigma(torch.tensor(ts))) + sigs.append(s.sigma(ts)) sigs += [0.0] return torch.FloatTensor(sigs) @@ -389,19 +345,19 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): # While we're doing this, we can also resolve the mask device and scaling for performance reasons for i in range(len(conditions)): c = conditions[i] - if 'area' in c[1]: - area = c[1]['area'] + if 'area' in c: + area = c['area'] if area[0] == "percentage": - modified = c[1].copy() + modified = c.copy() area = (max(1, round(area[1] * h)), max(1, round(area[2] * w)), round(area[3] * h), round(area[4] * w)) modified['area'] = area - c = [c[0], modified] + c = modified conditions[i] = c - if 'mask' in c[1]: - mask = c[1]['mask'] + if 'mask' in c: + mask = c['mask'] mask = mask.to(device=device) - modified = c[1].copy() + modified = c.copy() if len(mask.shape) == 2: mask = mask.unsqueeze(0) if mask.shape[1] != h or mask.shape[2] != w: @@ -422,66 +378,70 @@ def resolve_areas_and_cond_masks(conditions, h, w, device): modified['area'] = area modified['mask'] = mask - conditions[i] = [c[0], modified] + conditions[i] = modified def create_cond_with_same_area_if_none(conds, c): - if 'area' not in c[1]: + if 'area' not in c: return - c_area = c[1]['area'] + c_area = c['area'] smallest = None for x in conds: - if 'area' in x[1]: - a = x[1]['area'] + if 'area' in x: + a = x['area'] if c_area[2] >= a[2] and c_area[3] >= a[3]: if a[0] + a[2] >= c_area[0] + c_area[2]: if a[1] + a[3] >= c_area[1] + c_area[3]: if smallest is None: smallest = x - elif 'area' not in smallest[1]: + elif 'area' not in smallest: smallest = x else: - if smallest[1]['area'][0] * smallest[1]['area'][1] > a[0] * a[1]: + if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]: smallest = x else: if smallest is None: smallest = x if smallest is None: return - if 'area' in smallest[1]: - if smallest[1]['area'] == c_area: + if 'area' in smallest: + if smallest['area'] == c_area: return - n = c[1].copy() - conds += [[smallest[0], n]] + + out = c.copy() + out['model_conds'] = smallest['model_conds'].copy() #TODO: which fields should be copied? + conds += [out] def calculate_start_end_timesteps(model, conds): + s = model.model_sampling for t in range(len(conds)): x = conds[t] timestep_start = None timestep_end = None - if 'start_percent' in x[1]: - timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['start_percent'] * 999.0))) - if 'end_percent' in x[1]: - timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x[1]['end_percent'] * 999.0))) + if 'start_percent' in x: + timestep_start = s.percent_to_sigma(x['start_percent']) + if 'end_percent' in x: + timestep_end = s.percent_to_sigma(x['end_percent']) if (timestep_start is not None) or (timestep_end is not None): - n = x[1].copy() + n = x.copy() if (timestep_start is not None): n['timestep_start'] = timestep_start if (timestep_end is not None): n['timestep_end'] = timestep_end - conds[t] = [x[0], n] + conds[t] = n def pre_run_control(model, conds): + s = model.model_sampling for t in range(len(conds)): x = conds[t] timestep_start = None timestep_end = None - percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) - if 'control' in x[1]: - x[1]['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function) + percent_to_timestep_function = lambda a: s.percent_to_sigma(a) + if 'control' in x: + x['control'].pre_run(model, percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] @@ -490,16 +450,16 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): uncond_other = [] for t in range(len(conds)): x = conds[t] - if 'area' not in x[1]: - if name in x[1] and x[1][name] is not None: - cond_cnets.append(x[1][name]) + if 'area' not in x: + if name in x and x[name] is not None: + cond_cnets.append(x[name]) else: cond_other.append((x, t)) for t in range(len(uncond)): x = uncond[t] - if 'area' not in x[1]: - if name in x[1] and x[1][name] is not None: - uncond_cnets.append(x[1][name]) + if 'area' not in x: + if name in x and x[name] is not None: + uncond_cnets.append(x[name]) else: uncond_other.append((x, t)) @@ -509,47 +469,35 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): for x in range(len(cond_cnets)): temp = uncond_other[x % len(uncond_other)] o = temp[0] - if name in o[1] and o[1][name] is not None: - n = o[1].copy() + if name in o and o[name] is not None: + n = o.copy() n[name] = uncond_fill_func(cond_cnets, x) - uncond += [[o[0], n]] + uncond += [n] else: - n = o[1].copy() + n = o.copy() n[name] = uncond_fill_func(cond_cnets, x) - uncond[temp[1]] = [o[0], n] - -def encode_adm(model, conds, batch_size, width, height, device, prompt_type): - for t in range(len(conds)): - x = conds[t] - adm_out = None - if 'adm' in x[1]: - adm_out = x[1]["adm"] - else: - params = x[1].copy() - params["width"] = params.get("width", width * 8) - params["height"] = params.get("height", height * 8) - params["prompt_type"] = params.get("prompt_type", prompt_type) - adm_out = model.encode_adm(device=device, **params) - - if adm_out is not None: - x[1] = x[1].copy() - x[1]["adm_encoded"] = comfy.utils.repeat_to_batch_size(adm_out, batch_size).to(device) - - return conds + uncond[temp[1]] = n -def encode_cond(model_function, key, conds, device, **kwargs): +def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwargs): for t in range(len(conds)): x = conds[t] - params = x[1].copy() + params = x.copy() params["device"] = device + params["noise"] = noise + params["width"] = params.get("width", noise.shape[3] * 8) + params["height"] = params.get("height", noise.shape[2] * 8) + params["prompt_type"] = params.get("prompt_type", prompt_type) for k in kwargs: if k not in params: params[k] = kwargs[k] out = model_function(**params) - if out is not None: - x[1] = x[1].copy() - x[1][key] = out + x = x.copy() + model_conds = x['model_conds'].copy() + for k in out: + model_conds[k] = out[k] + x['model_conds'] = model_conds + conds[t] = x return conds class Sampler: @@ -557,42 +505,9 @@ def sample(self): pass def max_denoise(self, model_wrap, sigmas): - return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05) - -class DDIM(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - timesteps = [] - for s in range(sigmas.shape[0]): - timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s])) - noise_mask = None - if denoise_mask is not None: - noise_mask = 1.0 - denoise_mask - - ddim_callback = None - if callback is not None: - total_steps = len(timesteps) - 1 - ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps) - - max_denoise = self.max_denoise(model_wrap, sigmas) - - ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device) - ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) - z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise) - samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps, - batch_size=noise.shape[0], - shape=noise.shape[1:], - verbose=False, - eta=0.0, - x_T=z_enc, - x0=latent_image, - img_callback=ddim_callback, - denoise_function=model_wrap.predict_eps_discrete_timestep, - extra_args=extra_args, - mask=noise_mask, - to_zero=sigmas[-1]==0, - end_step=sigmas.shape[0] - 1, - disable_pbar=disable_pbar) - return samples + max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma class UNIPC(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): @@ -606,13 +521,17 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"] -def ksampler(sampler_name, extra_options={}): +def ksampler(sampler_name, extra_options={}, inpaint_options={}): class KSAMPLER(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): extra_args["denoise_mask"] = denoise_mask model_k = KSamplerX0Inpaint(model_wrap) model_k.latent_image = latent_image - model_k.noise = noise + if inpaint_options.get("random", False): #TODO: Should this be the default? + generator = torch.manual_seed(extra_args.get("seed", 41) + 1) + model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device) + else: + model_k.noise = noise if self.max_denoise(model_wrap, sigmas): noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) @@ -641,11 +560,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N def wrap_model(model): model_denoise = CFGNoisePredictor(model) - if model.model_type == model_base.ModelType.V_PREDICTION: - model_wrap = CompVisVDenoiser(model_denoise, quantize=True) - else: - model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True) - return model_wrap + return model_denoise def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): positive = positive[:] @@ -656,8 +571,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model model_wrap = wrap_model(model) - calculate_start_end_timesteps(model_wrap, negative) - calculate_start_end_timesteps(model_wrap, positive) + calculate_start_end_timesteps(model, negative) + calculate_start_end_timesteps(model, positive) #make sure each cond area has an opposite one with the same area for c in positive: @@ -665,21 +580,17 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model for c in negative: create_cond_with_same_area_if_none(positive, c) - pre_run_control(model_wrap, negative + positive) + pre_run_control(model, negative + positive) - apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) + apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) if latent_image is not None: latent_image = model.process_latent_in(latent_image) - if model.is_adm(): - positive = encode_adm(model, positive, noise.shape[0], noise.shape[3], noise.shape[2], device, "positive") - negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") - - if hasattr(model, 'cond_concat'): - positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed} @@ -690,19 +601,18 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] def calculate_sigmas_scheduler(model, scheduler_name, steps): - model_wrap = wrap_model(model) if scheduler_name == "karras": - sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) elif scheduler_name == "exponential": - sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max)) elif scheduler_name == "normal": - sigmas = model_wrap.get_sigmas(steps) + sigmas = normal_scheduler(model, steps) elif scheduler_name == "simple": - sigmas = simple_scheduler(model_wrap, steps) + sigmas = simple_scheduler(model, steps) elif scheduler_name == "ddim_uniform": - sigmas = ddim_scheduler(model_wrap, steps) + sigmas = ddim_scheduler(model, steps) elif scheduler_name == "sgm_uniform": - sigmas = sgm_scheduler(model_wrap, steps) + sigmas = normal_scheduler(model, steps, sgm=True) else: print("error invalid scheduler", self.scheduler) return sigmas @@ -713,7 +623,7 @@ def sampler_class(name): elif name == "uni_pc_bh2": sampler = UNIPCBH2 elif name == "ddim": - sampler = DDIM + sampler = ksampler("euler", inpaint_options={"random": True}) else: sampler = ksampler(name) return sampler diff --git a/comfy/sd.py b/comfy/sd.py index c364b723cb9..65a61343be1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -55,13 +55,26 @@ def load_clip_weights(model, sd): def load_lora_for_models(model, clip, lora, strength_model, strength_clip): - key_map = comfy.lora.model_lora_keys_unet(model.model) - key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + key_map = {} + if model is not None: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if clip is not None: + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + loaded = comfy.lora.load_lora(lora, key_map) - new_modelpatcher = model.clone() - k = new_modelpatcher.add_patches(loaded, strength_model) - new_clip = clip.clone() - k1 = new_clip.add_patches(loaded, strength_clip) + if model is not None: + new_modelpatcher = model.clone() + k = new_modelpatcher.add_patches(loaded, strength_model) + else: + k = () + new_modelpatcher = None + + if clip is not None: + new_clip = clip.clone() + k1 = new_clip.add_patches(loaded, strength_clip) + else: + k1 = () + new_clip = None k = set(k) k1 = set(k1) for x in loaded: @@ -360,7 +373,7 @@ class EmptyClass: from . import latent_formats model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor) - model_config.unet_config = unet_config + model_config.unet_config = model_detection.convert_config(unet_config) if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type) @@ -388,11 +401,13 @@ class EmptyClass: if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"): clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model.clip_h elif clip_config["target"].endswith("FrozenCLIPEmbedder"): clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model.clip_l load_clip_weights(w, state_dict) return (comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 9978b6c35c6..fdaa1e6c76e 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -35,7 +35,7 @@ def encode_token_weights(self, token_weight_pairs): return z_empty.cpu(), first_pooled.cpu() return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() -class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = [ "last", @@ -278,7 +278,13 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No valid_file = None for embed_dir in embedding_directory: - embed_path = os.path.join(embed_dir, embedding_name) + embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name)) + embed_dir = os.path.abspath(embed_dir) + try: + if os.path.commonpath((embed_dir, embed_path)) != embed_dir: + continue + except: + continue if not os.path.isfile(embed_path): extensions = ['.safetensors', '.pt', '.bin'] for x in extensions: @@ -336,7 +342,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No embed_out = next(iter(values)) return embed_out -class SD1Tokenizer: +class SDTokenizer: def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") @@ -448,3 +454,40 @@ def tokenize_with_weights(self, text:str, return_word_ids=False): def untokenize(self, token_weight_pair): return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair)) + + +class SD1Tokenizer: + def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer): + self.clip_name = clip_name + self.clip = "clip_{}".format(self.clip_name) + setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory)) + + def tokenize_with_weights(self, text:str, return_word_ids=False): + out = {} + out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids) + return out + + def untokenize(self, token_weight_pair): + return getattr(self, self.clip).untokenize(token_weight_pair) + + +class SD1ClipModel(torch.nn.Module): + def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs): + super().__init__() + self.clip_name = clip_name + self.clip = "clip_{}".format(self.clip_name) + setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) + + def clip_layer(self, layer_idx): + getattr(self, self.clip).clip_layer(layer_idx) + + def reset_clip_layer(self): + getattr(self, self.clip).reset_clip_layer() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs = token_weight_pairs[self.clip_name] + out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs) + return out, pooled + + def load_sd(self, sd): + return getattr(self, self.clip).load_sd(sd) diff --git a/comfy/sd2_clip.py b/comfy/sd2_clip.py index 05e50a0057b..ebabf7ccd51 100644 --- a/comfy/sd2_clip.py +++ b/comfy/sd2_clip.py @@ -2,7 +2,7 @@ import torch import os -class SD2ClipModel(sd1_clip.SD1ClipModel): +class SD2ClipHModel(sd1_clip.SDClipModel): def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" @@ -12,6 +12,14 @@ def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, la super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) self.empty_tokens = [[49406] + [49407] + [0] * 75] -class SD2Tokenizer(sd1_clip.SD1Tokenizer): +class SD2ClipHTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024) + +class SD2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer) + +class SD2ClipModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, **kwargs): + super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs) diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index e3ac2ee0b4a..4c508a0ea88 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -2,7 +2,7 @@ import torch import os -class SDXLClipG(sd1_clip.SD1ClipModel): +class SDXLClipG(sd1_clip.SDClipModel): def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None): if layer == "penultimate": layer="hidden" @@ -16,14 +16,14 @@ def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate" def load_sd(self, sd): return super().load_sd(sd) -class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer): +class SDXLClipGTokenizer(sd1_clip.SDTokenizer): def __init__(self, tokenizer_path=None, embedding_directory=None): super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') -class SDXLTokenizer(sd1_clip.SD1Tokenizer): +class SDXLTokenizer: def __init__(self, embedding_directory=None): - self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory) + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) def tokenize_with_weights(self, text:str, return_word_ids=False): @@ -38,7 +38,7 @@ def untokenize(self, token_weight_pair): class SDXLClipModel(torch.nn.Module): def __init__(self, device="cpu", dtype=None): super().__init__() - self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) + self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) self.clip_l.layer_norm_hidden_state = False self.clip_g = SDXLClipG(device=device, dtype=dtype) @@ -63,21 +63,6 @@ def load_sd(self, sd): else: return self.clip_l.load_sd(sd) -class SDXLRefinerClipModel(torch.nn.Module): +class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None): - super().__init__() - self.clip_g = SDXLClipG(device=device, dtype=dtype) - - def clip_layer(self, layer_idx): - self.clip_g.clip_layer(layer_idx) - - def reset_clip_layer(self): - self.clip_g.reset_clip_layer() - - def encode_token_weights(self, token_weight_pairs): - token_weight_pairs_g = token_weight_pairs["g"] - g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) - return g_out, g_pooled - - def load_sd(self, sd): - return self.clip_g.load_sd(sd) + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index bb8ae2148fd..fdd4ea4f5c2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -38,8 +38,15 @@ def process_clip_state_dict(self, state_dict): if ids.dtype == torch.float32: state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() + replace_prefix = {} + replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l." + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) return state_dict + def process_clip_state_dict_for_saving(self, state_dict): + replace_prefix = {"clip_l.": "cond_stage_model."} + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def clip_target(self): return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel) @@ -62,12 +69,12 @@ def model_type(self, state_dict, prefix=""): return model_base.ModelType.EPS def process_clip_state_dict(self, state_dict): - state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24) + state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) return state_dict def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {} - replace_prefix[""] = "cond_stage_model.model." + replace_prefix["clip_h"] = "cond_stage_model.model" state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict) return state_dict @@ -104,7 +111,7 @@ class SDXLRefiner(supported_models_base.BASE): "use_linear_in_transformer": True, "context_dim": 1280, "adm_in_channels": 2560, - "transformer_depth": [0, 4, 4, 0], + "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0], } latent_format = latent_formats.SDXL @@ -139,7 +146,7 @@ class SDXL(supported_models_base.BASE): unet_config = { "model_channels": 320, "use_linear_in_transformer": True, - "transformer_depth": [0, 2, 10], + "transformer_depth": [0, 0, 2, 2, 10, 10], "context_dim": 2048, "adm_in_channels": 2816 } @@ -165,6 +172,7 @@ def process_clip_state_dict(self, state_dict): replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model" state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32) keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection" + keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection" keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale" state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -189,5 +197,14 @@ def process_clip_state_dict_for_saving(self, state_dict): def clip_target(self): return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel) +class SSD1B(SDXL): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "transformer_depth": [0, 0, 2, 2, 4, 4], + "context_dim": 2048, + "adm_in_channels": 2816 + } + -models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL] +models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] diff --git a/comfy/utils.py b/comfy/utils.py index a1807aa1d47..6a0c54e8098 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -170,25 +170,12 @@ def transformers_convert(sd, prefix_from, prefix_to, number): def unet_to_diffusers(unet_config): num_res_blocks = unet_config["num_res_blocks"] - attention_resolutions = unet_config["attention_resolutions"] channel_mult = unet_config["channel_mult"] - transformer_depth = unet_config["transformer_depth"] + transformer_depth = unet_config["transformer_depth"][:] + transformer_depth_output = unet_config["transformer_depth_output"][:] num_blocks = len(channel_mult) - if isinstance(num_res_blocks, int): - num_res_blocks = [num_res_blocks] * num_blocks - if isinstance(transformer_depth, int): - transformer_depth = [transformer_depth] * num_blocks - - transformers_per_layer = [] - res = 1 - for i in range(num_blocks): - transformers = 0 - if res in attention_resolutions: - transformers = transformer_depth[i] - transformers_per_layer.append(transformers) - res *= 2 - - transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1]) + + transformers_mid = unet_config.get("transformer_depth_middle", None) diffusers_unet_map = {} for x in range(num_blocks): @@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config): for i in range(num_res_blocks[x]): for b in UNET_MAP_RESNET: diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b) - if transformers_per_layer[x] > 0: + num_transformers = transformer_depth.pop(0) + if num_transformers > 0: for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b) - for t in range(transformers_per_layer[x]): + for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) n += 1 @@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config): diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b) num_res_blocks = list(reversed(num_res_blocks)) - transformers_per_layer = list(reversed(transformers_per_layer)) for x in range(num_blocks): n = (num_res_blocks[x] + 1) * x l = num_res_blocks[x] + 1 @@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config): for b in UNET_MAP_RESNET: diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b) c += 1 - if transformers_per_layer[x] > 0: + num_transformers = transformer_depth_output.pop() + if num_transformers > 0: c += 1 for b in UNET_MAP_ATTENTIONS: diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b) - for t in range(transformers_per_layer[x]): + for t in range(num_transformers): for b in TRANSFORMER_BLOCKS: diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b) if i == l - 1: diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 3f651e59456..324cfe105f2 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -126,7 +126,7 @@ def INPUT_TYPES(s): "max": 256, "step": 1 }), - "dither": (["none", "floyd-steinberg"],), + "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],), }, } @@ -135,19 +135,47 @@ def INPUT_TYPES(s): CATEGORY = "image/postprocessing" - def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): + def bayer(im, pal_im, order): + def normalized_bayer_matrix(n): + if n == 0: + return np.zeros((1,1), "float32") + else: + q = 4 ** n + m = q * normalized_bayer_matrix(n - 1) + return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q + + num_colors = len(pal_im.getpalette()) // 3 + spread = 2 * 256 / num_colors + bayer_n = int(math.log2(order)) + bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5) + + result = torch.from_numpy(np.array(im).astype(np.float32)) + tw = math.ceil(result.shape[0] / bayer_matrix.shape[0]) + th = math.ceil(result.shape[1] / bayer_matrix.shape[1]) + tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1) + result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255) + result = result.to(dtype=torch.uint8) + + im = Image.fromarray(result.cpu().numpy()) + im = im.quantize(palette=pal_im, dither=Image.Dither.NONE) + return im + + def quantize(self, image: torch.Tensor, colors: int, dither: str): batch_size, height, width, _ = image.shape result = torch.zeros_like(image) - dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE - for b in range(batch_size): - tensor_image = image[b] - img = (tensor_image * 255).to(torch.uint8).numpy() - pil_image = Image.fromarray(img, mode='RGB') + im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB') + + pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 - palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 - quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option) + if dither == "none": + quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE) + elif dither == "floyd-steinberg": + quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG) + elif dither.startswith("bayer"): + order = int(dither.split('-')[-1]) + quantized_image = Quantize.bayer(im, pal_im, order) quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 result[b] = quantized_array diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index 0a9daf27276..88a4ebe29f6 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -4,7 +4,7 @@ class LatentRebatch: @classmethod def INPUT_TYPES(s): return {"required": { "latents": ("LATENT",), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), }} RETURN_TYPES = ("LATENT",) INPUT_IS_LIST = True diff --git a/latent_preview.py b/latent_preview.py index e1553c85cac..6e758a1a9d1 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -22,7 +22,7 @@ def __init__(self, taesd): self.taesd = taesd def decode_latent_to_preview(self, x0): - x_sample = self.taesd.decoder(x0)[0].detach() + x_sample = self.taesd.decoder(x0[:1])[0].detach() # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] x_sample = x_sample.sub(0.5).mul(2) diff --git a/server.py b/server.py index 63f337a873f..11bd2a0fb44 100644 --- a/server.py +++ b/server.py @@ -82,7 +82,8 @@ def __init__(self, loop): if args.enable_cors_header: middlewares.append(create_cors_middleware(args.enable_cors_header)) - self.app = web.Application(client_max_size=104857600, middlewares=middlewares) + max_upload_size = round(args.max_upload_size * 1024 * 1024) + self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.sockets = dict() self.web_root = os.path.join(os.path.dirname( os.path.realpath(__file__)), "web") diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js index 152cd7043de..0a305391a4e 100644 --- a/web/extensions/core/contextMenuFilter.js +++ b/web/extensions/core/contextMenuFilter.js @@ -25,7 +25,7 @@ const ext = { requestAnimationFrame(() => { const currentNode = LGraphCanvas.active_canvas.current_node; const clickedComboValue = currentNode.widgets - .filter(w => w.type === "combo" && w.options.values.length === values.length) + ?.filter(w => w.type === "combo" && w.options.values.length === values.length) .find(w => w.options.values.every((v, i) => v === values[i])) ?.value; diff --git a/web/extensions/core/nodeTemplates.js b/web/extensions/core/nodeTemplates.js index a1a8c8251bf..81b4be9357b 100644 --- a/web/extensions/core/nodeTemplates.js +++ b/web/extensions/core/nodeTemplates.js @@ -15,6 +15,9 @@ import { GROUP_DATA, IS_GROUP_NODE, registerGroupNodes } from "./groupNode.js"; // To delete/rename: // Right click the canvas // Node templates -> Manage +// +// To rearrange: +// Open the manage dialog and Drag and drop elements using the "Name:" label as handle const id = "Comfy.NodeTemplates"; @@ -23,6 +26,10 @@ class ManageTemplates extends ComfyDialog { super(); this.element.classList.add("comfy-manage-templates"); this.templates = this.load(); + this.draggedEl = null; + this.saveVisualCue = null; + this.emptyImg = new Image(); + this.emptyImg.src = 'data:image/gif;base64,R0lGODlhAQABAIAAAAUEBAAAACwAAAAAAQABAAACAkQBADs='; this.importInput = $el("input", { type: "file", @@ -36,14 +43,11 @@ class ManageTemplates extends ComfyDialog { createButtons() { const btns = super.createButtons(); - btns[0].textContent = "Cancel"; - btns.unshift( - $el("button", { - type: "button", - textContent: "Save", - onclick: () => this.save(), - }) - ); + btns[0].textContent = "Close"; + btns[0].onclick = (e) => { + clearTimeout(this.saveVisualCue); + this.close(); + }; btns.unshift( $el("button", { type: "button", @@ -72,25 +76,6 @@ class ManageTemplates extends ComfyDialog { } } - save() { - // Find all visible inputs and save them as our new list - const inputs = this.element.querySelectorAll("input"); - const updated = []; - - for (let i = 0; i < inputs.length; i++) { - const input = inputs[i]; - if (input.parentElement.style.display !== "none") { - const t = this.templates[i]; - t.name = input.value.trim() || input.getAttribute("data-name"); - updated.push(t); - } - } - - this.templates = updated; - this.store(); - this.close(); - } - store() { localStorage.setItem(id, JSON.stringify(this.templates)); } @@ -146,71 +131,155 @@ class ManageTemplates extends ComfyDialog { super.show( $el( "div", - { - style: { - display: "grid", - gridTemplateColumns: "1fr auto", - gap: "5px", - }, - }, - this.templates.flatMap((t) => { + {}, + this.templates.flatMap((t,i) => { let nameInput; return [ $el( - "label", + "div", { - textContent: "Name: ", + dataset: { id: i }, + className: "tempateManagerRow", + style: { + display: "grid", + gridTemplateColumns: "1fr auto", + border: "1px dashed transparent", + gap: "5px", + backgroundColor: "var(--comfy-menu-bg)" + }, + ondragstart: (e) => { + this.draggedEl = e.currentTarget; + e.currentTarget.style.opacity = "0.6"; + e.currentTarget.style.border = "1px dashed yellow"; + e.dataTransfer.effectAllowed = 'move'; + e.dataTransfer.setDragImage(this.emptyImg, 0, 0); + }, + ondragend: (e) => { + e.target.style.opacity = "1"; + e.currentTarget.style.border = "1px dashed transparent"; + e.currentTarget.removeAttribute("draggable"); + + // rearrange the elements in the localStorage + this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => { + var prev_i = el.dataset.id; + + if ( el == this.draggedEl && prev_i != i ) { + [this.templates[i], this.templates[prev_i]] = [this.templates[prev_i], this.templates[i]]; + } + el.dataset.id = i; + }); + this.store(); + }, + ondragover: (e) => { + e.preventDefault(); + if ( e.currentTarget == this.draggedEl ) + return; + + let rect = e.currentTarget.getBoundingClientRect(); + if (e.clientY > rect.top + rect.height / 2) { + e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget.nextSibling); + } else { + e.currentTarget.parentNode.insertBefore(this.draggedEl, e.currentTarget); + } + } }, [ - $el("input", { - value: t.name, - dataset: { name: t.name }, - $: (el) => (nameInput = el), - }), - ] - ), - $el( - "div", - {}, - [ - $el("button", { - textContent: "Export", - style: { - fontSize: "12px", - fontWeight: "normal", - }, - onclick: (e) => { - const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string - const blob = new Blob([json], {type: "application/json"}); - const url = URL.createObjectURL(blob); - const a = $el("a", { - href: url, - download: (nameInput.value || t.name) + ".json", - style: {display: "none"}, - parent: document.body, - }); - a.click(); - setTimeout(function () { - a.remove(); - window.URL.revokeObjectURL(url); - }, 0); - }, - }), - $el("button", { - textContent: "Delete", - style: { - fontSize: "12px", - color: "red", - fontWeight: "normal", - }, - onclick: (e) => { - nameInput.value = ""; - e.target.parentElement.style.display = "none"; - e.target.parentElement.previousElementSibling.style.display = "none"; + $el( + "label", + { + textContent: "Name: ", + style: { + cursor: "grab", + }, + onmousedown: (e) => { + // enable dragging only from the label + if (e.target.localName == 'label') + e.currentTarget.parentNode.draggable = 'true'; + } }, - }), + [ + $el("input", { + value: t.name, + dataset: { name: t.name }, + style: { + transitionProperty: 'background-color', + transitionDuration: '0s', + }, + onchange: (e) => { + clearTimeout(this.saveVisualCue); + var el = e.target; + var row = el.parentNode.parentNode; + this.templates[row.dataset.id].name = el.value.trim() || 'untitled'; + this.store(); + el.style.backgroundColor = 'rgb(40, 95, 40)'; + el.style.transitionDuration = '0s'; + this.saveVisualCue = setTimeout(function () { + el.style.transitionDuration = '.7s'; + el.style.backgroundColor = 'var(--comfy-input-bg)'; + }, 15); + }, + onkeypress: (e) => { + var el = e.target; + clearTimeout(this.saveVisualCue); + el.style.transitionDuration = '0s'; + el.style.backgroundColor = 'var(--comfy-input-bg)'; + }, + $: (el) => (nameInput = el), + }) + ] + ), + $el( + "div", + {}, + [ + $el("button", { + textContent: "Export", + style: { + fontSize: "12px", + fontWeight: "normal", + }, + onclick: (e) => { + const json = JSON.stringify({templates: [t]}, null, 2); // convert the data to a JSON string + const blob = new Blob([json], {type: "application/json"}); + const url = URL.createObjectURL(blob); + const a = $el("a", { + href: url, + download: (nameInput.value || t.name) + ".json", + style: {display: "none"}, + parent: document.body, + }); + a.click(); + setTimeout(function () { + a.remove(); + window.URL.revokeObjectURL(url); + }, 0); + }, + }), + $el("button", { + textContent: "Delete", + style: { + fontSize: "12px", + color: "red", + fontWeight: "normal", + }, + onclick: (e) => { + const item = e.target.parentNode.parentNode; + item.parentNode.removeChild(item); + this.templates.splice(item.dataset.id*1, 1); + this.store(); + // update the rows index, setTimeout ensures that the list is updated + var that = this; + setTimeout(function (){ + that.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => { + el.dataset.id = i; + }); + }, 0); + }, + }), + ] + ), ] - ), + ) ]; }) ) diff --git a/web/scripts/app.js b/web/scripts/app.js index 55d4b120ebf..ebbb64be6ce 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -3,7 +3,7 @@ import { ComfyWidgets, getWidgetType } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; -import { getPngMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; +import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; /** * @typedef {import("types/comfy").ComfyExtension} ComfyExtension @@ -1601,17 +1601,25 @@ export class ComfyApp { * @returns The workflow and node links */ async graphToPrompt() { - const workflow = this.graph.serialize(); - const output = {}; - // Process nodes in order of execution for (const outerNode of this.graph.computeExecutionOrder(false)) { const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode]; for (const node of innerNodes) { if (node.isVirtualNode) { // Don't serialize frontend only nodes but let them make changes if (node.applyToGraph) { - node.applyToGraph(workflow); + node.applyToGraph(); } + } + } + } + + const workflow = this.graph.serialize(); + const output = {}; + // Process nodes in order of execution + for (const outerNode of this.graph.computeExecutionOrder(false)) { + const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode]; + for (const node of innerNodes) { + if (node.isVirtualNode) { continue; } @@ -1809,6 +1817,15 @@ export class ComfyApp { importA1111(this.graph, pngInfo.parameters); } } + } else if (file.type === "image/webp") { + const pngInfo = await getWebpMetadata(file); + if (pngInfo) { + if (pngInfo.workflow) { + this.loadGraphData(JSON.parse(pngInfo.workflow)); + } else if (pngInfo.Workflow) { + this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node. + } + } } else if (file.type === "application/json" || file.name?.endsWith(".json")) { const reader = new FileReader(); reader.onload = async () => { diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index c5293dfa332..491caed79f5 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -47,6 +47,103 @@ export function getPngMetadata(file) { }); } +function parseExifData(exifData) { + // Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian) + const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949; + console.log(exifData); + + // Function to read 16-bit and 32-bit integers from binary data + function readInt(offset, isLittleEndian, length) { + let arr = exifData.slice(offset, offset + length) + if (length === 2) { + return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint16(0, isLittleEndian); + } else if (length === 4) { + return new DataView(arr.buffer, arr.byteOffset, arr.byteLength).getUint32(0, isLittleEndian); + } + } + + // Read the offset to the first IFD (Image File Directory) + const ifdOffset = readInt(4, isLittleEndian, 4); + + function parseIFD(offset) { + const numEntries = readInt(offset, isLittleEndian, 2); + const result = {}; + + for (let i = 0; i < numEntries; i++) { + const entryOffset = offset + 2 + i * 12; + const tag = readInt(entryOffset, isLittleEndian, 2); + const type = readInt(entryOffset + 2, isLittleEndian, 2); + const numValues = readInt(entryOffset + 4, isLittleEndian, 4); + const valueOffset = readInt(entryOffset + 8, isLittleEndian, 4); + + // Read the value(s) based on the data type + let value; + if (type === 2) { + // ASCII string + value = String.fromCharCode(...exifData.slice(valueOffset, valueOffset + numValues - 1)); + } + + result[tag] = value; + } + + return result; + } + + // Parse the first IFD + const ifdData = parseIFD(ifdOffset); + return ifdData; +} + +function splitValues(input) { + var output = {}; + for (var key in input) { + var value = input[key]; + var splitValues = value.split(':', 2); + output[splitValues[0]] = splitValues[1]; + } + return output; +} + +export function getWebpMetadata(file) { + return new Promise((r) => { + const reader = new FileReader(); + reader.onload = (event) => { + const webp = new Uint8Array(event.target.result); + const dataView = new DataView(webp.buffer); + + // Check that the WEBP signature is present + if (dataView.getUint32(0) !== 0x52494646 || dataView.getUint32(8) !== 0x57454250) { + console.error("Not a valid WEBP file"); + r(); + return; + } + + // Start searching for chunks after the WEBP signature + let offset = 12; + let txt_chunks = {}; + // Loop through the chunks in the WEBP file + while (offset < webp.length) { + const chunk_length = dataView.getUint32(offset + 4, true); + const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4)); + if (chunk_type === "EXIF") { + let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length)); + for (var key in data) { + var value = data[key]; + let index = value.indexOf(':'); + txt_chunks[value.slice(0, index)] = value.slice(index + 1); + } + } + + offset += 8 + chunk_length; + } + + r(txt_chunks); + }; + + reader.readAsArrayBuffer(file); + }); +} + export function getLatentMetadata(file) { return new Promise((r) => { const reader = new FileReader(); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 5b719d1a986..30dcc9e77b1 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -719,20 +719,22 @@ export class ComfyUI { filename += ".json"; } } - const json = JSON.stringify(app.graph.serialize(), null, 2); // convert the data to a JSON string - const blob = new Blob([json], {type: "application/json"}); - const url = URL.createObjectURL(blob); - const a = $el("a", { - href: url, - download: filename, - style: {display: "none"}, - parent: document.body, + app.graphToPrompt().then(p=>{ + const json = JSON.stringify(p.workflow, null, 2); // convert the data to a JSON string + const blob = new Blob([json], {type: "application/json"}); + const url = URL.createObjectURL(blob); + const a = $el("a", { + href: url, + download: filename, + style: {display: "none"}, + parent: document.body, + }); + a.click(); + setTimeout(function () { + a.remove(); + window.URL.revokeObjectURL(url); + }, 0); }); - a.click(); - setTimeout(function () { - a.remove(); - window.URL.revokeObjectURL(url); - }, 0); }, }), $el("button", {