diff --git a/comfy/controlnet.py b/comfy/controlnet.py index d2d2cefaae8..c0f9b651121 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -430,9 +430,9 @@ def load_controlnet_hunyuandit(controlnet_data): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT) return control -def load_controlnet_flux_xlabs(sd): +def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd) - control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, sd) extra_conds = ['y', 'guidance'] control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) @@ -457,6 +457,10 @@ def load_controlnet_flux_instantx(sd): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control +def convert_mistoline(sd): + return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) + + def load_controlnet(ckpt_path, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT @@ -518,13 +522,15 @@ def load_controlnet(ckpt_path, model=None): if len(leftover_keys) > 0: logging.warning("leftover keys: {}".format(leftover_keys)) controlnet_data = new_sd - elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format + elif "controlnet_blocks.0.weight" in controlnet_data: if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: - return load_controlnet_flux_xlabs(controlnet_data) + return load_controlnet_flux_xlabs_mistoline(controlnet_data) elif "pos_embed_input.proj.weight" in controlnet_data: - return load_controlnet_mmdit(controlnet_data) + return load_controlnet_mmdit(controlnet_data) #SD3 diffusers controlnet elif "controlnet_x_embedder.weight" in controlnet_data: return load_controlnet_flux_instantx(controlnet_data) + elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux + return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True) pth_key = 'control_model.zero_convs.0.0.weight' pth = False diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 2598e71729a..d8b77612969 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -1,4 +1,5 @@ #Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py +#modified to support different types of flux controlnets import torch import math @@ -12,22 +13,65 @@ from .model import Flux import comfy.ldm.common_dit +class MistolineCondDownsamplBlock(nn.Module): + def __init__(self, dtype=None, device=None, operations=None): + super().__init__() + self.encoder = nn.Sequential( + operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) + ) + + def forward(self, x): + return self.encoder(x) + +class MistolineControlnetBlock(nn.Module): + def __init__(self, hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device) + self.act = nn.SiLU() + + def forward(self, x): + return self.act(self.linear(x)) + class ControlNetFlux(Flux): - def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs): + def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image_model=None, dtype=None, device=None, operations=None, **kwargs): super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) self.main_model_double = 19 self.main_model_single = 38 + + self.mistoline = mistoline # add ControlNet blocks + if self.mistoline: + control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + self.controlnet_blocks = nn.ModuleList([]) for _ in range(self.params.depth): - controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) - self.controlnet_blocks.append(controlnet_block) + self.controlnet_blocks.append(control_block()) self.controlnet_single_blocks = nn.ModuleList([]) for _ in range(self.params.depth_single_blocks): - self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)) + self.controlnet_single_blocks.append(control_block()) self.num_union_modes = num_union_modes self.controlnet_mode_embedder = None @@ -38,23 +82,26 @@ def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtyp self.latent_input = latent_input self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) if not self.latent_input: - self.input_hint_block = nn.Sequential( - operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) - ) + if self.mistoline: + self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations) + else: + self.input_hint_block = nn.Sequential( + operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) + ) def forward_orig( self, @@ -73,9 +120,6 @@ def forward_orig( # running on sequences img img = self.img_in(img) - if not self.latent_input: - controlnet_cond = self.input_hint_block(controlnet_cond) - controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) controlnet_cond = self.pos_embed_input(controlnet_cond) img = img + controlnet_cond @@ -131,9 +175,14 @@ def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): patch_size = 2 if self.latent_input: hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) - hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + elif self.mistoline: + hint = hint * 2.0 - 1.0 + hint = self.input_cond_block(hint) else: hint = hint * 2.0 - 1.0 + hint = self.input_hint_block(hint) + + hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) bs, c, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))