Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into group-nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
pythongosssss committed Nov 3, 2023
2 parents a92409c + ae2acfc commit 4c928d2
Show file tree
Hide file tree
Showing 37 changed files with 979 additions and 2,716 deletions.
47 changes: 26 additions & 21 deletions comfy/cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
model_channels,
hint_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
Expand All @@ -52,6 +51,7 @@ def __init__(
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
operations=comfy.ops,
):
Expand Down Expand Up @@ -79,29 +79,24 @@ def __init__(
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]

if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks

if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")

self.attention_resolutions = attention_resolutions
transformer_depth = transformer_depth[:]

self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
Expand Down Expand Up @@ -180,11 +175,14 @@ def __init__(
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
operations=operations
dtype=self.dtype,
device=device,
operations=operations,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_transformers = transformer_depth.pop(0)
if num_transformers > 0:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
Expand All @@ -201,9 +199,9 @@ def __init__(
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, operations=operations
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
Expand All @@ -223,11 +221,13 @@ def __init__(
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
dtype=self.dtype,
device=device,
operations=operations
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
)
)
)
Expand All @@ -245,20 +245,23 @@ def __init__(
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
mid_block = [
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
SpatialTransformer( # always uses a self-attn
)]
if transformer_depth_middle >= 0:
mid_block += [SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, operations=operations
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
),
ResBlock(
ch,
Expand All @@ -267,9 +270,11 @@ def __init__(
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
)
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self._feature_size += ch

Expand Down
2 changes: 2 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __call__(self, parser, namespace, values, option_string=None):
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")

parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
Expand Down
31 changes: 15 additions & 16 deletions comfy/clip_vision.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
from .utils import load_torch_file, transformers_convert
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
from .utils import load_torch_file, transformers_convert, common_upscale
import os
import torch
import contextlib

import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.utils

def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
scale = (size / min(image.shape[1], image.shape[2]))
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])

class ClipVisionModel():
def __init__(self, json_config):
Expand All @@ -23,25 +35,12 @@ def __init__(self, json_config):
self.model.to(self.dtype)

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)

def encode_image(self, image):
img = torch.clip((255. * image), 0, 255).round().int()
img = list(map(lambda a: a, img))
inputs = self.processor(images=img, return_tensors="pt")
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = inputs['pixel_values'].to(self.load_device)
pixel_values = clip_preprocess(image.to(self.load_device))

if self.dtype != torch.float32:
precision_scope = torch.autocast
Expand Down
64 changes: 64 additions & 0 deletions comfy/conds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import enum
import torch
import math
import comfy.utils


def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)

class CONDRegular:
def __init__(self, cond):
self.cond = cond

def _copy_with(self, cond):
return self.__class__(cond)

def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))

def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
return True

def concat(self, others):
conds = [self.cond]
for x in others:
conds.append(x.cond)
return torch.cat(conds)

class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))


class CONDCrossAttn(CONDRegular):
def can_concat(self, other):
s1 = self.cond.shape
s2 = other.cond.shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False

mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
return True

def concat(self, others):
conds = [self.cond]
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
conds.append(c)

out = []
for c in conds:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c)
return torch.cat(out)
16 changes: 14 additions & 2 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(self, control_model, global_average_pooling=False, device=None):
self.control_model = control_model
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None

def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
Expand All @@ -156,10 +157,13 @@ def get_control(self, x_noisy, t, cond, batched_number):


context = cond['c_crossattn']
y = cond.get('c_adm', None)
y = cond.get('y', None)
if y is not None:
y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)

def copy(self):
Expand All @@ -172,6 +176,14 @@ def get_models(self):
out.append(self.control_model_wrapped)
return out

def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.model_sampling_current = model.model_sampling

def cleanup(self):
self.model_sampling_current = None
super().cleanup()

class ControlLoraOps:
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
Expand Down
10 changes: 8 additions & 2 deletions comfy/extra_samplers/uni_pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,12 @@ def marginal_lambda(self, t):
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std

def predict_eps_sigma(model, input, sigma_in, **kwargs):
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
input = input * ((sigma ** 2 + 1.0) ** 0.5)
return (input - model(input, sigma_in, **kwargs)) / sigma


def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
timesteps = sigmas.clone()
if sigmas[-1] == 0:
Expand All @@ -874,14 +880,14 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
model_type = "noise"

model_fn = model_wrapper(
model.predict_eps_sigma,
lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
ns,
model_type=model_type,
guidance_type="uncond",
model_kwargs=extra_args,
)

order = min(3, len(timesteps) - 1)
order = min(3, len(timesteps) - 2)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
x /= ns.marginal_alpha(timesteps[-1])
Expand Down
Loading

0 comments on commit 4c928d2

Please sign in to comment.