Skip to content

Commit

Permalink
Controlnet union model basic implementation.
Browse files Browse the repository at this point in the history
This is only the model code itself, it currently defaults to an empty
embedding [0] * 6 which seems to work better than treating it like a
regular controlnet.

TODO: Add nodes to select the image type.
  • Loading branch information
comfyanonymous committed Jul 9, 2024
1 parent bb663bc commit faa5743
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 1 deletion.
113 changes: 112 additions & 1 deletion comfy/cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,47 @@
from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists
from ..ldm.cascade.common import OptimizedAttention
from collections import OrderedDict
import comfy.ops
from comfy.ldm.modules.attention import optimized_attention

class OptimizedAttention(nn.Module):

This comment has been minimized.

Copy link
@huchenlei

huchenlei Jul 9, 2024

Collaborator

Redefinition of OptimizedAttention on line 16. Is that intentional?

This comment has been minimized.

Copy link
@comfyanonymous

comfyanonymous Jul 9, 2024

Author Owner

Fixed.

def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.heads = nhead
self.c = c

self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)

def forward(self, x):
x = self.in_proj(x)
q, k, v = x.split(self.c, dim=2)
out = optimized_attention(q, k, v, self.heads)
return self.out_proj(out)

class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)

class ResBlockUnionControlnet(nn.Module):
def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
super().__init__()
self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
self.mlp = nn.Sequential(
OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)

def attention(self, x: torch.Tensor):
return self.attn(x)

def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x

class ControlledUnetModel(UNetModel):
#implemented in the ldm unet
Expand Down Expand Up @@ -53,6 +93,7 @@ def __init__(
transformer_depth_middle=None,
transformer_depth_output=None,
attn_precision=None,
union_controlnet=False,
device=None,
operations=comfy.ops.disable_weight_init,
**kwargs,
Expand Down Expand Up @@ -280,14 +321,84 @@ def __init__(
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch

if union_controlnet:
self.num_control_type = 6
num_trans_channel = 320
num_trans_head = 8
num_trans_layer = 1
num_proj_channel = 320
# task_scale_factor = num_trans_channel ** 0.5
self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))

self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
#-----------------------------------------------------------------------------------------------------

control_add_embed_dim = 256
class ControlAddEmbedding(nn.Module):
def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
super().__init__()
self.num_control_type = num_control_type
self.in_dim = in_dim
self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
def forward(self, control_type, dtype, device):
c_type = torch.zeros((self.num_control_type,), device=device)
c_type[control_type] = 1.0
c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))

self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
else:
self.task_embedding = None
self.control_add_embedding = None

def union_controlnet_merge(self, hint, control_type, emb, context):

This comment has been minimized.

Copy link
@huchenlei

huchenlei Jul 9, 2024

Collaborator

Crash here if I set control_type to a list of [1, 0, 0, 0, 0, 0].

image
image

Error occurred when executing KSampler:

index 1 is out of bounds for dimension 1 with size 1

File "D:\ComfyUI_windows_portable\ComfyUI\execution.py", line 151, in recursive_execute
output_data, output_ui = get_output_data(obj, input_data_all)
File "D:\ComfyUI_windows_portable\ComfyUI\execution.py", line 81, in get_output_data
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
File "D:\ComfyUI_windows_portable\ComfyUI\execution.py", line 74, in map_node_over_list
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
File "D:\ComfyUI_windows_portable\ComfyUI\nodes.py", line 1371, in sample
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
File "D:\ComfyUI_windows_portable\ComfyUI\nodes.py", line 1341, in common_ksampler
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
File "D:\ComfyUI_windows_portable\ComfyUI\custom_nodes\ComfyUI-Impact-Pack\modules\impact\sample_error_enhancer.py", line 9, in informative_sample
return original_sample(*args, **kwargs) # This code helps interpret error messages that occur within exceptions but does not have any impact on other operations.
File "D:\ComfyUI_windows_portable\ComfyUI\custom_nodes\ComfyUI-AnimateDiff-Evolved\animatediff\sampling.py", line 248, in motion_sample
return orig_comfy_sample(model, noise, *args, **kwargs)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\sample.py", line 43, in sample
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 795, in sample
return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 697, in sample
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 684, in sample
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 663, in inner_sample
samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 568, in sample
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options)
File "D:\stable-diffusion-webui\venv\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\k_diffusion\sampling.py", line 143, in sample_euler
denoised = model(x, sigma_hat * s_in, **extra_args)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 291, in __call__
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 650, in __call__
return self.predict_noise(*args, **kwargs)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 653, in predict_noise
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 277, in sampling_function
out = calc_cond_batch(model, conds, x, timestep, model_options)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\samplers.py", line 200, in calc_cond_batch
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\controlnet.py", line 194, in get_control
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
File "D:\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\stable-diffusion-webui\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\cldm\cldm.py", line 398, in forward
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
File "D:\ComfyUI_windows_portable\ComfyUI\comfy\cldm\cldm.py", line 374, in union_controlnet_merge
alpha = self.spatial_ch_projs(x[:, idx])
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
inputs = []
condition_list = []

for idx in range(min(1, len(control_type))):
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
if idx < len(control_type):
feat_seq += self.task_embedding[control_type[idx]]

inputs.append(feat_seq.unsqueeze(1))
condition_list.append(controlnet_cond)

x = torch.cat(inputs, dim=1)
x = self.transformer_layes(x)
controlnet_cond_fuser = None
for idx in range(len(control_type)):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
o = condition_list[idx] + alpha
if controlnet_cond_fuser is None:
controlnet_cond_fuser = o
else:
controlnet_cond_fuser += o
return controlnet_cond_fuser

def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))

def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)

guided_hint = self.input_hint_block(hint, emb, context)
guided_hint = None
if self.control_add_embedding is not None:
control_type = kwargs.get("control_type", [])

emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
if len(control_type) > 0:
if len(hint.shape) < 5:
hint = hint.unsqueeze(dim=0)
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)

if guided_hint is None:
guided_hint = self.input_hint_block(hint, emb, context)

out_output = []
out_middle = []
Expand Down
6 changes: 6 additions & 0 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,12 @@ def load_controlnet(ckpt_path, model=None):
if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)

if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
controlnet_config["union_controlnet"] = True
for k in list(controlnet_data.keys()):
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
new_sd[new_k] = controlnet_data.pop(k)

leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys))
Expand Down

0 comments on commit faa5743

Please sign in to comment.