diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index f8a16159452..3feb9bab0cd 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -13,7 +13,47 @@ from ..ldm.modules.attention import SpatialTransformer from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample from ..ldm.util import exists +from ..ldm.cascade.common import OptimizedAttention +from collections import OrderedDict import comfy.ops +from comfy.ldm.modules.attention import optimized_attention + +class OptimizedAttention(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.heads = nhead + self.c = c + + self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device) + self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + def forward(self, x): + x = self.in_proj(x) + q, k, v = x.split(self.c, dim=2) + out = optimized_attention(q, k, v, self.heads) + return self.out_proj(out) + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + +class ResBlockUnionControlnet(nn.Module): + def __init__(self, dim, nhead, dtype=None, device=None, operations=None): + super().__init__() + self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations) + self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device) + self.mlp = nn.Sequential( + OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()), + ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))])) + self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device) + + def attention(self, x: torch.Tensor): + return self.attn(x) + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x class ControlledUnetModel(UNetModel): #implemented in the ldm unet @@ -53,6 +93,7 @@ def __init__( transformer_depth_middle=None, transformer_depth_output=None, attn_precision=None, + union_controlnet=False, device=None, operations=comfy.ops.disable_weight_init, **kwargs, @@ -280,6 +321,65 @@ def __init__( self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) self._feature_size += ch + if union_controlnet: + self.num_control_type = 6 + num_trans_channel = 320 + num_trans_head = 8 + num_trans_layer = 1 + num_proj_channel = 320 + # task_scale_factor = num_trans_channel ** 0.5 + self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device)) + + self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)]) + self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device) + #----------------------------------------------------------------------------------------------------- + + control_add_embed_dim = 256 + class ControlAddEmbedding(nn.Module): + def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None): + super().__init__() + self.num_control_type = num_control_type + self.in_dim = in_dim + self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device) + self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device) + def forward(self, control_type, dtype, device): + c_type = torch.zeros((self.num_control_type,), device=device) + c_type[control_type] = 1.0 + c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim)) + return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type))) + + self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations) + else: + self.task_embedding = None + self.control_add_embedding = None + + def union_controlnet_merge(self, hint, control_type, emb, context): + # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main + inputs = [] + condition_list = [] + + for idx in range(min(1, len(control_type))): + controlnet_cond = self.input_hint_block(hint[idx], emb, context) + feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) + if idx < len(control_type): + feat_seq += self.task_embedding[control_type[idx]] + + inputs.append(feat_seq.unsqueeze(1)) + condition_list.append(controlnet_cond) + + x = torch.cat(inputs, dim=1) + x = self.transformer_layes(x) + controlnet_cond_fuser = None + for idx in range(len(control_type)): + alpha = self.spatial_ch_projs(x[:, idx]) + alpha = alpha.unsqueeze(-1).unsqueeze(-1) + o = condition_list[idx] + alpha + if controlnet_cond_fuser is None: + controlnet_cond_fuser = o + else: + controlnet_cond_fuser += o + return controlnet_cond_fuser + def make_zero_conv(self, channels, operations=None, dtype=None, device=None): return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) @@ -287,7 +387,18 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs): t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) emb = self.time_embed(t_emb) - guided_hint = self.input_hint_block(hint, emb, context) + guided_hint = None + if self.control_add_embedding is not None: + control_type = kwargs.get("control_type", []) + + emb += self.control_add_embedding(control_type, emb.dtype, emb.device) + if len(control_type) > 0: + if len(hint.shape) < 5: + hint = hint.unsqueeze(dim=0) + guided_hint = self.union_controlnet_merge(hint, control_type, emb, context) + + if guided_hint is None: + guided_hint = self.input_hint_block(hint, emb, context) out_output = [] out_middle = [] diff --git a/comfy/controlnet.py b/comfy/controlnet.py index d0039513968..84286f1fac7 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -413,6 +413,12 @@ def load_controlnet(ckpt_path, model=None): if k in controlnet_data: new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet + controlnet_config["union_controlnet"] = True + for k in list(controlnet_data.keys()): + new_k = k.replace('.attn.in_proj_', '.attn.in_proj.') + new_sd[new_k] = controlnet_data.pop(k) + leftover_keys = controlnet_data.keys() if len(leftover_keys) > 0: logging.warning("leftover keys: {}".format(leftover_keys))