Skip to content

Commit

Permalink
Merge PR #130 from Kosinkadink/develop - ControlNet++ support
Browse files Browse the repository at this point in the history
ControlNet++ support
  • Loading branch information
Kosinkadink committed Jul 15, 2024
2 parents 7a456aa + 36fdc79 commit 56000f3
Show file tree
Hide file tree
Showing 7 changed files with 583 additions and 2 deletions.
9 changes: 9 additions & 0 deletions adv_control/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,11 @@ def copy(self):

def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
# from pathlib import Path
# log_name = ckpt_path.split('\\')[-1]
# with open(Path(__file__).parent.parent.parent / rf"keys_{log_name}.txt", "w") as afile:
# for key, value in controlnet_data.items():
# afile.write(f"{key}:\t{value.shape}\n")
control = None
# check if a non-vanilla ControlNet
controlnet_type = ControlWeightType.DEFAULT
Expand All @@ -538,6 +543,10 @@ def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, mo
# SVD-ControlNet check
elif "temporal_res_block" in key:
has_temporal_res_block_key = True
# ControlNet++ check
elif "task_embedding" in key:
raise Exception("ControlNet++ model detected; must be loaded using the Load ControlNet++ Model nodes.")

if has_controlnet_key and has_motion_modules_key:
controlnet_type = ControlWeightType.SPARSECTRL
elif has_controlnet_key and has_temporal_res_block_key:
Expand Down
478 changes: 478 additions & 0 deletions adv_control/control_plusplus.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions adv_control/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TimestepKeyframeNode, TimestepKeyframeInterpolationNode, TimestepKeyframeFromStrengthListNode)
from .nodes_sparsectrl import SparseCtrlMergedLoaderAdvanced, SparseCtrlLoaderAdvanced, SparseIndexMethodNode, SparseSpreadMethodNode, RgbSparseCtrlPreprocessor, SparseWeightExtras
from .nodes_reference import ReferenceControlNetNode, ReferenceControlFinetune, ReferencePreprocessorNode
from .nodes_plusplus import PlusPlusLoaderAdvanced, PlusPlusLoaderSingle, PlusPlusInputNode
from .nodes_loosecontrol import ControlNetLoaderWithLoraAdvanced
from .nodes_deprecated import LoadImagesFromDirectory
from .logger import logger
Expand Down Expand Up @@ -199,6 +200,10 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta
"ACN_SparseCtrlIndexMethodNode": SparseIndexMethodNode,
"ACN_SparseCtrlSpreadMethodNode": SparseSpreadMethodNode,
"ACN_SparseCtrlWeightExtras": SparseWeightExtras,
# ControlNet++
"ACN_ControlNet++LoaderSingle": PlusPlusLoaderSingle,
"ACN_ControlNet++LoaderAdvanced": PlusPlusLoaderAdvanced,
"ACN_ControlNet++InputNode": PlusPlusInputNode,
# Reference
"ACN_ReferencePreprocessor": ReferencePreprocessorNode,
"ACN_ReferenceControlNet": ReferenceControlNetNode,
Expand Down Expand Up @@ -238,6 +243,10 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta
"ACN_SparseCtrlIndexMethodNode": "SparseCtrl Index Method 🛂🅐🅒🅝",
"ACN_SparseCtrlSpreadMethodNode": "SparseCtrl Spread Method 🛂🅐🅒🅝",
"ACN_SparseCtrlWeightExtras": "SparseCtrl Weight Extras 🛂🅐🅒🅝",
# ControlNet++
"ACN_ControlNet++LoaderSingle": "Load ControlNet++ Model (Single) 🛂🅐🅒🅝",
"ACN_ControlNet++LoaderAdvanced": "Load ControlNet++ Model (Multi) 🛂🅐🅒🅝",
"ACN_ControlNet++InputNode": "ControlNet++ Input 🛂🅐🅒🅝",
# Reference
"ACN_ReferencePreprocessor": "Reference Preproccessor 🛂🅐🅒🅝",
"ACN_ReferenceControlNet": "Reference ControlNet 🛂🅐🅒🅝",
Expand Down
84 changes: 84 additions & 0 deletions adv_control/nodes_plusplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from torch import Tensor
import math

import folder_paths

from .control_plusplus import load_controlnetplusplus, PlusPlusType, PlusPlusInput, PlusPlusInputGroup, PlusPlusImageWrapper
from .utils import BIGMAX


class PlusPlusLoaderAdvanced:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"plus_input": ("PLUS_INPUT", ),
"name": (folder_paths.get_filename_list("controlnet"), ),
}
}

RETURN_TYPES = ("CONTROL_NET", "IMAGE",)
FUNCTION = "load_controlnet_plusplus"

CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/ControlNet++"

def load_controlnet_plusplus(self, plus_input: PlusPlusInputGroup, name: str):
controlnet_path = folder_paths.get_full_path("controlnet", name)
controlnet = load_controlnetplusplus(controlnet_path)
controlnet.verify_control_type(name, plus_input)
return (controlnet, PlusPlusImageWrapper(plus_input),)


class PlusPlusLoaderSingle:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"name": (folder_paths.get_filename_list("controlnet"), ),
"control_type": (PlusPlusType._LIST_WITH_NONE, {"default": PlusPlusType.NONE}, ),
}
}

RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet_plusplus"

CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/ControlNet++"

def load_controlnet_plusplus(self, name: str, control_type: str):
controlnet_path = folder_paths.get_full_path("controlnet", name)
controlnet = load_controlnetplusplus(controlnet_path)
controlnet.single_control_type = control_type
controlnet.verify_control_type(name)
return (controlnet,)


class PlusPlusInputNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"control_type": (PlusPlusType._LIST,),
},
"optional": {
"prev_plus_input": ("PLUS_INPUT",),
#"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": BIGMAX, "step": 0.01}),
}
}

RETURN_TYPES = ("PLUS_INPUT", )
FUNCTION = "wrap_images"

CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/ControlNet++"

def wrap_images(self, image: Tensor, control_type: str, strength=1.0, prev_plus_input: PlusPlusInputGroup=None):
if prev_plus_input is None:
prev_plus_input = PlusPlusInputGroup()
prev_plus_input = prev_plus_input.clone()

if math.isclose(strength, 0.0):
strength = 0.0000001
pp_input = PlusPlusInput(image, control_type, strength)
prev_plus_input.add(pp_input)

return (prev_plus_input,)
3 changes: 2 additions & 1 deletion adv_control/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class ControlWeightType:
UNIVERSAL = "universal"
T2IADAPTER = "t2iadapter"
CONTROLNET = "controlnet"
CONTROLNETPLUSPLUS = "controlnet++"
CONTROLLORA = "controllora"
CONTROLLLLITE = "controllllite"
SVD_CONTROLNET = "svd_controlnet"
Expand Down Expand Up @@ -380,7 +381,7 @@ def default(cls, keyframe: TimestepKeyframe) -> 'TimestepKeyframeGroup':

class AbstractPreprocWrapper:
error_msg = "Invalid use of [InsertHere] output. The output of [InsertHere] preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
def __init__(self, condhint: Tensor):
def __init__(self, condhint):
self.condhint = condhint

def movedim(self, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-advanced-controlnet"
description = "Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks."
version = "1.1.0"
version = "1.1.1"
license = "LICENSE"
dependencies = []

Expand Down
Empty file removed requirements.txt
Empty file.

0 comments on commit 56000f3

Please sign in to comment.