Skip to content

Commit

Permalink
Merge PR #141 - ControlLLLite refactor + vanilla CN conversion w/ con…
Browse files Browse the repository at this point in the history
…text_opts

ControlLLLite refactor + vanilla CN conversion when using sliding context
  • Loading branch information
Kosinkadink committed Jul 20, 2024
2 parents d3c6ae0 + 07b8e3e commit a91a3ac
Show file tree
Hide file tree
Showing 7 changed files with 547 additions and 312 deletions.
281 changes: 119 additions & 162 deletions adv_control/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
from comfy.model_patcher import ModelPatcher

from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseSettings, SparseConst
from .control_lllite import LLLiteModule, LLLitePatch
from .control_lllite import LLLiteModule, LLLitePatch, load_controllllite
from .control_svd import svd_unet_config_from_diffusers_unet, SVDControlNet, svd_unet_to_diffusers
from .utils import (AdvancedControlBase, TimestepKeyframeGroup, LatentKeyframeGroup, AbstractPreprocWrapper, ControlWeightType, ControlWeights, WeightTypeException,
manual_cast_clean_groupnorm, disable_weight_init_clean_groupnorm, prepare_mask_batch, get_properly_arranged_t2i_weights, load_torch_file_with_dict_factory,
broadcast_image_to_extend, extend_to_batch_size)
from .logger import logger


ORIG_PREVIOUS_CONTROLNET = "_orig_previous_controlnet"


class ControlNetAdvanced(ControlNet, AdvancedControlBase):
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
Expand Down Expand Up @@ -98,8 +101,10 @@ def copy(self):

@staticmethod
def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced':
return ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
to_return = ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, device=v.device, load_device=v.load_device, manual_cast_dtype=v.manual_cast_dtype)
v.copy_to(to_return)
return to_return


class T2IAdapterAdvanced(T2IAdapter, AdvancedControlBase):
Expand Down Expand Up @@ -166,8 +171,10 @@ def cleanup(self):

@staticmethod
def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroup=None) -> 'T2IAdapterAdvanced':
return T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in,
to_return = T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in,
compression_ratio=v.compression_ratio, upscale_algorithm=v.upscale_algorithm, device=v.device)
v.copy_to(to_return)
return to_return


class ControlLoraAdvanced(ControlLora, AdvancedControlBase):
Expand All @@ -194,8 +201,10 @@ def cleanup(self):

@staticmethod
def from_vanilla(v: ControlLora, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlLoraAdvanced':
return ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe,
to_return = ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe,
global_average_pooling=v.global_average_pooling, device=v.device)
v.copy_to(to_return)
return to_return


class SVDControlNetAdvanced(ControlNetAdvanced):
Expand Down Expand Up @@ -408,115 +417,6 @@ def copy(self):
return c


class ControlLLLiteAdvanced(ControlBase, AdvancedControlBase):
# This ControlNet is more of an attention patch than a traditional controlnet
def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_keyframes: TimestepKeyframeGroup, device=None):
super().__init__(device)
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), require_model=True)
self.patch_attn1 = patch_attn1.set_control(self)
self.patch_attn2 = patch_attn2.set_control(self)
self.latent_dims_div2 = None
self.latent_dims_div4 = None

def patch_model(self, model: ModelPatcher):
model.set_model_attn1_patch(self.patch_attn1)
model.set_model_attn2_patch(self.patch_attn2)

def set_cond_hint_inject(self, *args, **kwargs):
to_return = super().set_cond_hint_inject(*args, **kwargs)
# cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1)
self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
return to_return

def pre_run_advanced(self, *args, **kwargs):
AdvancedControlBase.pre_run_advanced(self, *args, **kwargs)
#logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}")
self.patch_attn1.set_control(self)
self.patch_attn2.set_control(self)
#logger.warn(f"in pre_run_advanced: {id(self)}")

def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
# normal ControlNet stuff
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)

if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
return control_prev

