-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Controlnet union model basic implementation.
This is only the model code itself, it currently defaults to an empty embedding [0] * 6 which seems to work better than treating it like a regular controlnet. TODO: Add nodes to select the image type.
- Loading branch information
1 parent
bb663bc
commit faa5743
Showing
2 changed files
with
118 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,47 @@ | |
from ..ldm.modules.attention import SpatialTransformer | ||
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample | ||
from ..ldm.util import exists | ||
from ..ldm.cascade.common import OptimizedAttention | ||
from collections import OrderedDict | ||
import comfy.ops | ||
from comfy.ldm.modules.attention import optimized_attention | ||
|
||
class OptimizedAttention(nn.Module): | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong. |
||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): | ||
super().__init__() | ||
self.heads = nhead | ||
self.c = c | ||
|
||
self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device) | ||
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) | ||
|
||
def forward(self, x): | ||
x = self.in_proj(x) | ||
q, k, v = x.split(self.c, dim=2) | ||
out = optimized_attention(q, k, v, self.heads) | ||
return self.out_proj(out) | ||
|
||
class QuickGELU(nn.Module): | ||
def forward(self, x: torch.Tensor): | ||
return x * torch.sigmoid(1.702 * x) | ||
|
||
class ResBlockUnionControlnet(nn.Module): | ||
def __init__(self, dim, nhead, dtype=None, device=None, operations=None): | ||
super().__init__() | ||
self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations) | ||
self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device) | ||
self.mlp = nn.Sequential( | ||
OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()), | ||
("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))])) | ||
self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device) | ||
|
||
def attention(self, x: torch.Tensor): | ||
return self.attn(x) | ||
|
||
def forward(self, x: torch.Tensor): | ||
x = x + self.attention(self.ln_1(x)) | ||
x = x + self.mlp(self.ln_2(x)) | ||
return x | ||
|
||
class ControlledUnetModel(UNetModel): | ||
#implemented in the ldm unet | ||
|
@@ -53,6 +93,7 @@ def __init__( | |
transformer_depth_middle=None, | ||
transformer_depth_output=None, | ||
attn_precision=None, | ||
union_controlnet=False, | ||
device=None, | ||
operations=comfy.ops.disable_weight_init, | ||
**kwargs, | ||
|
@@ -280,14 +321,84 @@ def __init__( | |
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) | ||
self._feature_size += ch | ||
|
||
if union_controlnet: | ||
self.num_control_type = 6 | ||
num_trans_channel = 320 | ||
num_trans_head = 8 | ||
num_trans_layer = 1 | ||
num_proj_channel = 320 | ||
# task_scale_factor = num_trans_channel ** 0.5 | ||
self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device)) | ||
|
||
self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)]) | ||
self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device) | ||
#----------------------------------------------------------------------------------------------------- | ||
|
||
control_add_embed_dim = 256 | ||
class ControlAddEmbedding(nn.Module): | ||
def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None): | ||
super().__init__() | ||
self.num_control_type = num_control_type | ||
self.in_dim = in_dim | ||
self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device) | ||
self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device) | ||
def forward(self, control_type, dtype, device): | ||
c_type = torch.zeros((self.num_control_type,), device=device) | ||
c_type[control_type] = 1.0 | ||
c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim)) | ||
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type))) | ||
|
||
self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations) | ||
else: | ||
self.task_embedding = None | ||
self.control_add_embedding = None | ||
|
||
def union_controlnet_merge(self, hint, control_type, emb, context): | ||
This comment has been minimized.
Sorry, something went wrong.
huchenlei
Collaborator
|
||
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main | ||
inputs = [] | ||
condition_list = [] | ||
|
||
for idx in range(min(1, len(control_type))): | ||
controlnet_cond = self.input_hint_block(hint[idx], emb, context) | ||
feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) | ||
if idx < len(control_type): | ||
feat_seq += self.task_embedding[control_type[idx]] | ||
|
||
inputs.append(feat_seq.unsqueeze(1)) | ||
condition_list.append(controlnet_cond) | ||
|
||
x = torch.cat(inputs, dim=1) | ||
x = self.transformer_layes(x) | ||
controlnet_cond_fuser = None | ||
for idx in range(len(control_type)): | ||
alpha = self.spatial_ch_projs(x[:, idx]) | ||
alpha = alpha.unsqueeze(-1).unsqueeze(-1) | ||
o = condition_list[idx] + alpha | ||
if controlnet_cond_fuser is None: | ||
controlnet_cond_fuser = o | ||
else: | ||
controlnet_cond_fuser += o | ||
return controlnet_cond_fuser | ||
|
||
def make_zero_conv(self, channels, operations=None, dtype=None, device=None): | ||
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) | ||
|
||
def forward(self, x, hint, timesteps, context, y=None, **kwargs): | ||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) | ||
emb = self.time_embed(t_emb) | ||
|
||
guided_hint = self.input_hint_block(hint, emb, context) | ||
guided_hint = None | ||
if self.control_add_embedding is not None: | ||
control_type = kwargs.get("control_type", []) | ||
|
||
emb += self.control_add_embedding(control_type, emb.dtype, emb.device) | ||
if len(control_type) > 0: | ||
if len(hint.shape) < 5: | ||
hint = hint.unsqueeze(dim=0) | ||
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context) | ||
|
||
if guided_hint is None: | ||
guided_hint = self.input_hint_block(hint, emb, context) | ||
|
||
out_output = [] | ||
out_middle = [] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Redefinition of
OptimizedAttention
on line 16. Is that intentional?