diff --git a/adv_control/control.py b/adv_control/control.py index fc5b8a2..27e3d84 100644 --- a/adv_control/control.py +++ b/adv_control/control.py @@ -519,6 +519,11 @@ def copy(self): def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) + # from pathlib import Path + # log_name = ckpt_path.split('\\')[-1] + # with open(Path(__file__).parent.parent.parent / rf"keys_{log_name}.txt", "w") as afile: + # for key, value in controlnet_data.items(): + # afile.write(f"{key}:\t{value.shape}\n") control = None # check if a non-vanilla ControlNet controlnet_type = ControlWeightType.DEFAULT @@ -538,6 +543,10 @@ def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, mo # SVD-ControlNet check elif "temporal_res_block" in key: has_temporal_res_block_key = True + # ControlNet++ check + elif "task_embedding" in key: + raise Exception("ControlNet++ model detected; must be loaded using the Load ControlNet++ Model nodes.") + if has_controlnet_key and has_motion_modules_key: controlnet_type = ControlWeightType.SPARSECTRL elif has_controlnet_key and has_temporal_res_block_key: diff --git a/adv_control/control_plusplus.py b/adv_control/control_plusplus.py new file mode 100644 index 0000000..3e2d8c4 --- /dev/null +++ b/adv_control/control_plusplus.py @@ -0,0 +1,478 @@ +# Code ported and modified from the diffusers ControlNetPlus repo by Qi Xin: +# https://github.com/xinsir6/ControlNetPlus/blob/main/models/controlnet_union.py +from typing import Union + +import os +import torch +import torch as th +import torch.nn as nn +from torch import Tensor +from collections import OrderedDict + + +from comfy.ldm.modules.diffusionmodules.util import (zero_module, timestep_embedding) + +from comfy.cldm.cldm import ControlNet as ControlNetCLDM +import comfy.cldm.cldm +from comfy.controlnet import ControlNet +#from comfy.t2i_adapter.adapter import ResidualAttentionBlock +from comfy.ldm.modules.attention import optimized_attention +import comfy.ops +import comfy.model_management +import comfy.model_detection +import comfy.utils + +from .utils import (AdvancedControlBase, ControlWeights, ControlWeightType, TimestepKeyframeGroup, AbstractPreprocWrapper, + extend_to_batch_size, broadcast_image_to_extend) +from .logger import logger + + +class PlusPlusType: + OPENPOSE = "openpose" + DEPTH = "depth" + THICKLINE = "hed/pidi/scribble/ted" + THINLINE = "canny/lineart/mlsd" + NORMAL = "normal" + SEGMENT = "segment" + TILE = "tile" + REPAINT = "inpaint/outpaint" + NONE = "none" + _LIST_WITH_NONE = [OPENPOSE, DEPTH, THICKLINE, THINLINE, NORMAL, SEGMENT, TILE, REPAINT, NONE] + _LIST = [OPENPOSE, DEPTH, THICKLINE, THINLINE, NORMAL, SEGMENT, TILE, REPAINT] + _DICT = {OPENPOSE: 0, DEPTH: 1, THICKLINE: 2, THINLINE: 3, NORMAL: 4, SEGMENT: 5, TILE: 6, REPAINT: 7, NONE: -1} + + @classmethod + def to_idx(cls, control_type: str): + try: + return cls._DICT[control_type] + except KeyError: + raise Exception(f"Unknown control type '{control_type}'.") + + +class PlusPlusInput: + def __init__(self, image: Tensor, control_type: str, strength: float): + self.image = image + self.control_type = control_type + self.strength = strength + + def clone(self): + return PlusPlusInput(self.image, self.control_type, self.strength) + + +class PlusPlusInputGroup: + def __init__(self): + self.controls: dict[str, PlusPlusInput] = {} + + def add(self, pp_input: PlusPlusInput): + if pp_input.control_type in self.controls: + raise Exception(f"Control type '{pp_input.control_type}' is already present; ControlNet++ does not allow more than 1 of each type.") + self.controls[pp_input.control_type] = pp_input + + def clone(self) -> 'PlusPlusInputGroup': + cloned = PlusPlusInputGroup() + for key, value in self.controls.items(): + cloned.controls[key] = value.clone() + return cloned + + +class PlusPlusImageWrapper(AbstractPreprocWrapper): + error_msg = error_msg = "Invalid use of ControlNet++ Image Wrapper. The output of ControlNet++ Image Wrapper is NOT a usual image, but an object holding the images and extra info - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input." + def __init__(self, condhint: PlusPlusInputGroup): + super().__init__(condhint) + # just an IDE type hint + self.condhint: PlusPlusInputGroup + + def movedim(self, source: int, destination: int): + condhint = self.condhint.clone() + for pp_input in condhint.controls.values(): + pp_input.image = pp_input.image.movedim(source, destination) + return PlusPlusImageWrapper(condhint) + +# parts taken from comfy/cldm/cldm.py +class OptimizedAttention(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.heads = nhead + self.c = c + + self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device) + self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + def forward(self, x): + x = self.in_proj(x) + q, k, v = x.split(self.c, dim=2) + out = optimized_attention(q, k, v, self.heads) + return self.out_proj(out) + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + +class ResBlockUnionControlnet(nn.Module): + def __init__(self, dim, nhead, dtype=None, device=None, operations=None): + super().__init__() + self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations) + self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device) + self.mlp = nn.Sequential( + OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()), + ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))])) + self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device) + + def attention(self, x: torch.Tensor): + return self.attn(x) + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class ControlAddEmbeddingAdv(nn.Module): + def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations: comfy.ops.disable_weight_init=None): + super().__init__() + self.num_control_type = num_control_type + self.in_dim = in_dim + self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device) + self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device) + + def forward(self, control_type, dtype, device): + if control_type is None: + control_type = torch.zeros((self.num_control_type,), device=device) + c_type = timestep_embedding(control_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim)) + return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type))) + + +class ControlNetPlusPlus(ControlNetCLDM): + def __init__(self, *args,**kwargs): + super().__init__(*args, **kwargs) + + operations: comfy.ops.disable_weight_init = kwargs.get("operations", comfy.ops.disable_weight_init) + device = kwargs.get("device", None) + + time_embed_dim = self.model_channels * 4 + control_add_embed_dim = 256 + + self.control_add_embedding = ControlAddEmbeddingAdv(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations) + + def union_controlnet_merge(self, hint: list[Tensor], control_type, emb, context): + # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main + indexes = torch.nonzero(control_type[0]) + inputs = [] + condition_list = [] + + for idx in range(indexes.shape[0]): + controlnet_cond = self.input_hint_block(hint[indexes[idx][0]], emb, context) + feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) + if idx < indexes.shape[0]: + feat_seq += self.task_embedding[indexes[idx][0]] + + inputs.append(feat_seq.unsqueeze(1)) + condition_list.append(controlnet_cond) + + x = torch.cat(inputs, dim=1) + x = self.transformer_layes(x) + + controlnet_cond_fuser = None + for idx in range(indexes.shape[0]): + alpha = self.spatial_ch_projs(x[:, idx]) + alpha = alpha.unsqueeze(-1).unsqueeze(-1) + o = condition_list[idx] + alpha + if controlnet_cond_fuser is None: + controlnet_cond_fuser = o + else: + controlnet_cond_fuser += o + return controlnet_cond_fuser + + def forward(self, x: Tensor, hint: list[Tensor], timesteps, context, y: Tensor=None, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + emb = self.time_embed(t_emb) + + guided_hint = None + if self.control_add_embedding is not None: + control_type = kwargs.get("control_type", None) + + emb += self.control_add_embedding(control_type, emb.dtype, emb.device) + if control_type is not None: + guided_hint = self.union_controlnet_merge(hint, control_type, emb, context) + + if guided_hint is None: + guided_hint = self.input_hint_block(hint[0], emb, context) + + out_output = [] + out_middle = [] + + hs = [] + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for module, zero_conv in zip(self.input_blocks, self.zero_convs): + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + out_output.append(zero_conv(h, emb, context)) + + h = self.middle_block(h, emb, context) + out_middle.append(self.middle_block_out(h, emb, context)) + + return {"middle": out_middle, "output": out_output} + + +class ControlNetPlusPlusAdvanced(ControlNet, AdvancedControlBase): + def __init__(self, control_model: ControlNetPlusPlus, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): + super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype) + AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet()) + self.add_compatible_weight(ControlWeightType.CONTROLNETPLUSPLUS) + # for IDE type hint purposes + self.control_model: ControlNetPlusPlus + self.cond_hint_original: Union[PlusPlusImageWrapper, PlusPlusInputGroup] + self.cond_hint: list[Union[Tensor, None]] + self.cond_hint_shape: Tensor = None + self.cond_hint_types: Tensor = None + # in case it is using the single loader + self.single_control_type: str = None + + def get_universal_weights(self) -> ControlWeights: + # TODO: match actual layer count of model + raw_weights = [(self.weights.base_multiplier ** float(12 - i)) for i in range(13)] + return self.weights.copy_with_new_weights(raw_weights) + + def verify_control_type(self, model_name: str, pp_group: PlusPlusInputGroup=None): + if pp_group is not None: + for pp_input in pp_group.controls.values(): + if PlusPlusType.to_idx(pp_input.control_type) >= self.control_model.num_control_type: + raise Exception(f"ControlNet++ model '{model_name}' does not support control_type '{pp_input.control_type}'.") + if self.single_control_type is not None: + if PlusPlusType.to_idx(self.single_control_type) >= self.control_model.num_control_type: + raise Exception(f"ControlNet++ model '{model_name}' does not support control_type '{self.single_control_type}'.") + + def set_cond_hint_inject(self, *args, **kwargs): + to_return = super().set_cond_hint_inject(*args, **kwargs) + # if not single_control_type, expect PlusPlusImageWrapper + if self.single_control_type is None: + # check that cond_hint is wrapped, and unwrap it + if type(self.cond_hint_original) != PlusPlusImageWrapper: + raise Exception("ControlNet++ (Multi) expects image input from the Load ControlNet++ Model node, NOT from anything else. Images are provided to that node via ControlNet++ Input nodes.") + self.cond_hint_original = self.cond_hint_original.condhint.clone() + # otherwise, expect single image input (AKA, usual controlnet input) + else: + # check that cond_hint is not a PlusPlusImageWrapper + if type(self.cond_hint_original) == PlusPlusImageWrapper: + raise Exception("ControlNet++ (Single) expects usual image input, NOT the image input from a Load ControlNet++ Model (Multi) node.") + pp_group = PlusPlusInputGroup() + pp_input = PlusPlusInput(self.cond_hint_original, self.single_control_type, 1.0) + pp_group.add(pp_input) + self.cond_hint_original = pp_group + return to_return + + def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number): + control_prev = None + if self.previous_controlnet is not None: + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) + + if self.timestep_range is not None: + if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: + if control_prev is not None: + return control_prev + else: + return None + + dtype = self.control_model.dtype + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + + output_dtype = x_noisy.dtype + + # make all cond_hints appropriate dimensions + # TODO: change this to not require cond_hint upscaling every step when self.sub_idxs is present + if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint_shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint_shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = [None] * self.control_model.num_control_type + self.cond_hint_types = torch.tensor([0.0] * self.control_model.num_control_type) + self.cond_hint_shape = None + compression_ratio = self.compression_ratio + # unlike normal controlnet, need to handle each input image tensor (for each type) + for pp_type, pp_input in self.cond_hint_original.controls.items(): + pp_idx = PlusPlusType.to_idx(pp_type) + # if negative, means no type should be selected (single only) + if pp_idx < 0: + pp_idx = 0 + else: + self.cond_hint_types[pp_idx] = pp_input.strength + # if self.cond_hint_original lengths greater or equal to latent count, subdivide + if self.sub_idxs is not None: + actual_cond_hint_orig = pp_input.image + if pp_input.image.size(0) < self.full_latent_length: + actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length) + self.cond_hint[pp_idx] = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center") + else: + self.cond_hint[pp_idx] = comfy.utils.common_upscale(pp_input.image, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center") + self.cond_hint[pp_idx] = self.cond_hint[pp_idx].to(device=self.device, dtype=dtype) + self.cond_hint_shape = self.cond_hint[pp_idx].shape + # prepare cond_hint_controls to match batchsize + if self.cond_hint_types.count_nonzero() == 0: + self.cond_hint_types = None + else: + self.cond_hint_types = self.cond_hint_types.unsqueeze(0).to(device=self.device, dtype=dtype).repeat(x_noisy.shape[0], 1) + for i in range(len(self.cond_hint)): + if self.cond_hint[i] is not None: + if x_noisy.shape[0] != self.cond_hint[i].shape[0]: + self.cond_hint[i] = broadcast_image_to_extend(self.cond_hint[i], x_noisy.shape[0], batched_number) + if self.cond_hint_types is not None: + self.cond_hint_types = broadcast_image_to_extend(self.cond_hint_types, x_noisy.shape[0], batched_number) + + # prepare mask_cond_hint + self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype) + + context = cond.get('crossattn_controlnet', cond['c_crossattn']) + y = cond.get('y', None) + if y is not None: + y = y.to(dtype) + timestep = self.model_sampling_current.timestep(t) + x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) + + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, control_type=self.cond_hint_types) + return self.control_merge(control, control_prev, output_dtype) + + def copy(self): + c = ControlNetPlusPlusAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + self.copy_to(c) + self.copy_to_advanced(c) + c.single_control_type = self.single_control_type + return c + + +def load_controlnetplusplus(ckpt_path: str, timestep_keyframe: TimestepKeyframeGroup=None, model=None): + controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) + # check that actually is ControlNet++ model + if "task_embedding" not in controlnet_data: + raise Exception(f"'{ckpt_path}' is not a valid ControlNet++ model.") + + controlnet_config = None + supported_inference_dtypes = None + + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format + controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data) + diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) + diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" + diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + k_in = "controlnet_down_blocks.{}{}".format(count, s) + k_out = "zero_convs.{}.0{}".format(count, s) + if k_in not in controlnet_data: + loop = False + break + diffusers_keys[k_in] = k_out + count += 1 + + count = 0 + loop = True + while loop: + suffix = [".weight", ".bias"] + for s in suffix: + if count == 0: + k_in = "controlnet_cond_embedding.conv_in{}".format(s) + else: + k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s) + k_out = "input_hint_block.{}{}".format(count * 2, s) + if k_in not in controlnet_data: + k_in = "controlnet_cond_embedding.conv_out{}".format(s) + loop = False + diffusers_keys[k_in] = k_out + count += 1 + + new_sd = {} + for k in diffusers_keys: + if k in controlnet_data: + new_sd[diffusers_keys[k]] = controlnet_data.pop(k) + + if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet + controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0] + for k in list(controlnet_data.keys()): + new_k = k.replace('.attn.in_proj_', '.attn.in_proj.') + new_sd[new_k] = controlnet_data.pop(k) + + leftover_keys = controlnet_data.keys() + if len(leftover_keys) > 0: + logger.warning("leftover ControlNet++ keys: {}".format(leftover_keys)) + controlnet_data = new_sd + elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format + raise Exception("Unexpected SD3 diffusers format for ControlNet++ model. Something is very wrong.") + + pth_key = 'control_model.zero_convs.0.0.weight' + pth = False + key = 'zero_convs.0.0.weight' + if pth_key in controlnet_data: + pth = True + key = pth_key + prefix = "control_model." + elif key in controlnet_data: + prefix = "" + else: + raise Exception("Unexpected T2IAdapter format for ControlNet++ model. Something is very wrong.") + + if controlnet_config is None: + model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) + supported_inference_dtypes = model_config.supported_inference_dtypes + controlnet_config = model_config.unet_config + + load_device = comfy.model_management.get_torch_device() + if supported_inference_dtypes is None: + unet_dtype = comfy.model_management.unet_dtype() + else: + unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) + + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + if manual_cast_dtype is not None: + controlnet_config["operations"] = comfy.ops.manual_cast + controlnet_config["dtype"] = unet_dtype + controlnet_config.pop("out_channels") + controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] + control_model = ControlNetPlusPlus(**controlnet_config) + + if pth: + if 'difference' in controlnet_data: + if model is not None: + comfy.model_management.load_models_gpu([model]) + model_sd = model.model_state_dict() + for x in controlnet_data: + c_m = "control_model." + if x.startswith(c_m): + sd_key = "diffusion_model.{}".format(x[len(c_m):]) + if sd_key in model_sd: + cd = controlnet_data[x] + cd += model_sd[sd_key].type(cd.dtype).to(cd.device) + else: + logger.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.") + + class WeightsLoader(torch.nn.Module): + pass + w = WeightsLoader() + w.control_model = control_model + missing, unexpected = w.load_state_dict(controlnet_data, strict=False) + else: + missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) + + if len(missing) > 0: + logger.warning("missing ControlNet++ keys: {}".format(missing)) + + if len(unexpected) > 0: + logger.debug("unexpected ControlNet++ keys: {}".format(unexpected)) + + global_average_pooling = False + filename = os.path.splitext(ckpt_path)[0] + if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling + global_average_pooling = True + + control = ControlNetPlusPlusAdvanced(control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) + return control diff --git a/adv_control/nodes.py b/adv_control/nodes.py index 66e7368..0355415 100644 --- a/adv_control/nodes.py +++ b/adv_control/nodes.py @@ -12,6 +12,7 @@ TimestepKeyframeNode, TimestepKeyframeInterpolationNode, TimestepKeyframeFromStrengthListNode) from .nodes_sparsectrl import SparseCtrlMergedLoaderAdvanced, SparseCtrlLoaderAdvanced, SparseIndexMethodNode, SparseSpreadMethodNode, RgbSparseCtrlPreprocessor, SparseWeightExtras from .nodes_reference import ReferenceControlNetNode, ReferenceControlFinetune, ReferencePreprocessorNode +from .nodes_plusplus import PlusPlusLoaderAdvanced, PlusPlusLoaderSingle, PlusPlusInputNode from .nodes_loosecontrol import ControlNetLoaderWithLoraAdvanced from .nodes_deprecated import LoadImagesFromDirectory from .logger import logger @@ -199,6 +200,10 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta "ACN_SparseCtrlIndexMethodNode": SparseIndexMethodNode, "ACN_SparseCtrlSpreadMethodNode": SparseSpreadMethodNode, "ACN_SparseCtrlWeightExtras": SparseWeightExtras, + # ControlNet++ + "ACN_ControlNet++LoaderSingle": PlusPlusLoaderSingle, + "ACN_ControlNet++LoaderAdvanced": PlusPlusLoaderAdvanced, + "ACN_ControlNet++InputNode": PlusPlusInputNode, # Reference "ACN_ReferencePreprocessor": ReferencePreprocessorNode, "ACN_ReferenceControlNet": ReferenceControlNetNode, @@ -238,6 +243,10 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta "ACN_SparseCtrlIndexMethodNode": "SparseCtrl Index Method 🛂🅐🅒🅝", "ACN_SparseCtrlSpreadMethodNode": "SparseCtrl Spread Method 🛂🅐🅒🅝", "ACN_SparseCtrlWeightExtras": "SparseCtrl Weight Extras 🛂🅐🅒🅝", + # ControlNet++ + "ACN_ControlNet++LoaderSingle": "Load ControlNet++ Model (Single) 🛂🅐🅒🅝", + "ACN_ControlNet++LoaderAdvanced": "Load ControlNet++ Model (Multi) 🛂🅐🅒🅝", + "ACN_ControlNet++InputNode": "ControlNet++ Input 🛂🅐🅒🅝", # Reference "ACN_ReferencePreprocessor": "Reference Preproccessor 🛂🅐🅒🅝", "ACN_ReferenceControlNet": "Reference ControlNet 🛂🅐🅒🅝", diff --git a/adv_control/nodes_plusplus.py b/adv_control/nodes_plusplus.py new file mode 100644 index 0000000..9fd0514 --- /dev/null +++ b/adv_control/nodes_plusplus.py @@ -0,0 +1,84 @@ +from torch import Tensor +import math + +import folder_paths + +from .control_plusplus import load_controlnetplusplus, PlusPlusType, PlusPlusInput, PlusPlusInputGroup, PlusPlusImageWrapper +from .utils import BIGMAX + + +class PlusPlusLoaderAdvanced: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "plus_input": ("PLUS_INPUT", ), + "name": (folder_paths.get_filename_list("controlnet"), ), + } + } + + RETURN_TYPES = ("CONTROL_NET", "IMAGE",) + FUNCTION = "load_controlnet_plusplus" + + CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/ControlNet++" + + def load_controlnet_plusplus(self, plus_input: PlusPlusInputGroup, name: str): + controlnet_path = folder_paths.get_full_path("controlnet", name) + controlnet = load_controlnetplusplus(controlnet_path) + controlnet.verify_control_type(name, plus_input) + return (controlnet, PlusPlusImageWrapper(plus_input),) + + +class PlusPlusLoaderSingle: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "name": (folder_paths.get_filename_list("controlnet"), ), + "control_type": (PlusPlusType._LIST_WITH_NONE, {"default": PlusPlusType.NONE}, ), + } + } + + RETURN_TYPES = ("CONTROL_NET",) + FUNCTION = "load_controlnet_plusplus" + + CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/ControlNet++" + + def load_controlnet_plusplus(self, name: str, control_type: str): + controlnet_path = folder_paths.get_full_path("controlnet", name) + controlnet = load_controlnetplusplus(controlnet_path) + controlnet.single_control_type = control_type + controlnet.verify_control_type(name) + return (controlnet,) + + +class PlusPlusInputNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "control_type": (PlusPlusType._LIST,), + }, + "optional": { + "prev_plus_input": ("PLUS_INPUT",), + #"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": BIGMAX, "step": 0.01}), + } + } + + RETURN_TYPES = ("PLUS_INPUT", ) + FUNCTION = "wrap_images" + + CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/ControlNet++" + + def wrap_images(self, image: Tensor, control_type: str, strength=1.0, prev_plus_input: PlusPlusInputGroup=None): + if prev_plus_input is None: + prev_plus_input = PlusPlusInputGroup() + prev_plus_input = prev_plus_input.clone() + + if math.isclose(strength, 0.0): + strength = 0.0000001 + pp_input = PlusPlusInput(image, control_type, strength) + prev_plus_input.add(pp_input) + + return (prev_plus_input,) diff --git a/adv_control/utils.py b/adv_control/utils.py index b8651f9..726233b 100644 --- a/adv_control/utils.py +++ b/adv_control/utils.py @@ -145,6 +145,7 @@ class ControlWeightType: UNIVERSAL = "universal" T2IADAPTER = "t2iadapter" CONTROLNET = "controlnet" + CONTROLNETPLUSPLUS = "controlnet++" CONTROLLORA = "controllora" CONTROLLLLITE = "controllllite" SVD_CONTROLNET = "svd_controlnet" @@ -380,7 +381,7 @@ def default(cls, keyframe: TimestepKeyframe) -> 'TimestepKeyframeGroup': class AbstractPreprocWrapper: error_msg = "Invalid use of [InsertHere] output. The output of [InsertHere] preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input." - def __init__(self, condhint: Tensor): + def __init__(self, condhint): self.condhint = condhint def movedim(self, *args, **kwargs): diff --git a/pyproject.toml b/pyproject.toml index 4e3d8ab..4c102c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-advanced-controlnet" description = "Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks." -version = "1.1.0" +version = "1.1.1" license = "LICENSE" dependencies = [] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index e69de29..0000000