dtype = x_noisy.dtype
# prepare cond_hint
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
if self.sub_idxs is not None:
actual_cond_hint_orig = self.cond_hint_original
if self.cond_hint_original.size(0) < self.full_latent_length:
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
else:
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to_extend(self.cond_hint, x_noisy.shape[0], batched_number)
# some special logic here compared to other controlnets:
# * The cond_emb in attn patches will divide latent dims by 2 or 4, integer
# * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4
divisible_by_2_h = x_noisy.shape[2]%2==0
divisible_by_2_w = x_noisy.shape[3]%2==0
if not (divisible_by_2_h and divisible_by_2_w):
#logger.warn(f"{x_noisy.shape} not divisible by 2!")
new_h = (x_noisy.shape[2]//2)*2
new_w = (x_noisy.shape[3]//2)*2
if not divisible_by_2_h:
new_h += 2
if not divisible_by_2_w:
new_w += 2
self.latent_dims_div2 = (new_h, new_w)
divisible_by_4_h = x_noisy.shape[2]%4==0
divisible_by_4_w = x_noisy.shape[3]%4==0
if not (divisible_by_4_h and divisible_by_4_w):
#logger.warn(f"{x_noisy.shape} not divisible by 4!")
new_h = (x_noisy.shape[2]//4)*4
new_w = (x_noisy.shape[3]//4)*4
if not divisible_by_4_h:
new_h += 4
if not divisible_by_4_w:
new_w += 4
self.latent_dims_div4 = (new_h, new_w)
# prepare mask
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
# done preparing; model patches will take care of everything now.
# return normal controlnet stuff
return control_prev

def cleanup_advanced(self):
super().cleanup_advanced()
self.patch_attn1.cleanup()
self.patch_attn2.cleanup()
self.latent_dims_div2 = None
self.latent_dims_div4 = None

def copy(self):
c = ControlLLLiteAdvanced(self.patch_attn1, self.patch_attn2, self.timestep_keyframes)
self.copy_to(c)
self.copy_to_advanced(c)
return c

# deepcopy needs to properly keep track of objects to work between model.clone calls!
# def __deepcopy__(self, *args, **kwargs):
# self.cleanup_advanced()
# return self

# def get_models(self):
# # get_models is called once at the start of every KSampler run - use to reset already_patched status
# out = super().get_models()
# logger.error(f"in get_models! {id(self)}")
# return out


def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
# from pathlib import Path
Expand Down Expand Up @@ -591,6 +491,112 @@ def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None):
return control


def convert_all_to_advanced(conds: list[list[dict[str]]]) -> tuple[bool, list]:
cache = {}
modified = False
new_conds = []
for cond in conds:
converted_cond = None
if cond is not None:
need_to_convert = False
# first, check if there is even a need to convert
for sub_cond in cond:
actual_cond = sub_cond[1]
if "control" in actual_cond:
if not are_all_advanced_controlnet(actual_cond["control"]):
need_to_convert = True
break
if not need_to_convert:
converted_cond = cond
else:
converted_cond = []
for sub_cond in cond:
new_sub_cond: list = []
for actual_cond in sub_cond:
if not type(actual_cond) == dict:
new_sub_cond.append(actual_cond)
continue
if "control" not in actual_cond:
new_sub_cond.append(actual_cond)
elif are_all_advanced_controlnet(actual_cond["control"]):
new_sub_cond.append(actual_cond)
else:
actual_cond = actual_cond.copy()
actual_cond["control"] = _convert_all_control_to_advanced(actual_cond["control"], cache)
new_sub_cond.append(actual_cond)
modified = True
converted_cond.append(new_sub_cond)
new_conds.append(converted_cond)
return modified, new_conds


def _convert_all_control_to_advanced(input_object: ControlBase, cache: dict):
output_object = input_object
# iteratively convert to advanced, if needed
next_cn = None
curr_cn = input_object
iter = 0
while curr_cn is not None:
if not is_advanced_controlnet(curr_cn):
# if already in cache, then conversion was done before, so just link it and exit
if curr_cn in cache:
new_cn = cache[curr_cn]
if next_cn is not None:
setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet)
next_cn.previous_controlnet = new_cn
if iter == 0: # if was top-level controlnet, that's the new output
output_object = new_cn
break
try:
# convert to advanced, and assign previous_controlnet (convert doesn't transfer it)
new_cn = convert_to_advanced(curr_cn)
except Exception as e:
raise Exception("Failed to automatically convert a ControlNet to Advanced to support sliding window context.", e)
new_cn.previous_controlnet = curr_cn.previous_controlnet
if iter == 0: # if was top-level controlnet, that's the new output
output_object = new_cn
# if next_cn is present, then it needs to be pointed to new_cn
if next_cn is not None:
setattr(next_cn, ORIG_PREVIOUS_CONTROLNET, next_cn.previous_controlnet)
next_cn.previous_controlnet = new_cn
# add to cache
cache[curr_cn] = new_cn
curr_cn = new_cn
next_cn = curr_cn
curr_cn = curr_cn.previous_controlnet
iter += 1
return output_object


def restore_all_controlnet_conns(conds: list[list[dict[str]]]):
# if a cn has an _orig_previous_controlnet property, restore it and delete
for main_cond in conds:
if main_cond is not None:
for cond in main_cond:
if "control" in cond[1]:
_restore_all_controlnet_conns(cond[1]["control"])


def _restore_all_controlnet_conns(input_object: ControlBase):
# restore original previous_controlnet if needed
curr_cn = input_object
while curr_cn is not None:
if hasattr(curr_cn, ORIG_PREVIOUS_CONTROLNET):
curr_cn.previous_controlnet = getattr(curr_cn, ORIG_PREVIOUS_CONTROLNET)
delattr(curr_cn, ORIG_PREVIOUS_CONTROLNET)
curr_cn = curr_cn.previous_controlnet


def are_all_advanced_controlnet(input_object: ControlBase):
# iteratively check if linked controlnets objects are all advanced
curr_cn = input_object
while curr_cn is not None:
if not is_advanced_controlnet(curr_cn):
return False
curr_cn = curr_cn.previous_controlnet
return True


def is_advanced_controlnet(input_object):
return hasattr(input_object, "sub_idxs")

Expand Down Expand Up @@ -749,55 +755,6 @@ class WeightsLoader(torch.nn.Module):
return control


def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None):
if controlnet_data is None:
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
# adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
# first, split weights for each module
module_weights = {}
for key, value in controlnet_data.items():
fragments = key.split(".")
module_name = fragments[0]
weight_name = ".".join(fragments[1:])

if module_name not in module_weights:
module_weights[module_name] = {}
module_weights[module_name][weight_name] = value

# next, load each module
modules = {}
for module_name, weights in module_weights.items():
# kohya planned to do something about how these should be chosen, so I'm not touching this
# since I am not familiar with the logic for this
if "conditioning1.4.weight" in weights:
depth = 3
elif weights["conditioning1.2.weight"].shape[-1] == 4:
depth = 2
else:
depth = 1

module = LLLiteModule(
name=module_name,
is_conv2d=weights["down.0.weight"].ndim == 4,
in_dim=weights["down.0.weight"].shape[1],
depth=depth,
cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
mlp_dim=weights["down.0.weight"].shape[0],
)
# load weights into module
module.load_state_dict(weights)
modules[module_name] = module
if len(modules) == 1:
module.is_first = True

#logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules")

patch_attn1 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN1)
patch_attn2 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN2)
control = ControlLLLiteAdvanced(patch_attn1=patch_attn1, patch_attn2=patch_attn2, timestep_keyframes=timestep_keyframe)
return control


def load_svdcontrolnet(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
if controlnet_data is None:
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
Expand Down
Loading

0 comments on commit a91a3ac

Please sign in to comment.